diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java index fdd5e9268c90..28c1a6f5b9bd 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java @@ -264,7 +264,7 @@ private static MethodHandle adaptReturn( if (expectedReturnConvention == BLOCK_BUILDER) { // write the result to block builder // type.writeValue(BlockBuilder, value), f(a,b)::value => method(BlockBuilder, a, b)::void - methodHandle = collectArguments(writeBlockValue(returnType), 1, methodHandle); + methodHandle = collectArguments(writeBlockValue(returnType, actualReturnConvention.isNullable()), 1, methodHandle); // f(BlockBuilder, a, b)::void => f(a, b, BlockBuilder) MethodType newType = methodHandle.type() .dropParameterTypes(0, 1) @@ -668,7 +668,7 @@ else if (methodArgumentType == Slice.class) { } } - private static MethodHandle writeBlockValue(Type type) + private static MethodHandle writeBlockValue(Type type, boolean nullable) { Class methodArgumentType = type.getJavaType(); String getterName; @@ -689,14 +689,25 @@ else if (methodArgumentType == Slice.class) { methodArgumentType = Object.class; } + MethodHandle methodHandle; try { - return lookup().findVirtual(Type.class, getterName, methodType(void.class, BlockBuilder.class, methodArgumentType)) + methodHandle = lookup().findVirtual(Type.class, getterName, methodType(void.class, BlockBuilder.class, methodArgumentType)) .bindTo(type) .asType(methodType(void.class, BlockBuilder.class, type.getJavaType())); } catch (ReflectiveOperationException e) { throw new AssertionError(e); } + + if (!nullable) { + return methodHandle; + } + + methodHandle = methodHandle.asType(methodType(void.class, BlockBuilder.class, wrap(type.getJavaType()))); + return guardWithTest( + isNullArgument(methodHandle.type(), 1), + permuteArguments(APPEND_NULL_METHOD, methodHandle.type(), 0), + methodHandle); } private static MethodHandle getFlatValueNeverNull(Type argumentType, Class expectedType) diff --git a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java index ac0ba86d87d4..2f0c5a49d2b8 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java +++ b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java @@ -62,6 +62,7 @@ import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; import static io.trino.spi.function.InvocationConvention.simpleConvention; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; @@ -70,6 +71,7 @@ import static io.trino.spi.type.TimestampType.createTimestampType; import static io.trino.spi.type.TypeUtils.writeNativeValue; import static io.trino.spi.type.VarcharType.VARCHAR; +import static java.lang.invoke.MethodHandles.identity; import static java.lang.invoke.MethodHandles.lookup; import static java.lang.invoke.MethodType.methodType; import static java.util.Collections.nCopies; @@ -98,6 +100,29 @@ public void testAdaptFromNeverNull() verifyAllAdaptations(actualConvention, "neverNull", RETURN_TYPE, ARGUMENT_TYPES); } + @Test + public void testAdaptNullableReturnToBlockBuilder() + throws Throwable + { + // adapt identity(Double):Double to identity(Double, BlockBuilder):void + MethodHandle adaptedMethodHandle = ScalarFunctionAdapter.adapt( + identity(Double.class), + DOUBLE, + ImmutableList.of(DOUBLE), + simpleConvention(NULLABLE_RETURN, BOXED_NULLABLE), + simpleConvention(BLOCK_BUILDER, BOXED_NULLABLE)); + + // verify non-null and null value are written to the block + BlockBuilder blockBuilder = DOUBLE.createBlockBuilder(null, 1); + adaptedMethodHandle.invoke(1.1, blockBuilder); + adaptedMethodHandle.invoke(null, blockBuilder); + Block block = blockBuilder.buildValueBlock(); + assertThat(block.getPositionCount()).isEqualTo(2); + assertThat(block.isNull(0)).isFalse(); + assertThat(DOUBLE.getDouble(block, 0)).isEqualTo(1.1); + assertThat(block.isNull(1)).isTrue(); + } + @Test public void testAdaptFromNeverNullObjects() throws Throwable