diff --git a/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java b/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java index 09e7fe0d01c6..90a2f421e30d 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/FunctionManager.java @@ -22,6 +22,7 @@ import io.trino.connector.system.GlobalSystemConnector; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.AggregationImplementation; @@ -192,6 +193,7 @@ private static void verifyMethodHandleSignature(BoundSignature boundSignature, S .mapToInt(InvocationArgumentConvention::getParameterCount) .sum(); expectedParameterCount += methodType.parameterList().stream().filter(ConnectorSession.class::equals).count(); + expectedParameterCount += convention.getReturnConvention().getParameterCount(); if (scalarFunctionImplementation.getInstanceFactory().isPresent()) { expectedParameterCount++; } @@ -262,6 +264,12 @@ private static void verifyMethodHandleSignature(BoundSignature boundSignature, S verifyFunctionSignature(methodType.returnType().isAssignableFrom(wrap(returnType.getJavaType())), "Expected return type to be %s, but is %s", returnType.getJavaType(), wrap(methodType.returnType())); break; + case BLOCK_BUILDER: + verifyFunctionSignature(methodType.lastParameterType().equals(BlockBuilder.class), + "Expected last argument type to be BlockBuilder, but is %s", methodType.lastParameterType()); + verifyFunctionSignature(methodType.returnType().equals(void.class), + "Expected return type to be void, but is %s", methodType.returnType()); + break; default: throw new UnsupportedOperationException("Unknown return convention: " + convention.getReturnConvention()); } diff --git a/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunction.java b/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunction.java index 81f1fd46ec04..9a3478642f34 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunction.java +++ b/core/trino-main/src/main/java/io/trino/metadata/PolymorphicScalarFunction.java @@ -144,6 +144,9 @@ private static boolean matchesParameterAndReturnTypes( } methodParameterIndex += argumentConvention.getParameterCount(); } + if (returnConvention == InvocationReturnConvention.BLOCK_BUILDER) { + throw new UnsupportedOperationException("BLOCK_BUILDER return convention is not yet supported"); + } return method.getReturnType().equals(getNullAwareContainerType(boundSignature.getReturnType().getJavaType(), returnConvention)); } @@ -174,6 +177,7 @@ private static Class getNullAwareContainerType(Class clazz, InvocationRetu return switch (returnConvention) { case NULLABLE_RETURN -> Primitives.wrap(clazz); case DEFAULT_ON_NULL, FAIL_ON_NULL -> clazz; + case BLOCK_BUILDER -> void.class; }; } diff --git a/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java b/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java index 61fdf175ff2a..a8477213765c 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java @@ -134,6 +134,7 @@ import io.trino.operator.scalar.GenericIndeterminateOperator; import io.trino.operator.scalar.GenericLessThanOperator; import io.trino.operator.scalar.GenericLessThanOrEqualOperator; +import io.trino.operator.scalar.GenericReadValueOperator; import io.trino.operator.scalar.GenericXxHash64Operator; import io.trino.operator.scalar.HmacFunctions; import io.trino.operator.scalar.HyperLogLogFunctions; @@ -569,6 +570,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators .functions(MAP_FILTER_FUNCTION, new MapTransformKeysFunction(blockTypeOperators), MAP_TRANSFORM_VALUES_FUNCTION) .function(FORMAT_FUNCTION) .function(TRY_CAST) + .function(new GenericReadValueOperator(typeOperators)) .function(new GenericEqualOperator(typeOperators)) .function(new GenericHashCodeOperator(typeOperators)) .function(new GenericXxHash64Operator(typeOperators)) diff --git a/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionsParserHelper.java b/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionsParserHelper.java index 53eb65b5aec8..f4961ac15f55 100644 --- a/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionsParserHelper.java +++ b/core/trino-main/src/main/java/io/trino/operator/annotations/FunctionsParserHelper.java @@ -63,6 +63,7 @@ import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM; import static io.trino.spi.function.OperatorType.LESS_THAN; import static io.trino.spi.function.OperatorType.LESS_THAN_OR_EQUAL; +import static io.trino.spi.function.OperatorType.READ_VALUE; import static io.trino.spi.function.OperatorType.XX_HASH_64; import static io.trino.sql.analyzer.TypeSignatureTranslator.parseTypeSignature; import static java.lang.String.CASE_INSENSITIVE_ORDER; @@ -119,6 +120,9 @@ else if (ORDERABLE_TYPE_OPERATORS.contains(operator)) { verifyTypeSignatureDoesNotContainAnyTypeParameters(typeSignature, typeSignature, typeParameterNames); } } + else if (operator == READ_VALUE) { + verifyOperatorSignature(operator, argumentTypes); + } else { throw new IllegalArgumentException("Operator dependency on " + operator + " is not allowed"); } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayAllMatchFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayAllMatchFunction.java index ef4194b2ff18..0f0efb57aa80 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayAllMatchFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayAllMatchFunction.java @@ -14,15 +14,20 @@ package io.trino.operator.scalar; import io.trino.spi.block.Block; +import io.trino.spi.function.Convention; import io.trino.spi.function.Description; +import io.trino.spi.function.OperatorDependency; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; -import io.trino.spi.function.TypeParameterSpecialization; import io.trino.spi.type.StandardTypes; -import io.trino.spi.type.Type; +import java.lang.invoke.MethodHandle; + +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.OperatorType.READ_VALUE; import static java.lang.Boolean.FALSE; @Description("Returns true if all elements of the array match the given predicate") @@ -32,110 +37,20 @@ public final class ArrayAllMatchFunction private ArrayAllMatchFunction() {} @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = Object.class) @SqlType(StandardTypes.BOOLEAN) @SqlNullable - public static Boolean allMatchObject( - @TypeParameter("T") Type elementType, + public static Boolean allMatch( + @OperatorDependency(operator = READ_VALUE, argumentTypes = "T", convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle readValue, @SqlType("array(T)") Block arrayBlock, @SqlType("function(T, boolean)") ObjectToBooleanFunction function) + throws Throwable { boolean hasNullResult = false; int positionCount = arrayBlock.getPositionCount(); for (int i = 0; i < positionCount; i++) { Object element = null; if (!arrayBlock.isNull(i)) { - element = elementType.getObject(arrayBlock, i); - } - Boolean match = function.apply(element); - if (FALSE.equals(match)) { - return false; - } - if (match == null) { - hasNullResult = true; - } - } - if (hasNullResult) { - return null; - } - return true; - } - - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = long.class) - @SqlType(StandardTypes.BOOLEAN) - @SqlNullable - public static Boolean allMatchLong( - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") LongToBooleanFunction function) - { - boolean hasNullResult = false; - int positionCount = arrayBlock.getPositionCount(); - for (int i = 0; i < positionCount; i++) { - Long element = null; - if (!arrayBlock.isNull(i)) { - element = elementType.getLong(arrayBlock, i); - } - Boolean match = function.apply(element); - if (FALSE.equals(match)) { - return false; - } - if (match == null) { - hasNullResult = true; - } - } - if (hasNullResult) { - return null; - } - return true; - } - - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = double.class) - @SqlType(StandardTypes.BOOLEAN) - @SqlNullable - public static Boolean allMatchDouble( - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") DoubleToBooleanFunction function) - { - boolean hasNullResult = false; - int positionCount = arrayBlock.getPositionCount(); - for (int i = 0; i < positionCount; i++) { - Double element = null; - if (!arrayBlock.isNull(i)) { - element = elementType.getDouble(arrayBlock, i); - } - Boolean match = function.apply(element); - if (FALSE.equals(match)) { - return false; - } - if (match == null) { - hasNullResult = true; - } - } - if (hasNullResult) { - return null; - } - return true; - } - - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = boolean.class) - @SqlType(StandardTypes.BOOLEAN) - @SqlNullable - public static Boolean allMatchBoolean( - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") BooleanToBooleanFunction function) - { - boolean hasNullResult = false; - int positionCount = arrayBlock.getPositionCount(); - for (int i = 0; i < positionCount; i++) { - Boolean element = null; - if (!arrayBlock.isNull(i)) { - element = elementType.getBoolean(arrayBlock, i); + element = readValue.invoke(arrayBlock, i); } Boolean match = function.apply(element); if (FALSE.equals(match)) { diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayAnyMatchFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayAnyMatchFunction.java index dfb72414f15d..7e45b1faa261 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayAnyMatchFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayAnyMatchFunction.java @@ -14,15 +14,20 @@ package io.trino.operator.scalar; import io.trino.spi.block.Block; +import io.trino.spi.function.Convention; import io.trino.spi.function.Description; +import io.trino.spi.function.OperatorDependency; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; -import io.trino.spi.function.TypeParameterSpecialization; import io.trino.spi.type.StandardTypes; -import io.trino.spi.type.Type; +import java.lang.invoke.MethodHandle; + +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.OperatorType.READ_VALUE; import static java.lang.Boolean.TRUE; @Description("Returns true if the array contains one or more elements that match the given predicate") @@ -32,110 +37,20 @@ public final class ArrayAnyMatchFunction private ArrayAnyMatchFunction() {} @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = Object.class) @SqlType(StandardTypes.BOOLEAN) @SqlNullable - public static Boolean anyMatchObject( - @TypeParameter("T") Type elementType, + public static Boolean anyMatch( + @OperatorDependency(operator = READ_VALUE, argumentTypes = "T", convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle readValue, @SqlType("array(T)") Block arrayBlock, @SqlType("function(T, boolean)") ObjectToBooleanFunction function) + throws Throwable { boolean hasNullResult = false; int positionCount = arrayBlock.getPositionCount(); for (int i = 0; i < positionCount; i++) { Object element = null; if (!arrayBlock.isNull(i)) { - element = elementType.getObject(arrayBlock, i); - } - Boolean match = function.apply(element); - if (TRUE.equals(match)) { - return true; - } - if (match == null) { - hasNullResult = true; - } - } - if (hasNullResult) { - return null; - } - return false; - } - - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = long.class) - @SqlType(StandardTypes.BOOLEAN) - @SqlNullable - public static Boolean anyMatchLong( - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") LongToBooleanFunction function) - { - boolean hasNullResult = false; - int positionCount = arrayBlock.getPositionCount(); - for (int i = 0; i < positionCount; i++) { - Long element = null; - if (!arrayBlock.isNull(i)) { - element = elementType.getLong(arrayBlock, i); - } - Boolean match = function.apply(element); - if (TRUE.equals(match)) { - return true; - } - if (match == null) { - hasNullResult = true; - } - } - if (hasNullResult) { - return null; - } - return false; - } - - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = double.class) - @SqlType(StandardTypes.BOOLEAN) - @SqlNullable - public static Boolean anyMatchDouble( - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") DoubleToBooleanFunction function) - { - boolean hasNullResult = false; - int positionCount = arrayBlock.getPositionCount(); - for (int i = 0; i < positionCount; i++) { - Double element = null; - if (!arrayBlock.isNull(i)) { - element = elementType.getDouble(arrayBlock, i); - } - Boolean match = function.apply(element); - if (TRUE.equals(match)) { - return true; - } - if (match == null) { - hasNullResult = true; - } - } - if (hasNullResult) { - return null; - } - return false; - } - - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = boolean.class) - @SqlType(StandardTypes.BOOLEAN) - @SqlNullable - public static Boolean anyMatchBoolean( - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") BooleanToBooleanFunction function) - { - boolean hasNullResult = false; - int positionCount = arrayBlock.getPositionCount(); - for (int i = 0; i < positionCount; i++) { - Boolean element = null; - if (!arrayBlock.isNull(i)) { - element = elementType.getBoolean(arrayBlock, i); + element = readValue.invoke(arrayBlock, i); } Boolean match = function.apply(element); if (TRUE.equals(match)) { diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayConcatFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayConcatFunction.java index 86d42d8b0501..931e921761be 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayConcatFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayConcatFunction.java @@ -13,13 +13,11 @@ */ package io.trino.operator.scalar; -import com.google.common.collect.ImmutableList; import io.trino.annotation.UsedByGeneratedCode; import io.trino.metadata.SqlScalarFunction; -import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BufferedArrayValueBuilder; import io.trino.spi.function.BoundSignature; import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.Signature; @@ -29,6 +27,7 @@ import io.trino.sql.gen.VarArgsToArrayAdapterGenerator; import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; import java.util.Optional; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; @@ -36,7 +35,8 @@ import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.type.TypeSignature.arrayType; import static io.trino.sql.gen.VarArgsToArrayAdapterGenerator.generateVarArgsToArrayAdapter; -import static io.trino.util.Reflection.methodHandle; +import static java.lang.invoke.MethodHandles.lookup; +import static java.lang.invoke.MethodType.methodType; import static java.util.Collections.nCopies; public final class ArrayConcatFunction @@ -47,8 +47,19 @@ public final class ArrayConcatFunction private static final String FUNCTION_NAME = "concat"; private static final String DESCRIPTION = "Concatenates given arrays"; - private static final MethodHandle METHOD_HANDLE = methodHandle(ArrayConcatFunction.class, "concat", Type.class, Object.class, Block[].class); - private static final MethodHandle USER_STATE_FACTORY = methodHandle(ArrayConcatFunction.class, "createState", Type.class); + private static final MethodHandle METHOD_HANDLE; + private static final MethodHandle USER_STATE_FACTORY; + + static { + try { + MethodHandles.Lookup lookup = lookup(); + METHOD_HANDLE = lookup.findStatic(ArrayConcatFunction.class, "concat", methodType(Block.class, Type.class, Object.class, Block[].class)); + USER_STATE_FACTORY = lookup.findStatic(ArrayConcatFunction.class, "createState", methodType(Object.class, ArrayType.class)); + } + catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } private ArrayConcatFunction() { @@ -78,7 +89,7 @@ protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) Block.class, boundSignature.getArity(), METHOD_HANDLE.bindTo(arrayType.getElementType()), - USER_STATE_FACTORY.bindTo(arrayType.getElementType())); + USER_STATE_FACTORY.bindTo(arrayType)); return new ChoicesSpecializedSqlScalarFunction( boundSignature, @@ -89,9 +100,9 @@ protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) } @UsedByGeneratedCode - public static Object createState(Type elementType) + public static Object createState(ArrayType arrayType) { - return new PageBuilder(ImmutableList.of(elementType)); + return BufferedArrayValueBuilder.createBuffered(arrayType); } @UsedByGeneratedCode @@ -99,12 +110,12 @@ public static Block concat(Type elementType, Object state, Block[] blocks) { int resultPositionCount = 0; - // fast path when there is at most one non empty block + // fast path when there is at most one non-empty block Block nonEmptyBlock = null; - for (int i = 0; i < blocks.length; i++) { - resultPositionCount += blocks[i].getPositionCount(); - if (blocks[i].getPositionCount() > 0) { - nonEmptyBlock = blocks[i]; + for (Block value : blocks) { + resultPositionCount += value.getPositionCount(); + if (value.getPositionCount() > 0) { + nonEmptyBlock = value; } } if (nonEmptyBlock == null) { @@ -114,19 +125,12 @@ public static Block concat(Type elementType, Object state, Block[] blocks) return nonEmptyBlock; } - PageBuilder pageBuilder = (PageBuilder) state; - if (pageBuilder.isFull()) { - pageBuilder.reset(); - } - - BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); - for (int blockIndex = 0; blockIndex < blocks.length; blockIndex++) { - Block block = blocks[blockIndex]; - for (int i = 0; i < block.getPositionCount(); i++) { - elementType.appendTo(block, i, blockBuilder); + return ((BufferedArrayValueBuilder) state).build(resultPositionCount, elementBuilder -> { + for (Block block : blocks) { + for (int i = 0; i < block.getPositionCount(); i++) { + elementType.appendTo(block, i, elementBuilder); + } } - } - pageBuilder.declarePositions(resultPositionCount); - return blockBuilder.getRegion(blockBuilder.getPositionCount() - resultPositionCount, resultPositionCount); + }); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayElementAtFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayElementAtFunction.java index ba0ccebe94b2..4dbcc6e3cce7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayElementAtFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayElementAtFunction.java @@ -15,14 +15,21 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.function.Convention; import io.trino.spi.function.Description; +import io.trino.spi.function.OperatorDependency; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; +import java.lang.invoke.MethodHandle; + import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.OperatorType.READ_VALUE; import static java.lang.Math.toIntExact; @ScalarFunction("element_at") @@ -34,55 +41,12 @@ private ArrayElementAtFunction() {} @TypeParameter("E") @SqlNullable @SqlType("E") - public static Long longElementAt(@TypeParameter("E") Type elementType, @SqlType("array(E)") Block array, @SqlType("bigint") long index) - { - int position = checkedIndexToBlockPosition(array, index); - if (position == -1) { - return null; - } - if (array.isNull(position)) { - return null; - } - - return elementType.getLong(array, position); - } - - @TypeParameter("E") - @SqlNullable - @SqlType("E") - public static Boolean booleanElementAt(@TypeParameter("E") Type elementType, @SqlType("array(E)") Block array, @SqlType("bigint") long index) - { - int position = checkedIndexToBlockPosition(array, index); - if (position == -1) { - return null; - } - if (array.isNull(position)) { - return null; - } - - return elementType.getBoolean(array, position); - } - - @TypeParameter("E") - @SqlNullable - @SqlType("E") - public static Double doubleElementAt(@TypeParameter("E") Type elementType, @SqlType("array(E)") Block array, @SqlType("bigint") long index) - { - int position = checkedIndexToBlockPosition(array, index); - if (position == -1) { - return null; - } - if (array.isNull(position)) { - return null; - } - - return elementType.getDouble(array, position); - } - - @TypeParameter("E") - @SqlNullable - @SqlType("E") - public static Object sliceElementAt(@TypeParameter("E") Type elementType, @SqlType("array(E)") Block array, @SqlType("bigint") long index) + public static Object elementAt( + @TypeParameter("E") Type elementType, + @OperatorDependency(operator = READ_VALUE, argumentTypes = "E", convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle readValue, + @SqlType("array(E)") Block array, + @SqlType("bigint") long index) + throws Throwable { int position = checkedIndexToBlockPosition(array, index); if (position == -1) { @@ -92,7 +56,7 @@ public static Object sliceElementAt(@TypeParameter("E") Type elementType, @SqlTy return null; } - return elementType.getObject(array, position); + return readValue.invoke(array, position); } /** diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayMaxFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayMaxFunction.java index 127519104fd1..68dab4ac6b79 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayMaxFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayMaxFunction.java @@ -21,16 +21,13 @@ import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; -import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.OperatorType.COMPARISON_UNORDERED_FIRST; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.spi.type.RealType.REAL; -import static io.trino.util.Failures.internalError; +import static io.trino.spi.function.OperatorType.READ_VALUE; @ScalarFunction("array_max") @Description("Get maximum value of array") @@ -41,139 +38,28 @@ private ArrayMaxFunction() {} @TypeParameter("T") @SqlType("T") @SqlNullable - public static Long longArrayMax( + public static Object arrayMax( @OperatorDependency( operator = COMPARISON_UNORDERED_FIRST, argumentTypes = {"T", "T"}, convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareMethodHandle, - @TypeParameter("T") Type elementType, + @OperatorDependency(operator = READ_VALUE, argumentTypes = "T", convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle readValue, @SqlType("array(T)") Block block) + throws Throwable { - int selectedPosition = findMaxArrayElement(compareMethodHandle, block); - if (selectedPosition < 0) { - return null; - } - return elementType.getLong(block, selectedPosition); - } - - @TypeParameter("T") - @SqlType("T") - @SqlNullable - public static Boolean booleanArrayMax( - @OperatorDependency( - operator = COMPARISON_UNORDERED_FIRST, - argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareMethodHandle, - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block block) - { - int selectedPosition = findMaxArrayElement(compareMethodHandle, block); - if (selectedPosition < 0) { - return null; - } - return elementType.getBoolean(block, selectedPosition); - } - - @TypeParameter("T") - @SqlType("T") - @SqlNullable - public static Double doubleArrayMax( - @OperatorDependency( - operator = COMPARISON_UNORDERED_FIRST, - argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareMethodHandle, - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block block) - { - int selectedPosition = findMaxArrayElement(compareMethodHandle, block); - if (selectedPosition < 0) { - return null; - } - return elementType.getDouble(block, selectedPosition); - } - - @TypeParameter("T") - @SqlType("T") - @SqlNullable - public static Object objectArrayMax( - @OperatorDependency( - operator = COMPARISON_UNORDERED_FIRST, - argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareMethodHandle, - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block block) - { - int selectedPosition = findMaxArrayElement(compareMethodHandle, block); - if (selectedPosition < 0) { - return null; - } - return elementType.getObject(block, selectedPosition); - } - - private static int findMaxArrayElement(MethodHandle compareMethodHandle, Block block) - { - try { - int selectedPosition = -1; - for (int position = 0; position < block.getPositionCount(); position++) { - if (block.isNull(position)) { - return -1; - } - if (selectedPosition < 0 || ((long) compareMethodHandle.invokeExact(block, position, block, selectedPosition)) > 0) { - selectedPosition = position; - } - } - return selectedPosition; - } - catch (Throwable t) { - throw internalError(t); - } - } - - @SqlType("double") - @SqlNullable - public static Double doubleTypeArrayMax(@SqlType("array(double)") Block block) - { - if (block.getPositionCount() == 0) { - return null; - } int selectedPosition = -1; for (int position = 0; position < block.getPositionCount(); position++) { if (block.isNull(position)) { return null; } - if (selectedPosition < 0 || doubleGreater(DOUBLE.getDouble(block, position), DOUBLE.getDouble(block, selectedPosition))) { + if (selectedPosition < 0 || ((long) compareMethodHandle.invokeExact(block, position, block, selectedPosition)) > 0) { selectedPosition = position; } } - return DOUBLE.getDouble(block, selectedPosition); - } - private static boolean doubleGreater(double left, double right) - { - return (left > right) || Double.isNaN(right); - } - - @SqlType("real") - @SqlNullable - public static Long realTypeArrayMax(@SqlType("array(real)") Block block) - { - if (block.getPositionCount() == 0) { + if (selectedPosition < 0) { return null; } - int selectedPosition = -1; - for (int position = 0; position < block.getPositionCount(); position++) { - if (block.isNull(position)) { - return null; - } - if (selectedPosition < 0 || floatGreater(REAL.getFloat(block, position), REAL.getFloat(block, selectedPosition))) { - selectedPosition = position; - } - } - return REAL.getLong(block, selectedPosition); - } - - private static boolean floatGreater(float left, float right) - { - return (left > right) || Float.isNaN(right); + return readValue.invoke(block, selectedPosition); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayMinFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayMinFunction.java index 5bea93fa2316..375aceeee61e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayMinFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayMinFunction.java @@ -21,14 +21,13 @@ import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; -import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.OperatorType.COMPARISON_UNORDERED_LAST; -import static io.trino.util.Failures.internalError; +import static io.trino.spi.function.OperatorType.READ_VALUE; @ScalarFunction("array_min") @Description("Get minimum value of array") @@ -39,91 +38,28 @@ private ArrayMinFunction() {} @TypeParameter("T") @SqlType("T") @SqlNullable - public static Long longArrayMin( + public static Object arrayMin( @OperatorDependency( operator = COMPARISON_UNORDERED_LAST, argumentTypes = {"T", "T"}, convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareMethodHandle, - @TypeParameter("T") Type elementType, + @OperatorDependency(operator = READ_VALUE, argumentTypes = "T", convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle readValue, @SqlType("array(T)") Block block) + throws Throwable { - int selectedPosition = findMinArrayElement(compareMethodHandle, block); - if (selectedPosition < 0) { - return null; - } - return elementType.getLong(block, selectedPosition); - } - - @TypeParameter("T") - @SqlType("T") - @SqlNullable - public static Boolean booleanArrayMin( - @OperatorDependency( - operator = COMPARISON_UNORDERED_LAST, - argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareMethodHandle, - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block block) - { - int selectedPosition = findMinArrayElement(compareMethodHandle, block); - if (selectedPosition < 0) { - return null; - } - return elementType.getBoolean(block, selectedPosition); - } - - @TypeParameter("T") - @SqlType("T") - @SqlNullable - public static Double doubleArrayMin( - @OperatorDependency( - operator = COMPARISON_UNORDERED_LAST, - argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareMethodHandle, - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block block) - { - int selectedPosition = findMinArrayElement(compareMethodHandle, block); - if (selectedPosition < 0) { - return null; + int selectedPosition = -1; + for (int position = 0; position < block.getPositionCount(); position++) { + if (block.isNull(position)) { + return null; + } + if (selectedPosition < 0 || ((long) compareMethodHandle.invokeExact(block, position, block, selectedPosition)) < 0) { + selectedPosition = position; + } } - return elementType.getDouble(block, selectedPosition); - } - - @TypeParameter("T") - @SqlType("T") - @SqlNullable - public static Object objectArrayMin( - @OperatorDependency( - operator = COMPARISON_UNORDERED_LAST, - argumentTypes = {"T", "T"}, - convention = @Convention(arguments = {BLOCK_POSITION_NOT_NULL, BLOCK_POSITION_NOT_NULL}, result = FAIL_ON_NULL)) MethodHandle compareMethodHandle, - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block block) - { - int selectedPosition = findMinArrayElement(compareMethodHandle, block); if (selectedPosition < 0) { return null; } - return elementType.getObject(block, selectedPosition); - } - private static int findMinArrayElement(MethodHandle compareMethodHandle, Block block) - { - try { - int selectedPosition = -1; - for (int position = 0; position < block.getPositionCount(); position++) { - if (block.isNull(position)) { - return -1; - } - if (selectedPosition < 0 || ((long) compareMethodHandle.invokeExact(block, position, block, selectedPosition)) < 0) { - selectedPosition = position; - } - } - return selectedPosition; - } - catch (Throwable t) { - throw internalError(t); - } + return readValue.invoke(block, selectedPosition); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayNoneMatchFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayNoneMatchFunction.java index 51b39e5f351a..01b815356570 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayNoneMatchFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayNoneMatchFunction.java @@ -14,14 +14,20 @@ package io.trino.operator.scalar; import io.trino.spi.block.Block; +import io.trino.spi.function.Convention; import io.trino.spi.function.Description; +import io.trino.spi.function.OperatorDependency; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; -import io.trino.spi.function.TypeParameterSpecialization; import io.trino.spi.type.StandardTypes; -import io.trino.spi.type.Type; + +import java.lang.invoke.MethodHandle; + +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.OperatorType.READ_VALUE; @Description("Returns true if all elements of the array don't match the given predicate") @ScalarFunction("none_match") @@ -30,63 +36,15 @@ public final class ArrayNoneMatchFunction private ArrayNoneMatchFunction() {} @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = Object.class) @SqlType(StandardTypes.BOOLEAN) @SqlNullable - public static Boolean noneMatchObject( - @TypeParameter("T") Type elementType, + public static Boolean noneMatch( + @OperatorDependency(operator = READ_VALUE, argumentTypes = "T", convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle readValue, @SqlType("array(T)") Block arrayBlock, @SqlType("function(T, boolean)") ObjectToBooleanFunction function) + throws Throwable { - Boolean anyMatchResult = ArrayAnyMatchFunction.anyMatchObject(elementType, arrayBlock, function); - if (anyMatchResult == null) { - return null; - } - return !anyMatchResult; - } - - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = long.class) - @SqlType(StandardTypes.BOOLEAN) - @SqlNullable - public static Boolean noneMatchLong( - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") LongToBooleanFunction function) - { - Boolean anyMatchResult = ArrayAnyMatchFunction.anyMatchLong(elementType, arrayBlock, function); - if (anyMatchResult == null) { - return null; - } - return !anyMatchResult; - } - - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = double.class) - @SqlType(StandardTypes.BOOLEAN) - @SqlNullable - public static Boolean noneMatchDouble( - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") DoubleToBooleanFunction function) - { - Boolean anyMatchResult = ArrayAnyMatchFunction.anyMatchDouble(elementType, arrayBlock, function); - if (anyMatchResult == null) { - return null; - } - return !anyMatchResult; - } - - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = boolean.class) - @SqlType(StandardTypes.BOOLEAN) - @SqlNullable - public static Boolean noneMatchBoolean( - @TypeParameter("T") Type elementType, - @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") BooleanToBooleanFunction function) - { - Boolean anyMatchResult = ArrayAnyMatchFunction.anyMatchBoolean(elementType, arrayBlock, function); + Boolean anyMatchResult = ArrayAnyMatchFunction.anyMatch(readValue, arrayBlock, function); if (anyMatchResult == null) { return null; } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayRemoveFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayRemoveFunction.java index b8b6134e4f13..879a6c86e1ca 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayRemoveFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayRemoveFunction.java @@ -13,17 +13,16 @@ */ package io.trino.operator.scalar; -import com.google.common.collect.ImmutableList; -import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BufferedArrayValueBuilder; import io.trino.spi.function.Convention; import io.trino.spi.function.Description; import io.trino.spi.function.OperatorDependency; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -41,12 +40,12 @@ @Description("Remove specified values from the given array") public final class ArrayRemoveFunction { - private final PageBuilder pageBuilder; + private final BufferedArrayValueBuilder arrayValueBuilder; @TypeParameter("E") public ArrayRemoveFunction(@TypeParameter("E") Type elementType) { - pageBuilder = new PageBuilder(ImmutableList.of(elementType)); + arrayValueBuilder = BufferedArrayValueBuilder.createBuffered(new ArrayType(elementType)); } @TypeParameter("E") @@ -56,52 +55,7 @@ public Block remove( operator = EQUAL, argumentTypes = {"E", "E"}, convention = @Convention(arguments = {NEVER_NULL, NEVER_NULL}, result = NULLABLE_RETURN)) - MethodHandle equalsFunction, - @TypeParameter("E") Type type, - @SqlType("array(E)") Block array, - @SqlType("E") long value) - { - return remove(equalsFunction, type, array, (Object) value); - } - - @TypeParameter("E") - @SqlType("array(E)") - public Block remove( - @OperatorDependency( - operator = EQUAL, - argumentTypes = {"E", "E"}, - convention = @Convention(arguments = {NEVER_NULL, NEVER_NULL}, result = NULLABLE_RETURN)) - MethodHandle equalsFunction, - @TypeParameter("E") Type type, - @SqlType("array(E)") Block array, - @SqlType("E") double value) - { - return remove(equalsFunction, type, array, (Object) value); - } - - @TypeParameter("E") - @SqlType("array(E)") - public Block remove( - @OperatorDependency( - operator = EQUAL, - argumentTypes = {"E", "E"}, - convention = @Convention(arguments = {NEVER_NULL, NEVER_NULL}, result = NULLABLE_RETURN)) - MethodHandle equalsFunction, - @TypeParameter("E") Type type, - @SqlType("array(E)") Block array, - @SqlType("E") boolean value) - { - return remove(equalsFunction, type, array, (Object) value); - } - - @TypeParameter("E") - @SqlType("array(E)") - public Block remove( - @OperatorDependency( - operator = EQUAL, - argumentTypes = {"E", "E"}, - convention = @Convention(arguments = {NEVER_NULL, NEVER_NULL}, result = NULLABLE_RETURN)) - MethodHandle equalsFunction, + MethodHandle equalFunction, @TypeParameter("E") Type type, @SqlType("array(E)") Block array, @SqlType("E") Object value) @@ -116,7 +70,7 @@ public Block remove( positions.add(i); continue; } - Boolean result = (Boolean) equalsFunction.invoke(element, value); + Boolean result = (Boolean) equalFunction.invoke(element, value); if (result == null) { throw new TrinoException(NOT_SUPPORTED, "array_remove does not support arrays with elements that are null or contain null"); } @@ -133,16 +87,10 @@ public Block remove( return array; } - if (pageBuilder.isFull()) { - pageBuilder.reset(); - } - BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); - - for (int position : positions) { - type.appendTo(array, position, blockBuilder); - } - - pageBuilder.declarePositions(positions.size()); - return blockBuilder.getRegion(blockBuilder.getPositionCount() - positions.size(), positions.size()); + return arrayValueBuilder.build(positions.size(), elementBuilder -> { + for (int position : positions) { + type.appendTo(array, position, elementBuilder); + } + }); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayReverseFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayReverseFunction.java index e916539cf045..365e1b5a44e8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayReverseFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayReverseFunction.java @@ -13,26 +13,25 @@ */ package io.trino.operator.scalar; -import com.google.common.collect.ImmutableList; -import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BufferedArrayValueBuilder; import io.trino.spi.function.Description; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; @ScalarFunction("reverse") @Description("Returns an array which has the reversed order of the given array.") public final class ArrayReverseFunction { - private final PageBuilder pageBuilder; + private final BufferedArrayValueBuilder arrayValueBuilder; @TypeParameter("E") public ArrayReverseFunction(@TypeParameter("E") Type elementType) { - pageBuilder = new PageBuilder(ImmutableList.of(elementType)); + arrayValueBuilder = BufferedArrayValueBuilder.createBuffered(new ArrayType(elementType)); } @TypeParameter("E") @@ -47,16 +46,10 @@ public Block reverse( return block; } - if (pageBuilder.isFull()) { - pageBuilder.reset(); - } - - BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); - for (int i = arrayLength - 1; i >= 0; i--) { - type.appendTo(block, i, blockBuilder); - } - pageBuilder.declarePositions(arrayLength); - - return blockBuilder.getRegion(blockBuilder.getPositionCount() - arrayLength, arrayLength); + return arrayValueBuilder.build(arrayLength, elementBuilder -> { + for (int i = arrayLength - 1; i >= 0; i--) { + type.appendTo(block, i, elementBuilder); + } + }); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayShuffleFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayShuffleFunction.java index 81fecc20580f..ffce6ef37a5c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayShuffleFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayShuffleFunction.java @@ -13,14 +13,13 @@ */ package io.trino.operator.scalar; -import com.google.common.collect.ImmutableList; -import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BufferedArrayValueBuilder; import io.trino.spi.function.Description; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import java.util.concurrent.ThreadLocalRandom; @@ -29,14 +28,14 @@ @Description("Generates a random permutation of the given array.") public final class ArrayShuffleFunction { - private final PageBuilder pageBuilder; + private final BufferedArrayValueBuilder arrayValueBuilder; private static final int INITIAL_LENGTH = 128; private int[] positions = new int[INITIAL_LENGTH]; @TypeParameter("E") public ArrayShuffleFunction(@TypeParameter("E") Type elementType) { - pageBuilder = new PageBuilder(ImmutableList.of(elementType)); + arrayValueBuilder = BufferedArrayValueBuilder.createBuffered(new ArrayType(elementType)); } @TypeParameter("E") @@ -62,17 +61,10 @@ public Block shuffle( positions[index] = swap; } - if (pageBuilder.isFull()) { - pageBuilder.reset(); - } - - BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); - - for (int i = 0; i < length; i++) { - type.appendTo(block, positions[i], blockBuilder); - } - pageBuilder.declarePositions(length); - - return blockBuilder.getRegion(blockBuilder.getPositionCount() - length, length); + return arrayValueBuilder.build(length, elementBuilder -> { + for (int i = 0; i < length; i++) { + type.appendTo(block, positions[i], elementBuilder); + } + }); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySortComparatorFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySortComparatorFunction.java index 4d900e0b7caf..b27fca5b36c1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySortComparatorFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySortComparatorFunction.java @@ -13,118 +13,75 @@ */ package io.trino.operator.scalar; -import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; -import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BufferedArrayValueBuilder; +import io.trino.spi.function.Convention; import io.trino.spi.function.Description; +import io.trino.spi.function.OperatorDependency; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; -import io.trino.spi.function.TypeParameterSpecialization; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.sql.gen.lambda.LambdaFunctionInterface; +import java.lang.invoke.MethodHandle; import java.util.Comparator; import java.util.List; +import static com.google.common.base.Throwables.throwIfUnchecked; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.OperatorType.READ_VALUE; import static io.trino.util.Failures.checkCondition; @ScalarFunction("array_sort") @Description("Sorts the given array with a lambda comparator.") public final class ArraySortComparatorFunction { - private final PageBuilder pageBuilder; + private final BufferedArrayValueBuilder arrayValueBuilder; private static final int INITIAL_LENGTH = 128; private List positions = Ints.asList(new int[INITIAL_LENGTH]); @TypeParameter("T") public ArraySortComparatorFunction(@TypeParameter("T") Type elementType) { - pageBuilder = new PageBuilder(ImmutableList.of(elementType)); + arrayValueBuilder = BufferedArrayValueBuilder.createBuffered(new ArrayType(elementType)); } @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = long.class) @SqlType("array(T)") - public Block sortLong( - @TypeParameter("T") Type type, - @SqlType("array(T)") Block block, - @SqlType("function(T, T, integer)") ComparatorLongLambda function) - { - int arrayLength = block.getPositionCount(); - initPositionsList(arrayLength); - - Comparator comparator = (x, y) -> comparatorResult(function.apply( - block.isNull(x) ? null : type.getLong(block, x), - block.isNull(y) ? null : type.getLong(block, y))); - - sortPositions(arrayLength, comparator); - - return computeResultBlock(type, block, arrayLength); - } - - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = double.class) - @SqlType("array(T)") - public Block sortDouble( - @TypeParameter("T") Type type, - @SqlType("array(T)") Block block, - @SqlType("function(T, T, integer)") ComparatorDoubleLambda function) - { - int arrayLength = block.getPositionCount(); - initPositionsList(arrayLength); - - Comparator comparator = (x, y) -> comparatorResult(function.apply( - block.isNull(x) ? null : type.getDouble(block, x), - block.isNull(y) ? null : type.getDouble(block, y))); - - sortPositions(arrayLength, comparator); - - return computeResultBlock(type, block, arrayLength); - } - - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = boolean.class) - @SqlType("array(T)") - public Block sortBoolean( - @TypeParameter("T") Type type, - @SqlType("array(T)") Block block, - @SqlType("function(T, T, integer)") ComparatorBooleanLambda function) - { - int arrayLength = block.getPositionCount(); - initPositionsList(arrayLength); - - Comparator comparator = (x, y) -> comparatorResult(function.apply( - block.isNull(x) ? null : type.getBoolean(block, x), - block.isNull(y) ? null : type.getBoolean(block, y))); - - sortPositions(arrayLength, comparator); - - return computeResultBlock(type, block, arrayLength); - } - - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = Object.class) - @SqlType("array(T)") - public Block sortObject( + public Block sort( @TypeParameter("T") Type type, + @OperatorDependency(operator = READ_VALUE, argumentTypes = "T", convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = FAIL_ON_NULL)) MethodHandle readValue, @SqlType("array(T)") Block block, @SqlType("function(T, T, integer)") ComparatorObjectLambda function) { int arrayLength = block.getPositionCount(); initPositionsList(arrayLength); - Comparator comparator = (x, y) -> comparatorResult(function.apply( - block.isNull(x) ? null : type.getObject(block, x), - block.isNull(y) ? null : type.getObject(block, y))); + Comparator comparator = (x, y) -> { + try { + return comparatorResult(function.apply( + block.isNull(x) ? null : readValue.invoke(block, x), + block.isNull(y) ? null : readValue.invoke(block, y))); + } + catch (Throwable e) { + throwIfUnchecked(e); + throw new RuntimeException(e); + } + }; sortPositions(arrayLength, comparator); - return computeResultBlock(type, block, arrayLength); + return arrayValueBuilder.build(arrayLength, elementBuilder -> { + for (int i = 0; i < arrayLength; i++) { + type.appendTo(block, positions.get(i), elementBuilder); + } + }); } private void initPositionsList(int arrayLength) @@ -149,22 +106,6 @@ private void sortPositions(int arrayLength, Comparator comparator) } } - private Block computeResultBlock(Type type, Block block, int arrayLength) - { - if (pageBuilder.isFull()) { - pageBuilder.reset(); - } - - BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); - - for (int i = 0; i < arrayLength; ++i) { - type.appendTo(block, positions.get(i), blockBuilder); - } - pageBuilder.declarePositions(arrayLength); - - return blockBuilder.getRegion(blockBuilder.getPositionCount() - arrayLength, arrayLength); - } - private static int comparatorResult(Long result) { checkCondition( @@ -174,27 +115,6 @@ private static int comparatorResult(Long result) return result.intValue(); } - @FunctionalInterface - public interface ComparatorLongLambda - extends LambdaFunctionInterface - { - Long apply(Long x, Long y); - } - - @FunctionalInterface - public interface ComparatorDoubleLambda - extends LambdaFunctionInterface - { - Long apply(Double x, Double y); - } - - @FunctionalInterface - public interface ComparatorBooleanLambda - extends LambdaFunctionInterface - { - Long apply(Boolean x, Boolean y); - } - @FunctionalInterface public interface ComparatorObjectLambda extends LambdaFunctionInterface diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySortFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySortFunction.java index ad9bb0842014..4a279f1268ad 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySortFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySortFunction.java @@ -13,16 +13,15 @@ */ package io.trino.operator.scalar; -import com.google.common.collect.ImmutableList; -import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BufferedArrayValueBuilder; import io.trino.spi.function.Convention; import io.trino.spi.function.Description; import io.trino.spi.function.OperatorDependency; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.type.BlockTypeOperators.BlockPositionComparison; import it.unimi.dsi.fastutil.ints.IntArrayList; @@ -36,14 +35,14 @@ public final class ArraySortFunction { public static final String NAME = "array_sort"; - private final PageBuilder pageBuilder; + private final BufferedArrayValueBuilder arrayValueBuilder; private static final int INITIAL_LENGTH = 128; private final IntArrayList positions = new IntArrayList(INITIAL_LENGTH); @TypeParameter("E") public ArraySortFunction(@TypeParameter("E") Type elementType) { - pageBuilder = new PageBuilder(ImmutableList.of(elementType)); + arrayValueBuilder = BufferedArrayValueBuilder.createBuffered(new ArrayType(elementType)); } @TypeParameter("E") @@ -78,17 +77,10 @@ public Block sort( return (int) comparisonOperator.compare(block, left, block, right); }); - if (pageBuilder.isFull()) { - pageBuilder.reset(); - } - - BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); - - for (int i = 0; i < arrayLength; i++) { - type.appendTo(block, positions.getInt(i), blockBuilder); - } - pageBuilder.declarePositions(arrayLength); - - return blockBuilder.getRegion(blockBuilder.getPositionCount() - arrayLength, arrayLength); + return arrayValueBuilder.build(arrayLength, elementBuilder -> { + for (int i = 0; i < arrayLength; i++) { + type.appendTo(block, positions.getInt(i), elementBuilder); + } + }); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySubscriptOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySubscriptOperator.java index 07019bfff6ee..433218c90d89 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySubscriptOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArraySubscriptOperator.java @@ -14,12 +14,12 @@ package io.trino.operator.scalar; import com.google.common.collect.ImmutableList; -import io.airlift.slice.Slice; -import io.trino.annotation.UsedByGeneratedCode; import io.trino.metadata.SqlScalarFunction; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionDependencies; +import io.trino.spi.function.FunctionDependencyDeclaration; import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.Signature; import io.trino.spi.type.Type; @@ -28,28 +28,44 @@ import java.lang.invoke.MethodHandle; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; +import static io.trino.spi.function.InvocationConvention.simpleConvention; +import static io.trino.spi.function.OperatorType.READ_VALUE; import static io.trino.spi.function.OperatorType.SUBSCRIPT; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.TypeSignature.arrayType; -import static io.trino.util.Reflection.methodHandle; import static java.lang.Math.toIntExact; import static java.lang.String.format; -import static java.util.Objects.requireNonNull; +import static java.lang.invoke.MethodHandles.collectArguments; +import static java.lang.invoke.MethodHandles.empty; +import static java.lang.invoke.MethodHandles.explicitCastArguments; +import static java.lang.invoke.MethodHandles.guardWithTest; +import static java.lang.invoke.MethodHandles.lookup; +import static java.lang.invoke.MethodHandles.permuteArguments; +import static java.lang.invoke.MethodType.methodType; public class ArraySubscriptOperator extends SqlScalarFunction { public static final ArraySubscriptOperator ARRAY_SUBSCRIPT = new ArraySubscriptOperator(); - private static final MethodHandle METHOD_HANDLE_BOOLEAN = methodHandle(ArraySubscriptOperator.class, "booleanSubscript", Type.class, Block.class, long.class); - private static final MethodHandle METHOD_HANDLE_LONG = methodHandle(ArraySubscriptOperator.class, "longSubscript", Type.class, Block.class, long.class); - private static final MethodHandle METHOD_HANDLE_DOUBLE = methodHandle(ArraySubscriptOperator.class, "doubleSubscript", Type.class, Block.class, long.class); - private static final MethodHandle METHOD_HANDLE_SLICE = methodHandle(ArraySubscriptOperator.class, "sliceSubscript", Type.class, Block.class, long.class); - private static final MethodHandle METHOD_HANDLE_OBJECT = methodHandle(ArraySubscriptOperator.class, "objectSubscript", Type.class, Block.class, long.class); + private static final MethodHandle GET_POSITION; + private static final MethodHandle IS_POSITION_NULL; - protected ArraySubscriptOperator() + static { + try { + GET_POSITION = lookup().findStatic(ArraySubscriptOperator.class, "getPosition", methodType(int.class, Block.class, long.class)); + IS_POSITION_NULL = lookup().findVirtual(Block.class, "isNull", methodType(boolean.class, int.class)); + } + catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } + + private ArraySubscriptOperator() { super(FunctionMetadata.scalarBuilder() .signature(Signature.builder() @@ -64,29 +80,31 @@ protected ArraySubscriptOperator() } @Override - protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) + public FunctionDependencyDeclaration getFunctionDependencies() + { + return FunctionDependencyDeclaration.builder() + .addOperatorSignature(READ_VALUE, ImmutableList.of(new TypeSignature("E"))) + .build(); + } + + @Override + public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) { Type elementType = boundSignature.getReturnType(); + MethodHandle methodHandle = functionDependencies.getOperatorImplementation( + READ_VALUE, + ImmutableList.of(elementType), + simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)) + .getMethodHandle(); + Class expectedReturnType = methodType(elementType.getJavaType()).wrap().returnType(); + methodHandle = explicitCastArguments(methodHandle, methodHandle.type().changeReturnType(expectedReturnType)); + methodHandle = guardWithTest( + IS_POSITION_NULL, + empty(methodHandle.type()), + methodHandle); + methodHandle = collectArguments(methodHandle, 1, GET_POSITION); + methodHandle = permuteArguments(methodHandle, methodHandle.type().dropParameterTypes(1, 2), 0, 0, 1); - MethodHandle methodHandle; - if (elementType.getJavaType() == boolean.class) { - methodHandle = METHOD_HANDLE_BOOLEAN; - } - else if (elementType.getJavaType() == long.class) { - methodHandle = METHOD_HANDLE_LONG; - } - else if (elementType.getJavaType() == double.class) { - methodHandle = METHOD_HANDLE_DOUBLE; - } - else if (elementType.getJavaType() == Slice.class) { - methodHandle = METHOD_HANDLE_SLICE; - } - else { - methodHandle = METHOD_HANDLE_OBJECT.asType( - METHOD_HANDLE_OBJECT.type().changeReturnType(elementType.getJavaType())); - } - methodHandle = methodHandle.bindTo(elementType); - requireNonNull(methodHandle, "methodHandle is null"); return new ChoicesSpecializedSqlScalarFunction( boundSignature, NULLABLE_RETURN, @@ -94,64 +112,14 @@ else if (elementType.getJavaType() == Slice.class) { methodHandle); } - @UsedByGeneratedCode - public static Long longSubscript(Type elementType, Block array, long index) - { - checkIndex(array, index); - int position = toIntExact(index - 1); - if (array.isNull(position)) { - return null; - } - - return elementType.getLong(array, position); - } - - @UsedByGeneratedCode - public static Boolean booleanSubscript(Type elementType, Block array, long index) + private static int getPosition(Block array, long index) { - checkIndex(array, index); - int position = toIntExact(index - 1); - if (array.isNull(position)) { - return null; - } - - return elementType.getBoolean(array, position); - } - - @UsedByGeneratedCode - public static Double doubleSubscript(Type elementType, Block array, long index) - { - checkIndex(array, index); - int position = toIntExact(index - 1); - if (array.isNull(position)) { - return null; - } - - return elementType.getDouble(array, position); - } - - @UsedByGeneratedCode - public static Slice sliceSubscript(Type elementType, Block array, long index) - { - checkIndex(array, index); - int position = toIntExact(index - 1); - if (array.isNull(position)) { - return null; + checkArrayIndex(index); + if (index > array.getPositionCount()) { + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, format("Array subscript must be less than or equal to array length: %s > %s", index, array.getPositionCount())); } - - return elementType.getSlice(array, position); - } - - @UsedByGeneratedCode - public static Object objectSubscript(Type elementType, Block array, long index) - { - checkIndex(array, index); int position = toIntExact(index - 1); - if (array.isNull(position)) { - return null; - } - - return elementType.getObject(array, position); + return position; } public static void checkArrayIndex(long index) @@ -163,12 +131,4 @@ public static void checkArrayIndex(long index) throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "Array subscript is negative: " + index); } } - - public static void checkIndex(Block array, long index) - { - checkArrayIndex(index); - if (index > array.getPositionCount()) { - throw new TrinoException(INVALID_FUNCTION_ARGUMENT, format("Array subscript must be less than or equal to array length: %s > %s", index, array.getPositionCount())); - } - } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToArrayCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToArrayCast.java index 1738c4fb9480..60e6125a62fb 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToArrayCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayToArrayCast.java @@ -21,13 +21,12 @@ import io.trino.spi.function.ScalarOperator; import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; -import io.trino.spi.function.TypeParameterSpecialization; import io.trino.spi.type.Type; import java.lang.invoke.MethodHandle; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.OperatorType.CAST; @ScalarOperator(CAST) @@ -37,11 +36,10 @@ private ArrayToArrayCast() {} @TypeParameter("F") @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = long.class) @SqlType("array(T)") - public static Block filterLong( + public static Block filter( @TypeParameter("T") Type resultType, - @CastDependency(fromType = "F", toType = "T", convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = NULLABLE_RETURN, session = true)) MethodHandle cast, + @CastDependency(fromType = "F", toType = "T", convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = BLOCK_BUILDER, session = true)) MethodHandle cast, ConnectorSession session, @SqlType("array(F)") Block array) throws Throwable @@ -49,92 +47,12 @@ public static Block filterLong( int positionCount = array.getPositionCount(); BlockBuilder resultBuilder = resultType.createBlockBuilder(null, positionCount); for (int position = 0; position < positionCount; position++) { - if (!array.isNull(position)) { - Long value = (Long) cast.invokeExact(session, array, position); - if (value != null) { - resultType.writeLong(resultBuilder, value); - continue; - } + if (array.isNull(position)) { + resultBuilder.appendNull(); } - resultBuilder.appendNull(); - } - return resultBuilder.build(); - } - - @TypeParameter("F") - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = double.class) - @SqlType("array(T)") - public static Block filterDouble( - @TypeParameter("T") Type resultType, - @CastDependency(fromType = "F", toType = "T", convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = NULLABLE_RETURN, session = true)) MethodHandle cast, - ConnectorSession session, - @SqlType("array(F)") Block array) - throws Throwable - { - int positionCount = array.getPositionCount(); - BlockBuilder resultBuilder = resultType.createBlockBuilder(null, positionCount); - for (int position = 0; position < positionCount; position++) { - if (!array.isNull(position)) { - Double value = (Double) cast.invokeExact(session, array, position); - if (value != null) { - resultType.writeDouble(resultBuilder, value); - continue; - } - } - resultBuilder.appendNull(); - } - return resultBuilder.build(); - } - - @TypeParameter("F") - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = boolean.class) - @SqlType("array(T)") - public static Block filterBoolean( - @TypeParameter("T") Type resultType, - @CastDependency(fromType = "F", toType = "T", convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = NULLABLE_RETURN, session = true)) MethodHandle cast, - ConnectorSession session, - @SqlType("array(F)") Block array) - throws Throwable - { - int positionCount = array.getPositionCount(); - BlockBuilder resultBuilder = resultType.createBlockBuilder(null, positionCount); - for (int position = 0; position < positionCount; position++) { - if (!array.isNull(position)) { - Boolean value = (Boolean) cast.invokeExact(session, array, position); - if (value != null) { - resultType.writeBoolean(resultBuilder, value); - continue; - } - } - resultBuilder.appendNull(); - } - return resultBuilder.build(); - } - - @TypeParameter("F") - @TypeParameter("T") - @TypeParameterSpecialization(name = "T", nativeContainerType = Object.class) - @SqlType("array(T)") - public static Block filterObject( - @TypeParameter("T") Type resultType, - @CastDependency(fromType = "F", toType = "T", convention = @Convention(arguments = BLOCK_POSITION_NOT_NULL, result = NULLABLE_RETURN, session = true)) MethodHandle cast, - ConnectorSession session, - @SqlType("array(F)") Block array) - throws Throwable - { - int positionCount = array.getPositionCount(); - BlockBuilder resultBuilder = resultType.createBlockBuilder(null, positionCount); - for (int position = 0; position < positionCount; position++) { - if (!array.isNull(position)) { - Object value = (Object) cast.invoke(session, array, position); - if (value != null) { - resultType.writeObject(resultBuilder, value); - continue; - } + else { + cast.invokeExact(session, array, position, resultBuilder); } - resultBuilder.appendNull(); } return resultBuilder.build(); } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java index af215facaef4..8db7d8638c12 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ChoicesSpecializedSqlScalarFunction.java @@ -112,6 +112,7 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(InvocationCo ScalarImplementationChoice bestChoice = Collections.max(choices, comparingInt(ScalarImplementationChoice::getScore)); MethodHandle methodHandle = ScalarFunctionAdapter.adapt( bestChoice.getMethodHandle(), + boundSignature.getReturnType(), boundSignature.getArgumentTypes(), bestChoice.getInvocationConvention(), invocationConvention); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/GenericReadValueOperator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericReadValueOperator.java new file mode 100644 index 000000000000..ce2903cd59c8 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/GenericReadValueOperator.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.scalar; + +import io.trino.metadata.SqlScalarFunction; +import io.trino.spi.function.BoundSignature; +import io.trino.spi.function.FunctionMetadata; +import io.trino.spi.function.ScalarFunctionImplementation; +import io.trino.spi.function.Signature; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperators; +import io.trino.spi.type.TypeSignature; + +import java.lang.invoke.MethodHandle; + +import static io.trino.spi.function.OperatorType.READ_VALUE; +import static java.util.Objects.requireNonNull; + +public class GenericReadValueOperator + extends SqlScalarFunction +{ + private final TypeOperators typeOperators; + + public GenericReadValueOperator(TypeOperators typeOperators) + { + super(FunctionMetadata.scalarBuilder() + .signature(Signature.builder() + .operatorType(READ_VALUE) + .typeVariable("T") + .returnType(new TypeSignature("T")) + .argumentType(new TypeSignature("T")) + .build()) + .build()); + this.typeOperators = requireNonNull(typeOperators, "typeOperators is null"); + } + + @Override + protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) + { + Type type = boundSignature.getArgumentType(0); + return invocationConvention -> { + MethodHandle methodHandle = typeOperators.getReadValueOperator(type, invocationConvention); + return ScalarFunctionImplementation.builder() + .methodHandle(methodHandle) + .build(); + }; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/JoniRegexpReplaceLambdaFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/JoniRegexpReplaceLambdaFunction.java index 57ec4bc86a45..51b8d11d3e03 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/JoniRegexpReplaceLambdaFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/JoniRegexpReplaceLambdaFunction.java @@ -13,20 +13,19 @@ */ package io.trino.operator.scalar; -import com.google.common.collect.ImmutableList; import io.airlift.joni.Matcher; import io.airlift.joni.Region; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; -import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BufferedArrayValueBuilder; import io.trino.spi.function.Description; import io.trino.spi.function.LiteralParameters; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; +import io.trino.spi.type.ArrayType; import io.trino.sql.gen.lambda.UnaryFunctionInterface; import io.trino.type.JoniRegexp; import io.trino.type.JoniRegexpType; @@ -39,7 +38,7 @@ @Description("Replaces substrings matching a regular expression using a lambda function") public final class JoniRegexpReplaceLambdaFunction { - private final PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(VARCHAR)); + private final BufferedArrayValueBuilder arrayValueBuilder = BufferedArrayValueBuilder.createBuffered(new ArrayType(VARCHAR)); @LiteralParameters("x") @SqlType("varchar") @@ -57,13 +56,6 @@ public Slice regexpReplace( SliceOutput output = new DynamicSliceOutput(source.length()); - // Prepare a BlockBuilder that will be used to create the target block - // that will be passed to the lambda function. - if (pageBuilder.isFull()) { - pageBuilder.reset(); - } - BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); - int groupCount = pattern.regex().numberOfCaptures(); int appendPosition = 0; int nextStart; @@ -90,17 +82,17 @@ public Slice regexpReplace( // Append the capturing groups to the target block that will be passed to lambda Region matchedRegion = matcher.getEagerRegion(); - for (int i = 1; i <= groupCount; i++) { - // Add to the block builder if the matched region is not null. In Joni null is represented as [-1, -1] - if (matchedRegion.beg[i] >= 0 && matchedRegion.end[i] >= 0) { - VARCHAR.writeSlice(blockBuilder, source, matchedRegion.beg[i], matchedRegion.end[i] - matchedRegion.beg[i]); + Block target = arrayValueBuilder.build(groupCount, elementBuilder -> { + for (int i = 1; i <= groupCount; i++) { + // Add to the block builder if the matched region is not null. In Joni null is represented as [-1, -1] + if (matchedRegion.beg[i] >= 0 && matchedRegion.end[i] >= 0) { + VARCHAR.writeSlice(elementBuilder, source, matchedRegion.beg[i], matchedRegion.end[i] - matchedRegion.beg[i]); + } + else { + elementBuilder.appendNull(); + } } - else { - blockBuilder.appendNull(); - } - } - pageBuilder.declarePositions(groupCount); - Block target = blockBuilder.getRegion(blockBuilder.getPositionCount() - groupCount, groupCount); + }); // Call the lambda function to replace the block, and append the result to output Slice replaced = (Slice) replaceFunction.apply(target); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/Re2JRegexpReplaceLambdaFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/Re2JRegexpReplaceLambdaFunction.java index 7cd6274fa6e1..2e5a09be21e2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/Re2JRegexpReplaceLambdaFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/Re2JRegexpReplaceLambdaFunction.java @@ -13,19 +13,18 @@ */ package io.trino.operator.scalar; -import com.google.common.collect.ImmutableList; import com.google.re2j.Matcher; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; -import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.BufferedArrayValueBuilder; import io.trino.spi.function.Description; import io.trino.spi.function.LiteralParameters; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlNullable; import io.trino.spi.function.SqlType; +import io.trino.spi.type.ArrayType; import io.trino.sql.gen.lambda.UnaryFunctionInterface; import io.trino.type.Re2JRegexp; import io.trino.type.Re2JRegexpType; @@ -36,7 +35,7 @@ @Description("Replaces substrings matching a regular expression using a lambda function") public final class Re2JRegexpReplaceLambdaFunction { - private final PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(VARCHAR)); + private final BufferedArrayValueBuilder arrayValueBuilder = BufferedArrayValueBuilder.createBuffered(new ArrayType(VARCHAR)); @LiteralParameters("x") @SqlType("varchar") @@ -54,13 +53,6 @@ public Slice regexpReplace( SliceOutput output = new DynamicSliceOutput(source.length()); - // Prepare a BlockBuilder that will be used to create the target block - // that will be passed to the lambda function. - if (pageBuilder.isFull()) { - pageBuilder.reset(); - } - BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); - int groupCount = matcher.groupCount(); int appendPosition = 0; @@ -75,17 +67,17 @@ public Slice regexpReplace( appendPosition = end; // Append the capturing groups to the target block that will be passed to lambda - for (int i = 1; i <= groupCount; i++) { - Slice matchedGroupSlice = matcher.group(i); - if (matchedGroupSlice != null) { - VARCHAR.writeSlice(blockBuilder, matchedGroupSlice); + Block target = arrayValueBuilder.build(groupCount, elementBuilder -> { + for (int i = 1; i <= groupCount; i++) { + Slice matchedGroupSlice = matcher.group(i); + if (matchedGroupSlice != null) { + VARCHAR.writeSlice(elementBuilder, matchedGroupSlice); + } + else { + elementBuilder.appendNull(); + } } - else { - blockBuilder.appendNull(); - } - } - pageBuilder.declarePositions(groupCount); - Block target = blockBuilder.getRegion(blockBuilder.getPositionCount() - groupCount, groupCount); + }); // Call the lambda function to replace the block, and append the result to output Slice replaced = (Slice) replaceFunction.apply(target); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/OperatorValidator.java b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/OperatorValidator.java index dd3a5d71da26..4fc039614e7b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/OperatorValidator.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/annotations/OperatorValidator.java @@ -73,6 +73,7 @@ public static void validateOperator(OperatorType operatorType, TypeSignature ret case IS_DISTINCT_FROM: case XX_HASH_64: case INDETERMINATE: + case READ_VALUE: // TODO } } diff --git a/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java b/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java index 6bac17912cba..c8ab2bb40922 100644 --- a/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java +++ b/core/trino-main/src/test/java/io/trino/type/AbstractTestType.java @@ -40,6 +40,7 @@ import io.trino.type.BlockTypeOperators.BlockPositionXxHash64; import org.testng.annotations.Test; +import java.lang.invoke.MethodHandle; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -56,7 +57,9 @@ import static io.trino.spi.connector.SortOrder.DESC_NULLS_FIRST; import static io.trino.spi.connector.SortOrder.DESC_NULLS_LAST; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.DEFAULT_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; @@ -84,7 +87,11 @@ public abstract class AbstractTestType private final Class objectValueType; private final Block testBlock; protected final Type type; + private final TypeOperators typeOperators; + private final MethodHandle readBlockMethod; + private final MethodHandle writeBlockMethod; + protected final BlockTypeOperators blockTypeOperators; private final BlockPositionEqual equalOperator; private final BlockPositionHashCode hashCodeOperator; @@ -103,6 +110,9 @@ protected AbstractTestType(Type type, Class objectValueType, Block testBlock, { this.type = requireNonNull(type, "type is null"); typeOperators = new TypeOperators(); + readBlockMethod = typeOperators.getReadValueOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)); + writeBlockMethod = typeOperators.getReadValueOperator(type, simpleConvention(BLOCK_BUILDER, NEVER_NULL)); + blockTypeOperators = new BlockTypeOperators(typeOperators); if (type.isComparable()) { equalOperator = blockTypeOperators.getEqualOperator(type); @@ -182,6 +192,7 @@ protected PlannerContext createPlannerContext() @Test public void testBlock() + throws Throwable { for (Entry entry : expectedStackValues.entrySet()) { assertPositionEquals(testBlock, entry.getKey(), entry.getValue(), expectedObjectValues.get(entry.getKey())); @@ -233,6 +244,7 @@ protected Object getSampleValue() } protected void assertPositionEquals(Block block, int position, Object expectedStackValue, Object expectedObjectValue) + throws Throwable { long hash = 0; if (type.isComparable()) { @@ -247,9 +259,16 @@ protected void assertPositionEquals(Block block, int position, Object expectedSt BlockBuilder blockBuilder = type.createBlockBuilder(null, 1); type.appendTo(block, position, blockBuilder); assertPositionValue(blockBuilder.build(), 0, expectedStackValue, hash, expectedObjectValue); + + if (expectedStackValue != null) { + blockBuilder = type.createBlockBuilder(null, 1); + writeBlockMethod.invoke(expectedStackValue, blockBuilder); + assertPositionValue(blockBuilder.build(), 0, expectedStackValue, hash, expectedObjectValue); + } } private void assertPositionValue(Block block, int position, Object expectedStackValue, long expectedHash, Object expectedObjectValue) + throws Throwable { assertEquals(block.isNull(position), expectedStackValue == null); @@ -330,18 +349,21 @@ private void assertPositionValue(Block block, int position, Object expectedStack assertThatThrownBy(() -> type.getLong(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getDouble(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getObject(block, position)).isInstanceOf(UnsupportedOperationException.class); + assertEquals((boolean) readBlockMethod.invokeExact(block, position), expectedStackValue); } else if (type.getJavaType() == long.class) { assertEquals(type.getLong(block, position), expectedStackValue); assertThatThrownBy(() -> type.getBoolean(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getDouble(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getObject(block, position)).isInstanceOf(UnsupportedOperationException.class); + assertEquals((long) readBlockMethod.invokeExact(block, position), expectedStackValue); } else if (type.getJavaType() == double.class) { assertEquals(type.getDouble(block, position), expectedStackValue); assertThatThrownBy(() -> type.getBoolean(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getLong(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getObject(block, position)).isInstanceOf(UnsupportedOperationException.class); + assertEquals((double) readBlockMethod.invokeExact(block, position), expectedStackValue); } else if (type.getJavaType() == Slice.class) { assertEquals(type.getSlice(block, position), expectedStackValue); @@ -349,26 +371,34 @@ else if (type.getJavaType() == Slice.class) { assertThatThrownBy(() -> type.getBoolean(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getLong(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getDouble(block, position)).isInstanceOf(UnsupportedOperationException.class); + assertEquals((Slice) readBlockMethod.invokeExact(block, position), expectedStackValue); } else if (type.getJavaType() == Block.class) { - SliceOutput actualSliceOutput = new DynamicSliceOutput(100); - writeBlock(blockEncodingSerde, actualSliceOutput, (Block) type.getObject(block, position)); - SliceOutput expectedSliceOutput = new DynamicSliceOutput(actualSliceOutput.size()); - writeBlock(blockEncodingSerde, expectedSliceOutput, (Block) expectedStackValue); - assertEquals(actualSliceOutput.slice(), expectedSliceOutput.slice()); + assertBlockEquals((Block) type.getObject(block, position), (Block) expectedStackValue); assertThatThrownBy(() -> type.getBoolean(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getLong(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getDouble(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getSlice(block, position)).isInstanceOf(UnsupportedOperationException.class); + assertBlockEquals((Block) readBlockMethod.invokeExact(block, position), (Block) expectedStackValue); } else { assertEquals(type.getObject(block, position), expectedStackValue); assertThatThrownBy(() -> type.getBoolean(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getLong(block, position)).isInstanceOf(UnsupportedOperationException.class); assertThatThrownBy(() -> type.getDouble(block, position)).isInstanceOf(UnsupportedOperationException.class); + assertEquals(readBlockMethod.invoke(block, position), expectedStackValue); } } + private void assertBlockEquals(Block actualValue, Block expectedValue) + { + SliceOutput actualSliceOutput = new DynamicSliceOutput(100); + writeBlock(blockEncodingSerde, actualSliceOutput, actualValue); + SliceOutput expectedSliceOutput = new DynamicSliceOutput(actualSliceOutput.size()); + writeBlock(blockEncodingSerde, expectedSliceOutput, expectedValue); + assertEquals(actualSliceOutput.slice(), expectedSliceOutput.slice()); + } + private void verifyInvalidPositionHandling(Block block) { assertThatThrownBy(() -> type.getObjectValue(SESSION, block, -1)) @@ -427,6 +457,14 @@ private void verifyInvalidPositionHandling(Block block) .hasMessage("Invalid position %d in block with %d positions", block.getPositionCount(), block.getPositionCount()); } + assertThatThrownBy(() -> readBlockMethod.invoke(block, -1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid position -1 in block with %d positions", block.getPositionCount()); + + assertThatThrownBy(() -> readBlockMethod.invoke(block, block.getPositionCount())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid position %d in block with %d positions", block.getPositionCount(), block.getPositionCount()); + if (type.getJavaType() == boolean.class) { assertThatThrownBy(() -> type.getBoolean(block, -1)) .isInstanceOf(IllegalArgumentException.class) diff --git a/core/trino-spi/pom.xml b/core/trino-spi/pom.xml index 6442299d5422..96c2f43a0d00 100644 --- a/core/trino-spi/pom.xml +++ b/core/trino-spi/pom.xml @@ -450,6 +450,11 @@ private Constructor should have never been public as there is a static factory method + + java.method.numberOfParametersChanged + method java.lang.invoke.MethodHandle io.trino.spi.function.ScalarFunctionAdapter::adapt(java.lang.invoke.MethodHandle, java.util.List<io.trino.spi.type.Type>, io.trino.spi.function.InvocationConvention, io.trino.spi.function.InvocationConvention) + method java.lang.invoke.MethodHandle io.trino.spi.function.ScalarFunctionAdapter::adapt(java.lang.invoke.MethodHandle, io.trino.spi.type.Type, java.util.List<io.trino.spi.type.Type>, io.trino.spi.function.InvocationConvention, io.trino.spi.function.InvocationConvention) + diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java b/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java index a968403ada37..bc5491307bec 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/InvocationConvention.java @@ -163,10 +163,10 @@ public enum InvocationReturnConvention /** * The function will never return a null value. * It is not possible to adapt a NEVER_NULL argument to a - * BOXED_NULLABLE or NULL_FLAG argument when the this return + * BOXED_NULLABLE or NULL_FLAG argument when this return * convention is used. */ - FAIL_ON_NULL(false), + FAIL_ON_NULL(false, 0), /** * When a null is passed to a never null argument, the function * will not be invoked, and the Java default value for the return @@ -174,24 +174,37 @@ public enum InvocationReturnConvention * This can not be used as an actual function return convention, * and instead is only used for adaptation. */ - DEFAULT_ON_NULL(false), + DEFAULT_ON_NULL(false, 0), /** * The function may return a null value. * When a null is passed to a never null argument, the function * will not be invoked, and a null value is returned. */ - NULLABLE_RETURN(true); + NULLABLE_RETURN(true, 0), + /** + * Return value is witten to a BlockBuilder passed as the last argument. + * When a null is passed to a never null argument, the function + * will not be invoked, and a null is written to the block builder. + */ + BLOCK_BUILDER(true, 1); private final boolean nullable; + private final int parameterCount; - InvocationReturnConvention(boolean nullable) + InvocationReturnConvention(boolean nullable, int parameterCount) { this.nullable = nullable; + this.parameterCount = parameterCount; } public boolean isNullable() { return nullable; } + + public int getParameterCount() + { + return parameterCount; + } } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/OperatorType.java b/core/trino-spi/src/main/java/io/trino/spi/function/OperatorType.java index f2cf5b321d36..c86be88034fc 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/OperatorType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/OperatorType.java @@ -38,7 +38,9 @@ public enum OperatorType SATURATED_FLOOR_CAST("SATURATED FLOOR CAST", 1), IS_DISTINCT_FROM("IS DISTINCT FROM", 2), XX_HASH_64("XX HASH 64", 1), - INDETERMINATE("INDETERMINATE", 1); + INDETERMINATE("INDETERMINATE", 1), + READ_VALUE("READ VALUE", 1), + /**/; private final String operator; private final int argumentCount; diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java index 4e6b4acbcb29..ed86019277db 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java @@ -18,6 +18,7 @@ import io.trino.spi.StandardErrorCode; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; import io.trino.spi.function.InvocationConvention.InvocationReturnConvention; import io.trino.spi.type.Type; @@ -35,6 +36,7 @@ import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.DEFAULT_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; @@ -57,6 +59,7 @@ public final class ScalarFunctionAdapter { private static final MethodHandle IS_NULL_METHOD = lookupIsNullMethod(); + private static final MethodHandle APPEND_NULL_METHOD = lookupAppendNullMethod(); private ScalarFunctionAdapter() {} @@ -104,16 +107,12 @@ private static boolean canAdaptReturn( return true; } - if (expectedReturnConvention == NULLABLE_RETURN && actualReturnConvention == FAIL_ON_NULL) { - return true; - } - - if (expectedReturnConvention == DEFAULT_ON_NULL - && (actualReturnConvention == NULLABLE_RETURN || actualReturnConvention == FAIL_ON_NULL)) { - return true; - } - - return false; + return switch (actualReturnConvention) { + case FAIL_ON_NULL -> true; + case NULLABLE_RETURN -> expectedReturnConvention.isNullable() || expectedReturnConvention == DEFAULT_ON_NULL; + case BLOCK_BUILDER -> false; + case DEFAULT_ON_NULL -> throw new IllegalArgumentException("actual return convention cannot be DEFAULT_ON_NULL"); + }; } private static boolean canAdaptParameter( @@ -155,6 +154,7 @@ private static boolean canAdaptParameter( */ public static MethodHandle adapt( MethodHandle methodHandle, + Type returnType, List actualArgumentTypes, InvocationConvention actualConvention, InvocationConvention expectedConvention) @@ -167,14 +167,14 @@ public static MethodHandle adapt( } if (actualConvention.supportsSession() && !expectedConvention.supportsSession()) { - throw new IllegalArgumentException("Session method can not be adapted to no session"); + throw new IllegalArgumentException("Session method cannot be adapted to no session"); } if (!(expectedConvention.supportsInstanceFactory() || !actualConvention.supportsInstanceFactory())) { - throw new IllegalArgumentException("Instance method can not be adapted to no instance"); + throw new IllegalArgumentException("Instance method cannot be adapted to no instance"); } // adapt return first, since return-null-on-null parameter convention must know if the return type is nullable - methodHandle = adaptReturn(methodHandle, actualConvention.getReturnConvention(), expectedConvention.getReturnConvention()); + methodHandle = adaptReturn(methodHandle, returnType, actualConvention.getReturnConvention(), expectedConvention.getReturnConvention()); // adapt parameters one at a time int parameterIndex = 0; @@ -204,6 +204,7 @@ public static MethodHandle adapt( private static MethodHandle adaptReturn( MethodHandle methodHandle, + Type returnType, InvocationReturnConvention actualReturnConvention, InvocationReturnConvention expectedReturnConvention) { @@ -211,16 +212,30 @@ private static MethodHandle adaptReturn( return methodHandle; } - Class returnType = methodHandle.type().returnType(); if (expectedReturnConvention == NULLABLE_RETURN) { if (actualReturnConvention == FAIL_ON_NULL) { // box return - return explicitCastArguments(methodHandle, methodHandle.type().changeReturnType(wrap(returnType))); + return explicitCastArguments(methodHandle, methodHandle.type().changeReturnType(wrap(methodHandle.type().returnType()))); } } + if (expectedReturnConvention == BLOCK_BUILDER) { + // write the result to block builder + // type.writeValue(BlockBuilder, value), f(a,b)::value => method(BlockBuilder, a, b)::void + methodHandle = collectArguments(writeBlockValue(returnType), 1, methodHandle); + // f(BlockBuilder, a, b)::void => f(a, b, BlockBuilder) + MethodType newType = methodHandle.type() + .dropParameterTypes(0, 1) + .appendParameterTypes(BlockBuilder.class); + int[] reorder = IntStream.range(0, newType.parameterCount()) + .map(i -> i > 0 ? i - 1 : newType.parameterCount() - 1) + .toArray(); + methodHandle = permuteArguments(methodHandle, newType, reorder); + return methodHandle; + } + if (expectedReturnConvention == FAIL_ON_NULL && actualReturnConvention == NULLABLE_RETURN) { - throw new IllegalArgumentException("Nullable return can not be adapted fail on null"); + throw new IllegalArgumentException("Nullable return cannot be adapted fail on null"); } if (expectedReturnConvention == DEFAULT_ON_NULL) { @@ -229,11 +244,11 @@ private static MethodHandle adaptReturn( } if (actualReturnConvention == NULLABLE_RETURN) { // perform unboxing, which converts nulls to Java primitive default value - methodHandle = explicitCastArguments(methodHandle, methodHandle.type().changeReturnType(unwrap(returnType))); + methodHandle = explicitCastArguments(methodHandle, methodHandle.type().changeReturnType(unwrap(returnType.getJavaType()))); return methodHandle; } } - throw new IllegalArgumentException("Unsupported return convention: " + actualReturnConvention); + throw new IllegalArgumentException("%s return convention cannot be adapted to %s".formatted(actualReturnConvention, expectedReturnConvention)); } private static MethodHandle adaptParameter( @@ -280,7 +295,7 @@ private static MethodHandle adaptParameter( methodHandle = explicitCastArguments(methodHandle, targetType); if (returnConvention == FAIL_ON_NULL) { - throw new IllegalArgumentException("RETURN_NULL_ON_NULL adaptation can not be used with FAIL_ON_NULL return convention"); + throw new IllegalArgumentException("RETURN_NULL_ON_NULL adaptation cannot be used with FAIL_ON_NULL return convention"); } MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention); return guardWithTest( @@ -321,7 +336,7 @@ private static MethodHandle adaptParameter( if (actualArgumentConvention == NEVER_NULL) { // if caller sets the null flag, return null, otherwise invoke target if (returnConvention == FAIL_ON_NULL) { - throw new IllegalArgumentException("RETURN_NULL_ON_NULL adaptation can not be used with FAIL_ON_NULL return convention"); + throw new IllegalArgumentException("RETURN_NULL_ON_NULL adaptation cannot be used with FAIL_ON_NULL return convention"); } // add a null flag to call methodHandle = dropArguments(methodHandle, parameterIndex + 1, boolean.class); @@ -334,7 +349,7 @@ private static MethodHandle adaptParameter( } if (actualArgumentConvention == BOXED_NULLABLE) { - return collectArguments(methodHandle, parameterIndex, boxedToNullFlagFilter(methodHandle.type().parameterType(parameterIndex))); + return collectArguments(methodHandle, parameterIndex, boxedToNullFlagFilter(returnConvention, methodHandle.type().parameterType(parameterIndex))); } } @@ -385,7 +400,7 @@ private static MethodHandle adaptParameter( getBlockValue = explicitCastArguments(getBlockValue, getBlockValue.type().changeReturnType(wrap(getBlockValue.type().returnType()))); getBlockValue = guardWithTest( isBlockPositionNull(getBlockValue.type(), 0), - returnNull(getBlockValue.type()), + returnNull(getBlockValue.type(), returnConvention), getBlockValue); methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); return methodHandle; @@ -399,7 +414,7 @@ private static MethodHandle adaptParameter( // long, Block, int => Block, int, Block, int getBlockValue = guardWithTest( isBlockPositionNull(getBlockValue.type(), 0), - returnNull(getBlockValue.type()), + returnNull(getBlockValue.type(), returnConvention), getBlockValue); methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); @@ -450,7 +465,7 @@ private static MethodHandle adaptParameter( getInOutValue = explicitCastArguments(getInOutValue, getInOutValue.type().changeReturnType(wrap(getInOutValue.type().returnType()))); getInOutValue = guardWithTest( isInOutNull(getInOutValue.type(), 0), - returnNull(getInOutValue.type()), + returnNull(getInOutValue.type(), returnConvention), getInOutValue); methodHandle = collectArguments(methodHandle, parameterIndex, getInOutValue); return methodHandle; @@ -464,7 +479,7 @@ private static MethodHandle adaptParameter( // long, InOut => InOut, InOut getInOutValue = guardWithTest( isInOutNull(getInOutValue.type(), 0), - returnNull(getInOutValue.type()), + returnNull(getInOutValue.type(), returnConvention), getInOutValue); methodHandle = collectArguments(methodHandle, parameterIndex, getInOutValue); @@ -478,7 +493,7 @@ private static MethodHandle adaptParameter( } } - throw new IllegalArgumentException("Can not convert argument %s to %s with return convention %s".formatted(expectedArgumentConvention, actualArgumentConvention, returnConvention)); + throw new IllegalArgumentException("Cannot convert argument %s to %s with return convention %s".formatted(expectedArgumentConvention, actualArgumentConvention, returnConvention)); } private static MethodHandle getBlockValue(Type argumentType, Class expectedType) @@ -512,6 +527,37 @@ else if (methodArgumentType == Slice.class) { } } + private static MethodHandle writeBlockValue(Type type) + { + Class methodArgumentType = type.getJavaType(); + String getterName; + if (methodArgumentType == boolean.class) { + getterName = "writeBoolean"; + } + else if (methodArgumentType == long.class) { + getterName = "writeLong"; + } + else if (methodArgumentType == double.class) { + getterName = "writeDouble"; + } + else if (methodArgumentType == Slice.class) { + getterName = "writeSlice"; + } + else { + getterName = "writeObject"; + methodArgumentType = Object.class; + } + + try { + return lookup().findVirtual(Type.class, getterName, methodType(void.class, BlockBuilder.class, methodArgumentType)) + .bindTo(type) + .asType(methodType(void.class, BlockBuilder.class, type.getJavaType())); + } + catch (ReflectiveOperationException e) { + throw new AssertionError(e); + } + } + private static MethodHandle getInOutValue(Type argumentType, Class expectedType) { Class methodArgumentType = argumentType.getJavaType(); @@ -539,7 +585,7 @@ else if (methodArgumentType == double.class) { } } - private static MethodHandle boxedToNullFlagFilter(Class argumentType) + private static MethodHandle boxedToNullFlagFilter(InvocationReturnConvention returnConvention, Class argumentType) { // Start with identity MethodHandle handle = identity(argumentType); @@ -552,7 +598,7 @@ private static MethodHandle boxedToNullFlagFilter(Class argumentType) // if the flag is true, return null, otherwise invoke identity return guardWithTest( isTrueNullFlag(handle.type(), 0), - returnNull(handle.type()), + returnNull(handle.type(), returnConvention), handle); } @@ -602,14 +648,12 @@ private static MethodHandle isInOutNull(MethodType methodType, int index) private static MethodHandle lookupIsNullMethod() { - MethodHandle isNull; try { - isNull = lookup().findStatic(Objects.class, "isNull", methodType(boolean.class, Object.class)); + return lookup().findStatic(Objects.class, "isNull", methodType(boolean.class, Object.class)); } catch (ReflectiveOperationException e) { throw new AssertionError(e); } - return isNull; } private static MethodHandle getNullShortCircuitResult(MethodHandle methodHandle, InvocationReturnConvention returnConvention) @@ -619,7 +663,7 @@ private static MethodHandle getNullShortCircuitResult(MethodHandle methodHandle, nullReturnValue = returnDefault(methodHandle.type()); } else { - nullReturnValue = returnNull(methodHandle.type()); + nullReturnValue = returnNull(methodHandle.type(), returnConvention); } return nullReturnValue; } @@ -637,8 +681,12 @@ private static MethodHandle returnDefault(MethodType methodType) return returnDefault; } - private static MethodHandle returnNull(MethodType methodType) + private static MethodHandle returnNull(MethodType methodType, InvocationReturnConvention returnConvention) { + if (returnConvention == BLOCK_BUILDER) { + return permuteArguments(APPEND_NULL_METHOD, methodType, methodType.parameterCount() - 1); + } + // Start with a constant null value of the expected return type: f():R MethodHandle returnNull = constant(wrap(methodType.returnType()), null); @@ -650,6 +698,17 @@ private static MethodHandle returnNull(MethodType methodType) return returnNull; } + private static MethodHandle lookupAppendNullMethod() + { + try { + return lookup().findVirtual(BlockBuilder.class, "appendNull", methodType(BlockBuilder.class)) + .asType(methodType(void.class, BlockBuilder.class)); + } + catch (ReflectiveOperationException e) { + throw new AssertionError(e); + } + } + private static MethodHandle throwTrinoNullArgumentException(MethodType type) { MethodHandle throwException = collectArguments(throwException(type.returnType(), TrinoException.class), 0, trinoNullArgumentException()); diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperatorDeclaration.java b/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperatorDeclaration.java index 679ec50f2424..80b67b22b039 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperatorDeclaration.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperatorDeclaration.java @@ -14,6 +14,7 @@ package io.trino.spi.type; import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.function.BlockIndex; import io.trino.spi.function.BlockPosition; @@ -41,6 +42,7 @@ import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; import static io.trino.spi.function.InvocationConvention.simpleConvention; @@ -52,6 +54,7 @@ public final class TypeOperatorDeclaration { public static final TypeOperatorDeclaration NO_TYPE_OPERATOR_DECLARATION = builder(boolean.class).build(); + private final Collection readValueOperators; private final Collection equalOperators; private final Collection hashCodeOperators; private final Collection xxHash64Operators; @@ -63,6 +66,7 @@ public final class TypeOperatorDeclaration private final Collection lessThanOrEqualOperators; private TypeOperatorDeclaration( + Collection readValueOperators, Collection equalOperators, Collection hashCodeOperators, Collection xxHash64Operators, @@ -73,6 +77,7 @@ private TypeOperatorDeclaration( Collection lessThanOperators, Collection lessThanOrEqualOperators) { + this.readValueOperators = List.copyOf(requireNonNull(readValueOperators, "readValueOperators is null")); this.equalOperators = List.copyOf(requireNonNull(equalOperators, "equalOperators is null")); this.hashCodeOperators = List.copyOf(requireNonNull(hashCodeOperators, "hashCodeOperators is null")); this.xxHash64Operators = List.copyOf(requireNonNull(xxHash64Operators, "xxHash64Operators is null")); @@ -94,6 +99,11 @@ public boolean isOrderable() return !comparisonUnorderedLastOperators.isEmpty(); } + public Collection getReadValueOperators() + { + return readValueOperators; + } + public Collection getEqualOperators() { return equalOperators; @@ -155,6 +165,7 @@ public static class Builder { private final Class typeJavaType; + private final Collection readValueOperators = new ArrayList<>(); private final Collection equalOperators = new ArrayList<>(); private final Collection hashCodeOperators = new ArrayList<>(); private final Collection xxHash64Operators = new ArrayList<>(); @@ -173,6 +184,7 @@ private Builder(Class typeJavaType) public Builder addOperators(TypeOperatorDeclaration operatorDeclaration) { + operatorDeclaration.getReadValueOperators().forEach(this::addReadValueOperator); operatorDeclaration.getEqualOperators().forEach(this::addEqualOperator); operatorDeclaration.getHashCodeOperators().forEach(this::addHashCodeOperator); operatorDeclaration.getXxHash64Operators().forEach(this::addXxHash64Operator); @@ -185,6 +197,13 @@ public Builder addOperators(TypeOperatorDeclaration operatorDeclaration) return this; } + public Builder addReadValueOperator(OperatorMethodHandle readValueOperator) + { + verifyMethodHandleSignature(1, typeJavaType, readValueOperator); + this.readValueOperators.add(readValueOperator); + return this; + } + public Builder addEqualOperator(OperatorMethodHandle equalOperator) { verifyMethodHandleSignature(2, boolean.class, equalOperator); @@ -348,6 +367,9 @@ public Builder addOperators(Class operatorsClass, Lookup lookup) } switch (operatorType) { + case READ_VALUE: + addReadValueOperator(new OperatorMethodHandle(parseInvocationConvention(operatorType, typeJavaType, method, typeJavaType), methodHandle)); + break; case EQUAL: addEqualOperator(new OperatorMethodHandle(parseInvocationConvention(operatorType, typeJavaType, method, boolean.class), methodHandle)); break; @@ -400,6 +422,7 @@ private void verifyMethodHandleSignature(int expectedArgumentCount, Class ret int expectedParameterCount = convention.getArgumentConventions().stream() .mapToInt(InvocationArgumentConvention::getParameterCount) .sum(); + expectedParameterCount += convention.getReturnConvention().getParameterCount(); checkArgument(expectedParameterCount == methodType.parameterCount(), "Expected %s method parameters, but got %s", expectedParameterCount, methodType.parameterCount()); @@ -445,6 +468,12 @@ private void verifyMethodHandleSignature(int expectedArgumentCount, Class ret checkArgument(methodType.returnType().equals(wrap(returnJavaType)), "Expected return type to be %s, but is %s", returnJavaType, wrap(methodType.returnType())); break; + case BLOCK_BUILDER: + checkArgument(methodType.lastParameterType().equals(BlockBuilder.class), + "Expected last argument type to be BlockBuilder, but is %s", methodType.returnType()); + checkArgument(methodType.returnType().equals(void.class), + "Expected return type to be void, but is %s", methodType.returnType()); + break; default: throw new UnsupportedOperationException("Unknown return convention: " + returnConvention); } @@ -458,6 +487,8 @@ private static InvocationConvention parseInvocationConvention(OperatorType opera List> parameterTypes = List.of(method.getParameterTypes()); List parameterAnnotations = List.of(method.getParameterAnnotations()); + parameterTypes = parameterTypes.subList(0, parameterTypes.size() - returnConvention.getParameterCount()); + parameterAnnotations = parameterAnnotations.subList(0, parameterAnnotations.size() - returnConvention.getParameterCount()); InvocationArgumentConvention leftArgumentConvention = extractNextArgumentConvention(typeJavaType, parameterTypes, parameterAnnotations, operatorType, method); if (leftArgumentConvention.getParameterCount() == parameterTypes.size()) { @@ -491,6 +522,11 @@ private static InvocationReturnConvention getReturnConvention(Class expectedR else if (method.isAnnotationPresent(SqlNullable.class) && method.getReturnType().equals(wrap(expectedReturnType))) { returnConvention = NULLABLE_RETURN; } + else if (method.getReturnType().equals(void.class) && + method.getParameterCount() >= 1 && + method.getParameterTypes()[method.getParameterCount() - 1].equals(BlockBuilder.class)) { + returnConvention = BLOCK_BUILDER; + } else { throw new IllegalArgumentException(format("Expected %s operator to return %s: %s", operatorType, expectedReturnType, method)); } @@ -569,6 +605,7 @@ public TypeOperatorDeclaration build() } return new TypeOperatorDeclaration( + readValueOperators, equalOperators, hashCodeOperators, xxHash64Operators, diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperators.java b/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperators.java index 9dec200ae0cf..445507311ca7 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperators.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/TypeOperators.java @@ -13,8 +13,10 @@ */ package io.trino.spi.type; +import io.airlift.slice.Slice; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.SortOrder; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.InvocationConvention.InvocationArgumentConvention; @@ -27,6 +29,7 @@ import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodHandles.Lookup; import java.lang.invoke.MethodType; +import java.util.ArrayList; import java.util.Collection; import java.util.Comparator; import java.util.List; @@ -37,8 +40,10 @@ import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; import static io.trino.spi.function.InvocationConvention.simpleConvention; @@ -47,6 +52,10 @@ import static io.trino.spi.function.OperatorType.EQUAL; import static io.trino.spi.function.OperatorType.LESS_THAN; import static io.trino.spi.function.OperatorType.LESS_THAN_OR_EQUAL; +import static io.trino.spi.function.OperatorType.READ_VALUE; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.IntegerType.INTEGER; import static java.lang.String.format; import static java.lang.invoke.MethodHandles.collectArguments; import static java.lang.invoke.MethodHandles.dropArguments; @@ -59,6 +68,9 @@ public class TypeOperators { + private static final InvocationConvention READ_BLOCK_NOT_NULL_CALLING_CONVENTION = simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL); + private static final InvocationConvention WRITE_BLOCK_CALLING_CONVENTION = simpleConvention(BLOCK_BUILDER, NEVER_NULL); + private final BiFunction, Object> cache; public TypeOperators() @@ -79,6 +91,11 @@ public TypeOperators(BiFunction, Object> cache) this.cache = cache; } + public MethodHandle getReadValueOperator(Type type, InvocationConvention callingConvention) + { + return getOperatorAdaptor(type, callingConvention, READ_VALUE).get(); + } + public MethodHandle getEqualOperator(Type type, InvocationConvention callingConvention) { if (!type.isComparable()) { @@ -200,6 +217,7 @@ private static MethodHandle adaptOperator(OperatorConvention operatorConvention, { return ScalarFunctionAdapter.adapt( operatorMethodHandle.getMethodHandle(), + getOperatorReturnType(operatorConvention), getOperatorArgumentTypes(operatorConvention), operatorMethodHandle.getCallingConvention(), operatorConvention.callingConvention()); @@ -233,6 +251,16 @@ private Collection getOperatorMethodHandles(OperatorConven TypeOperatorDeclaration typeOperatorDeclaration = operatorConvention.type().getTypeOperatorDeclaration(TypeOperators.this); requireNonNull(typeOperatorDeclaration, "typeOperators is null for " + operatorConvention.type()); return switch (operatorConvention.operatorType()) { + case READ_VALUE -> { + List readValueOperators = new ArrayList<>(typeOperatorDeclaration.getReadValueOperators()); + if (readValueOperators.stream().map(OperatorMethodHandle::getCallingConvention).noneMatch(READ_BLOCK_NOT_NULL_CALLING_CONVENTION::equals)) { + readValueOperators.add(new OperatorMethodHandle(READ_BLOCK_NOT_NULL_CALLING_CONVENTION, getDefaultReadBlockMethod(operatorConvention.type()))); + } + if (readValueOperators.stream().map(OperatorMethodHandle::getCallingConvention).noneMatch(WRITE_BLOCK_CALLING_CONVENTION::equals)) { + readValueOperators.add(new OperatorMethodHandle(WRITE_BLOCK_CALLING_CONVENTION, getDefaultWriteMethod(operatorConvention.type()))); + } + yield readValueOperators; + } case EQUAL -> typeOperatorDeclaration.getEqualOperators(); case HASH_CODE -> { Collection hashCodeOperators = typeOperatorDeclaration.getHashCodeOperators(); @@ -296,6 +324,44 @@ private Collection getOperatorMethodHandles(OperatorConven }; } + private static MethodHandle getDefaultReadBlockMethod(Type type) + { + Class javaType = type.getJavaType(); + if (boolean.class.equals(javaType)) { + return TYPE_GET_BOOLEAN.bindTo(type); + } + if (long.class.equals(javaType)) { + return TYPE_GET_LONG.bindTo(type); + } + if (double.class.equals(javaType)) { + return TYPE_GET_DOUBLE.bindTo(type); + } + if (Slice.class.equals(javaType)) { + return TYPE_GET_SLICE.bindTo(type); + } + return TYPE_GET_OBJECT + .asType(TYPE_GET_OBJECT.type().changeReturnType(type.getJavaType())) + .bindTo(type); + } + + private static MethodHandle getDefaultWriteMethod(Type type) + { + Class javaType = type.getJavaType(); + if (boolean.class.equals(javaType)) { + return TYPE_WRITE_BOOLEAN.bindTo(type); + } + if (long.class.equals(javaType)) { + return TYPE_WRITE_LONG.bindTo(type); + } + if (double.class.equals(javaType)) { + return TYPE_WRITE_DOUBLE.bindTo(type); + } + if (Slice.class.equals(javaType)) { + return TYPE_WRITE_SLICE.bindTo(type); + } + return TYPE_WRITE_OBJECT.bindTo(type); + } + private OperatorMethodHandle generateDistinctFromOperator(OperatorConvention operatorConvention) { if (operatorConvention.callingConvention().getArgumentConventions().equals(List.of(BLOCK_POSITION, BLOCK_POSITION))) { @@ -348,12 +414,23 @@ private OperatorMethodHandle generateOrderingOperator(OperatorConvention operato return adaptNeverNullComparisonToOrdering(sortOrder, comparisonInvoker); } + private static Type getOperatorReturnType(OperatorConvention operatorConvention) + { + return switch (operatorConvention.operatorType()) { + case EQUAL, IS_DISTINCT_FROM, LESS_THAN, LESS_THAN_OR_EQUAL, INDETERMINATE -> BOOLEAN; + case COMPARISON_UNORDERED_LAST, COMPARISON_UNORDERED_FIRST -> INTEGER; + case HASH_CODE, XX_HASH_64 -> BIGINT; + case READ_VALUE -> operatorConvention.type(); + default -> throw new IllegalArgumentException("Unsupported operator type: " + operatorConvention.operatorType()); + }; + } + private static List getOperatorArgumentTypes(OperatorConvention operatorConvention) { return switch (operatorConvention.operatorType()) { case EQUAL, IS_DISTINCT_FROM, COMPARISON_UNORDERED_LAST, COMPARISON_UNORDERED_FIRST, LESS_THAN, LESS_THAN_OR_EQUAL -> List.of(operatorConvention.type(), operatorConvention.type()); - case HASH_CODE, XX_HASH_64, INDETERMINATE -> + case READ_VALUE, HASH_CODE, XX_HASH_64, INDETERMINATE -> List.of(operatorConvention.type()); default -> throw new IllegalArgumentException("Unsupported operator type: " + operatorConvention.operatorType()); }; @@ -395,6 +472,18 @@ private record OperatorConvention(Type type, OperatorType operatorType, Optional private static final MethodHandle ORDER_COMPARISON_RESULT; private static final MethodHandle BLOCK_IS_NULL; + private static final MethodHandle TYPE_GET_BOOLEAN; + private static final MethodHandle TYPE_GET_LONG; + private static final MethodHandle TYPE_GET_DOUBLE; + private static final MethodHandle TYPE_GET_SLICE; + private static final MethodHandle TYPE_GET_OBJECT; + + private static final MethodHandle TYPE_WRITE_BOOLEAN; + private static final MethodHandle TYPE_WRITE_LONG; + private static final MethodHandle TYPE_WRITE_DOUBLE; + private static final MethodHandle TYPE_WRITE_SLICE; + private static final MethodHandle TYPE_WRITE_OBJECT; + static { try { Lookup lookup = lookup(); @@ -410,12 +499,33 @@ private record OperatorConvention(Type type, OperatorType operatorType, Optional ORDER_NULLS = lookup.findStatic(TypeOperators.class, "orderNulls", MethodType.methodType(int.class, SortOrder.class, boolean.class, boolean.class)); ORDER_COMPARISON_RESULT = lookup.findStatic(TypeOperators.class, "orderComparisonResult", MethodType.methodType(int.class, SortOrder.class, long.class)); BLOCK_IS_NULL = lookup.findVirtual(Block.class, "isNull", MethodType.methodType(boolean.class, int.class)); + + TYPE_GET_BOOLEAN = lookup.findVirtual(Type.class, "getBoolean", MethodType.methodType(boolean.class, Block.class, int.class)); + TYPE_GET_LONG = lookup.findVirtual(Type.class, "getLong", MethodType.methodType(long.class, Block.class, int.class)); + TYPE_GET_DOUBLE = lookup.findVirtual(Type.class, "getDouble", MethodType.methodType(double.class, Block.class, int.class)); + TYPE_GET_SLICE = lookup.findVirtual(Type.class, "getSlice", MethodType.methodType(Slice.class, Block.class, int.class)); + TYPE_GET_OBJECT = lookup.findVirtual(Type.class, "getObject", MethodType.methodType(Object.class, Block.class, int.class)); + + TYPE_WRITE_BOOLEAN = lookupWriteBlockBuilderMethod(lookup, "writeBoolean", boolean.class); + TYPE_WRITE_LONG = lookupWriteBlockBuilderMethod(lookup, "writeLong", long.class); + TYPE_WRITE_DOUBLE = lookupWriteBlockBuilderMethod(lookup, "writeDouble", double.class); + TYPE_WRITE_SLICE = lookupWriteBlockBuilderMethod(lookup, "writeSlice", Slice.class); + TYPE_WRITE_OBJECT = lookupWriteBlockBuilderMethod(lookup, "writeObject", Object.class); } catch (NoSuchMethodException | IllegalAccessException e) { throw new RuntimeException(e); } } + private static MethodHandle lookupWriteBlockBuilderMethod(Lookup lookup, String methodName, Class javaType) + throws NoSuchMethodException, IllegalAccessException + { + return permuteArguments( + lookup.findVirtual(Type.class, methodName, MethodType.methodType(void.class, BlockBuilder.class, javaType)), + MethodType.methodType(void.class, Type.class, javaType, BlockBuilder.class), + 0, 2, 1); + } + // // Adapt equal to is distinct from // diff --git a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java index 80be2d78eee8..47c3eda14f0d 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java +++ b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java @@ -50,6 +50,7 @@ import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.IN_OUT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -62,6 +63,7 @@ import static java.lang.invoke.MethodType.methodType; import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNull; @@ -72,6 +74,7 @@ public class TestScalarFunctionAdapter private static final ArrayType ARRAY_TYPE = new ArrayType(BIGINT); private static final CharType CHAR_TYPE = createCharType(7); private static final TimestampType TIMESTAMP_TYPE = createTimestampType(9); + private static final Type RETURN_TYPE = BOOLEAN; private static final List ARGUMENT_TYPES = ImmutableList.of(BOOLEAN, BIGINT, DOUBLE, VARCHAR, ARRAY_TYPE); private static final List OBJECTS_ARGUMENT_TYPES = ImmutableList.of(VARCHAR, ARRAY_TYPE, CHAR_TYPE, TIMESTAMP_TYPE); @@ -85,7 +88,7 @@ public void testAdaptFromNeverNull() false, true); String methodName = "neverNull"; - verifyAllAdaptations(actualConvention, methodName, ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); } @Test @@ -98,23 +101,25 @@ public void testAdaptFromNeverNullObjects() false, true); String methodName = "neverNullObjects"; - verifyAllAdaptations(actualConvention, methodName, OBJECTS_ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); } private static void verifyAllAdaptations( InvocationConvention actualConvention, String methodName, + Type returnType, List argumentTypes) throws Throwable { MethodType type = methodType(actualConvention.getReturnConvention() == FAIL_ON_NULL ? boolean.class : Boolean.class, toCallArgumentTypes(actualConvention, argumentTypes)); MethodHandle methodHandle = lookup().findVirtual(Target.class, methodName, type); - verifyAllAdaptations(actualConvention, methodHandle, argumentTypes); + verifyAllAdaptations(actualConvention, methodHandle, returnType, argumentTypes); } private static void verifyAllAdaptations( InvocationConvention actualConvention, MethodHandle methodHandle, + Type returnType, List argumentTypes) throws Throwable { @@ -128,6 +133,7 @@ private static void verifyAllAdaptations( methodHandle, actualConvention, expectedConvention, + returnType, argumentTypes); } } @@ -137,6 +143,7 @@ private static void adaptAndVerify( MethodHandle methodHandle, InvocationConvention actualConvention, InvocationConvention expectedConvention, + Type returnType, List argumentTypes) throws Throwable { @@ -144,6 +151,7 @@ private static void adaptAndVerify( try { adaptedMethodHandle = ScalarFunctionAdapter.adapt( methodHandle, + returnType, argumentTypes, actualConvention, expectedConvention); @@ -167,7 +175,9 @@ private static void adaptAndVerify( // crete an exact invoker to the handle, so we can use object invoke interface without type coercion concerns MethodHandle exactInvoker = MethodHandles.exactInvoker(adaptedMethodHandle.type()) .bindTo(adaptedMethodHandle); - exactInvoker = MethodHandles.explicitCastArguments(exactInvoker, exactInvoker.type().changeReturnType(Boolean.class)); + if (expectedConvention.getReturnConvention() != BLOCK_BUILDER) { + exactInvoker = MethodHandles.explicitCastArguments(exactInvoker, exactInvoker.type().changeReturnType(Boolean.class)); + } // try all combinations of null and not null arguments for (int notNullMask = 0; notNullMask < (1 << actualConvention.getArgumentConventions().size()); notNullMask++) { @@ -178,6 +188,18 @@ private static void adaptAndVerify( Target target = new Target(); List argumentValues = toCallArgumentValues(newCallingConvention, nullArguments, target, argumentTypes); try { + if (expectedConvention.getReturnConvention() == BLOCK_BUILDER) { + BlockBuilder blockBuilder = returnType.createBlockBuilder(null, 1); + argumentValues.add(blockBuilder); + exactInvoker.invokeWithArguments(argumentValues); + Block result = blockBuilder.build(); + assertThat(result.getPositionCount()).isEqualTo(1); + if (!result.isNull(0)) { + assertTrue(BOOLEAN.getBoolean(result, 0)); + } + return; + } + Boolean result = (Boolean) exactInvoker.invokeWithArguments(argumentValues); switch (expectedConvention.getReturnConvention()) { case FAIL_ON_NULL -> assertTrue(result);