From 852ef40c2fba689f049c706b1ee7750317be24c3 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 15 Jul 2023 17:36:42 -0700 Subject: [PATCH 1/2] Add test coverage for NULL_FLAG and BOXED_NULL and fix bugs --- .../spi/function/ScalarFunctionAdapter.java | 73 +++------- .../function/TestScalarFunctionAdapter.java | 125 ++++++++++++------ 2 files changed, 99 insertions(+), 99 deletions(-) diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java index ed86019277db..2a2c40157ff6 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java @@ -41,8 +41,8 @@ import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; import static java.lang.invoke.MethodHandles.collectArguments; -import static java.lang.invoke.MethodHandles.constant; import static java.lang.invoke.MethodHandles.dropArguments; +import static java.lang.invoke.MethodHandles.empty; import static java.lang.invoke.MethodHandles.explicitCastArguments; import static java.lang.invoke.MethodHandles.filterArguments; import static java.lang.invoke.MethodHandles.guardWithTest; @@ -52,7 +52,6 @@ import static java.lang.invoke.MethodHandles.permuteArguments; import static java.lang.invoke.MethodHandles.publicLookup; import static java.lang.invoke.MethodHandles.throwException; -import static java.lang.invoke.MethodHandles.zero; import static java.lang.invoke.MethodType.methodType; import static java.util.Objects.requireNonNull; @@ -297,10 +296,9 @@ private static MethodHandle adaptParameter( if (returnConvention == FAIL_ON_NULL) { throw new IllegalArgumentException("RETURN_NULL_ON_NULL adaptation cannot be used with FAIL_ON_NULL return convention"); } - MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention); return guardWithTest( isNullArgument(methodHandle.type(), parameterIndex), - nullReturnValue, + getNullShortCircuitResult(methodHandle, returnConvention), methodHandle); } @@ -341,15 +339,14 @@ private static MethodHandle adaptParameter( // add a null flag to call methodHandle = dropArguments(methodHandle, parameterIndex + 1, boolean.class); - MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention); return guardWithTest( isTrueNullFlag(methodHandle.type(), parameterIndex), - nullReturnValue, + getNullShortCircuitResult(methodHandle, returnConvention), methodHandle); } if (actualArgumentConvention == BOXED_NULLABLE) { - return collectArguments(methodHandle, parameterIndex, boxedToNullFlagFilter(returnConvention, methodHandle.type().parameterType(parameterIndex))); + return collectArguments(methodHandle, parameterIndex, boxedToNullFlagFilter(methodHandle.type().parameterType(parameterIndex))); } } @@ -381,10 +378,9 @@ private static MethodHandle adaptParameter( // if caller sets the null flag, return null, otherwise invoke target methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); - MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention); return guardWithTest( isBlockPositionNull(methodHandle.type(), parameterIndex), - nullReturnValue, + getNullShortCircuitResult(methodHandle, returnConvention), methodHandle); } @@ -400,7 +396,7 @@ private static MethodHandle adaptParameter( getBlockValue = explicitCastArguments(getBlockValue, getBlockValue.type().changeReturnType(wrap(getBlockValue.type().returnType()))); getBlockValue = guardWithTest( isBlockPositionNull(getBlockValue.type(), 0), - returnNull(getBlockValue.type(), returnConvention), + empty(getBlockValue.type()), getBlockValue); methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); return methodHandle; @@ -411,11 +407,13 @@ private static MethodHandle adaptParameter( MethodHandle isNull = isBlockPositionNull(getBlockValue.type(), 0); methodHandle = collectArguments(methodHandle, parameterIndex + 1, isNull); - // long, Block, int => Block, int, Block, int + // convert get block value to be null safe getBlockValue = guardWithTest( isBlockPositionNull(getBlockValue.type(), 0), - returnNull(getBlockValue.type(), returnConvention), + empty(getBlockValue.type()), getBlockValue); + + // long, Block, int => Block, int, Block, int methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue); int[] reorder = IntStream.range(0, methodHandle.type().parameterCount()) @@ -427,7 +425,7 @@ private static MethodHandle adaptParameter( } if (actualArgumentConvention == BLOCK_POSITION_NOT_NULL) { - if (returnConvention.isNullable()) { + if (returnConvention != FAIL_ON_NULL) { MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention); return guardWithTest( isBlockPositionNull(methodHandle.type(), parameterIndex), @@ -446,10 +444,9 @@ private static MethodHandle adaptParameter( // if caller sets the null flag, return null, otherwise invoke target methodHandle = collectArguments(methodHandle, parameterIndex, getInOutValue); - MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention); return guardWithTest( isInOutNull(methodHandle.type(), parameterIndex), - nullReturnValue, + getNullShortCircuitResult(methodHandle, returnConvention), methodHandle); } @@ -465,7 +462,7 @@ private static MethodHandle adaptParameter( getInOutValue = explicitCastArguments(getInOutValue, getInOutValue.type().changeReturnType(wrap(getInOutValue.type().returnType()))); getInOutValue = guardWithTest( isInOutNull(getInOutValue.type(), 0), - returnNull(getInOutValue.type(), returnConvention), + empty(getInOutValue.type()), getInOutValue); methodHandle = collectArguments(methodHandle, parameterIndex, getInOutValue); return methodHandle; @@ -479,7 +476,7 @@ private static MethodHandle adaptParameter( // long, InOut => InOut, InOut getInOutValue = guardWithTest( isInOutNull(getInOutValue.type(), 0), - returnNull(getInOutValue.type(), returnConvention), + empty(getInOutValue.type()), getInOutValue); methodHandle = collectArguments(methodHandle, parameterIndex, getInOutValue); @@ -585,7 +582,7 @@ else if (methodArgumentType == double.class) { } } - private static MethodHandle boxedToNullFlagFilter(InvocationReturnConvention returnConvention, Class argumentType) + private static MethodHandle boxedToNullFlagFilter(Class argumentType) { // Start with identity MethodHandle handle = identity(argumentType); @@ -598,7 +595,7 @@ private static MethodHandle boxedToNullFlagFilter(InvocationReturnConvention ret // if the flag is true, return null, otherwise invoke identity return guardWithTest( isTrueNullFlag(handle.type(), 0), - returnNull(handle.type(), returnConvention), + empty(handle.type()), handle); } @@ -657,45 +654,11 @@ private static MethodHandle lookupIsNullMethod() } private static MethodHandle getNullShortCircuitResult(MethodHandle methodHandle, InvocationReturnConvention returnConvention) - { - MethodHandle nullReturnValue; - if (returnConvention == DEFAULT_ON_NULL) { - nullReturnValue = returnDefault(methodHandle.type()); - } - else { - nullReturnValue = returnNull(methodHandle.type(), returnConvention); - } - return nullReturnValue; - } - - private static MethodHandle returnDefault(MethodType methodType) - { - // Start with a constant default value of the expected return type: f():R - MethodHandle returnDefault = zero(methodType.returnType()); - - // Add extra argument to match expected method type: f(a, b, c, ..., n):R - returnDefault = permuteArguments(returnDefault, methodType.changeReturnType(methodType.returnType())); - - // Convert return to a primitive is necessary: f(a, b, c, ..., n):r - returnDefault = explicitCastArguments(returnDefault, methodType); - return returnDefault; - } - - private static MethodHandle returnNull(MethodType methodType, InvocationReturnConvention returnConvention) { if (returnConvention == BLOCK_BUILDER) { - return permuteArguments(APPEND_NULL_METHOD, methodType, methodType.parameterCount() - 1); + return permuteArguments(APPEND_NULL_METHOD, methodHandle.type(), methodHandle.type().parameterCount() - 1); } - - // Start with a constant null value of the expected return type: f():R - MethodHandle returnNull = constant(wrap(methodType.returnType()), null); - - // Add extra argument to match expected method type: f(a, b, c, ..., n):R - returnNull = permuteArguments(returnNull, methodType.changeReturnType(wrap(methodType.returnType()))); - - // Convert return to a primitive is necessary: f(a, b, c, ..., n):r - returnNull = explicitCastArguments(returnNull, methodType); - return returnNull; + return empty(methodHandle.type()); } private static MethodHandle lookupAppendNullMethod() diff --git a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java index 47c3eda14f0d..5b3fb9188185 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java +++ b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java @@ -67,6 +67,7 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertSame; import static org.testng.Assert.assertTrue; public class TestScalarFunctionAdapter @@ -75,7 +76,7 @@ public class TestScalarFunctionAdapter private static final CharType CHAR_TYPE = createCharType(7); private static final TimestampType TIMESTAMP_TYPE = createTimestampType(9); private static final Type RETURN_TYPE = BOOLEAN; - private static final List ARGUMENT_TYPES = ImmutableList.of(BOOLEAN, BIGINT, DOUBLE, VARCHAR, ARRAY_TYPE); + private static final List ARGUMENT_TYPES = ImmutableList.of(DOUBLE, VARCHAR, ARRAY_TYPE); private static final List OBJECTS_ARGUMENT_TYPES = ImmutableList.of(VARCHAR, ARRAY_TYPE, CHAR_TYPE, TIMESTAMP_TYPE); @Test @@ -104,6 +105,58 @@ public void testAdaptFromNeverNullObjects() verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); } + @Test + public void testAdaptFromBoxedNull() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(ARGUMENT_TYPES.size(), BOXED_NULLABLE), + FAIL_ON_NULL, + false, + true); + String methodName = "boxedNull"; + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromBoxedNullObjects() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(OBJECTS_ARGUMENT_TYPES.size(), BOXED_NULLABLE), + FAIL_ON_NULL, + false, + true); + String methodName = "boxedNullObjects"; + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromNullFlag() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(ARGUMENT_TYPES.size(), NULL_FLAG), + FAIL_ON_NULL, + false, + true); + String methodName = "nullFlag"; + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromNullFlagObjects() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(OBJECTS_ARGUMENT_TYPES.size(), NULL_FLAG), + FAIL_ON_NULL, + false, + true); + String methodName = "nullFlagObjects"; + verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + } + private static void verifyAllAdaptations( InvocationConvention actualConvention, String methodName, @@ -158,10 +211,11 @@ private static void adaptAndVerify( assertTrue(ScalarFunctionAdapter.canAdapt(actualConvention, expectedConvention)); } catch (IllegalArgumentException e) { - assertFalse(ScalarFunctionAdapter.canAdapt(actualConvention, expectedConvention)); - assertTrue((expectedConvention.getReturnConvention() == FAIL_ON_NULL)); - if (hasNullableToNoNullableAdaptation(actualConvention, expectedConvention)) { - return; + if (!ScalarFunctionAdapter.canAdapt(actualConvention, expectedConvention)) { + assertSame(expectedConvention.getReturnConvention(), FAIL_ON_NULL); + if (hasNullableToNoNullableAdaptation(actualConvention, expectedConvention)) { + return; + } } throw new AssertionError("Adaptation failed but no illegal conversions found", e); } @@ -188,14 +242,16 @@ private static void adaptAndVerify( Target target = new Target(); List argumentValues = toCallArgumentValues(newCallingConvention, nullArguments, target, argumentTypes); try { + boolean expectNull = expectNullReturn(actualConvention, nullArguments); if (expectedConvention.getReturnConvention() == BLOCK_BUILDER) { BlockBuilder blockBuilder = returnType.createBlockBuilder(null, 1); argumentValues.add(blockBuilder); exactInvoker.invokeWithArguments(argumentValues); Block result = blockBuilder.build(); assertThat(result.getPositionCount()).isEqualTo(1); - if (!result.isNull(0)) { - assertTrue(BOOLEAN.getBoolean(result, 0)); + assertThat(result.isNull(0)).isEqualTo(expectNull); + if (!expectNull) { + assertThat(BOOLEAN.getBoolean(result, 0)).isTrue(); } return; } @@ -203,8 +259,8 @@ private static void adaptAndVerify( Boolean result = (Boolean) exactInvoker.invokeWithArguments(argumentValues); switch (expectedConvention.getReturnConvention()) { case FAIL_ON_NULL -> assertTrue(result); - case DEFAULT_ON_NULL -> assertEquals(result, (Boolean) nullArguments.isEmpty()); - case NULLABLE_RETURN -> assertEquals(result, nullArguments.isEmpty() ? true : null); + case DEFAULT_ON_NULL -> assertEquals(result, (Boolean) !expectNull); + case NULLABLE_RETURN -> assertEquals(result, !expectNull ? true : null); default -> throw new UnsupportedOperationException(); } } @@ -243,6 +299,17 @@ private static boolean canCallConventionWithNullArguments(InvocationConvention c return true; } + private static boolean expectNullReturn(InvocationConvention convention, BitSet nullArguments) + { + for (int i = 0; i < convention.getArgumentConventions().size(); i++) { + InvocationArgumentConvention argumentConvention = convention.getArgumentConvention(i); + if (nullArguments.get(i) && !argumentConvention.isNullable()) { + return true; + } + } + return false; + } + private static List> toCallArgumentTypes(InvocationConvention callingConvention, List argumentTypes) { List> expectedArguments = new ArrayList<>(); @@ -380,8 +447,6 @@ private static class Target { private boolean invoked; private boolean objectsMethod; - private Boolean booleanValue; - private Long longValue; private Double doubleValue; private Slice sliceValue; private Block blockValue; @@ -389,14 +454,12 @@ private static class Target private Object objectTimestampValue; @SuppressWarnings("unused") - public boolean neverNull(boolean booleanValue, long longValue, double doubleValue, Slice sliceValue, Block blockValue) + public boolean neverNull(double doubleValue, Slice sliceValue, Block blockValue) { checkState(!invoked, "Already invoked"); invoked = true; objectsMethod = false; - this.booleanValue = booleanValue; - this.longValue = longValue; this.doubleValue = doubleValue; this.sliceValue = sliceValue; this.blockValue = blockValue; @@ -418,14 +481,12 @@ public boolean neverNullObjects(Slice sliceValue, Block blockValue, Object objec } @SuppressWarnings("unused") - public boolean boxedNull(Boolean booleanValue, Long longValue, Double doubleValue, Slice sliceValue, Block blockValue) + public boolean boxedNull(Double doubleValue, Slice sliceValue, Block blockValue) { checkState(!invoked, "Already invoked"); invoked = true; objectsMethod = false; - this.booleanValue = booleanValue; - this.longValue = longValue; this.doubleValue = doubleValue; this.sliceValue = sliceValue; this.blockValue = blockValue; @@ -448,8 +509,6 @@ public boolean boxedNullObjects(Slice sliceValue, Block blockValue, Object objec @SuppressWarnings("unused") public boolean nullFlag( - boolean booleanValue, boolean booleanNull, - long longValue, boolean longNull, double doubleValue, boolean doubleNull, Slice sliceValue, boolean sliceNull, Block blockValue, boolean blockNull) @@ -458,22 +517,6 @@ public boolean nullFlag( invoked = true; objectsMethod = false; - if (booleanNull) { - assertFalse(booleanValue); - this.booleanValue = null; - } - else { - this.booleanValue = booleanValue; - } - - if (longNull) { - assertEquals(longValue, 0); - this.longValue = null; - } - else { - this.longValue = longValue; - } - if (doubleNull) { assertEquals(doubleValue, 0.0); this.doubleValue = null; @@ -553,11 +596,9 @@ public void verify( if (shouldFunctionBeInvoked(actualConvention, nullArguments)) { assertTrue(invoked, "function not invoked"); if (!objectsMethod) { - assertArgumentValue(this.booleanValue, 0, actualConvention, nullArguments, argumentTypes); - assertArgumentValue(this.longValue, 1, actualConvention, nullArguments, argumentTypes); - assertArgumentValue(this.doubleValue, 2, actualConvention, nullArguments, argumentTypes); - assertArgumentValue(this.sliceValue, 3, actualConvention, nullArguments, argumentTypes); - assertArgumentValue(this.blockValue, 4, actualConvention, nullArguments, argumentTypes); + assertArgumentValue(this.doubleValue, 0, actualConvention, nullArguments, argumentTypes); + assertArgumentValue(this.sliceValue, 1, actualConvention, nullArguments, argumentTypes); + assertArgumentValue(this.blockValue, 2, actualConvention, nullArguments, argumentTypes); } else { assertArgumentValue(this.sliceValue, 0, actualConvention, nullArguments, argumentTypes); @@ -568,8 +609,6 @@ public void verify( } else { assertFalse(invoked, "Function should not be invoked when null is passed to a NEVER_NULL argument"); - assertNull(this.booleanValue); - assertNull(this.longValue); assertNull(this.doubleValue); assertNull(this.sliceValue); assertNull(this.blockValue); @@ -579,8 +618,6 @@ public void verify( this.invoked = false; this.objectsMethod = false; - this.booleanValue = null; - this.longValue = null; this.doubleValue = null; this.sliceValue = null; this.blockValue = null; From 07b1691802328780bbf10282cf67c8527b72b99e Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 15 Jul 2023 18:34:58 -0700 Subject: [PATCH 2/2] Add test coverage for BLOCK_POSITION and BLOCK_POSITION_NOT_NULL --- .../spi/function/ScalarFunctionAdapter.java | 4 +- .../function/TestScalarFunctionAdapter.java | 148 ++++++++++++++++-- 2 files changed, 136 insertions(+), 16 deletions(-) diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java index 2a2c40157ff6..2f5f62881e8f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java @@ -140,7 +140,7 @@ private static boolean canAdaptParameter( //noinspection DataFlowIssue case NEVER_NULL -> true; }; - case BLOCK_POSITION_NOT_NULL -> expectedArgumentConvention == BLOCK_POSITION && returnConvention.isNullable(); + case BLOCK_POSITION_NOT_NULL -> expectedArgumentConvention == BLOCK_POSITION && (returnConvention.isNullable() || returnConvention == DEFAULT_ON_NULL); case BLOCK_POSITION -> expectedArgumentConvention == BLOCK_POSITION_NOT_NULL; case BOXED_NULLABLE, NULL_FLAG -> true; case IN_OUT -> false; @@ -490,7 +490,7 @@ private static MethodHandle adaptParameter( } } - throw new IllegalArgumentException("Cannot convert argument %s to %s with return convention %s".formatted(expectedArgumentConvention, actualArgumentConvention, returnConvention)); + throw new IllegalArgumentException("Cannot convert argument %s to %s with return convention %s".formatted(actualArgumentConvention, expectedArgumentConvention, returnConvention)); } private static MethodHandle getBlockValue(Type argumentType, Class expectedType) diff --git a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java index 5b3fb9188185..8a4930526d18 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java +++ b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java @@ -88,8 +88,7 @@ public void testAdaptFromNeverNull() FAIL_ON_NULL, false, true); - String methodName = "neverNull"; - verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, "neverNull", RETURN_TYPE, ARGUMENT_TYPES); } @Test @@ -101,8 +100,7 @@ public void testAdaptFromNeverNullObjects() FAIL_ON_NULL, false, true); - String methodName = "neverNullObjects"; - verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, "neverNullObjects", RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); } @Test @@ -114,8 +112,7 @@ public void testAdaptFromBoxedNull() FAIL_ON_NULL, false, true); - String methodName = "boxedNull"; - verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, "boxedNull", RETURN_TYPE, ARGUMENT_TYPES); } @Test @@ -127,8 +124,7 @@ public void testAdaptFromBoxedNullObjects() FAIL_ON_NULL, false, true); - String methodName = "boxedNullObjects"; - verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, "boxedNullObjects", RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); } @Test @@ -140,8 +136,7 @@ public void testAdaptFromNullFlag() FAIL_ON_NULL, false, true); - String methodName = "nullFlag"; - verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, "nullFlag", RETURN_TYPE, ARGUMENT_TYPES); } @Test @@ -153,8 +148,55 @@ public void testAdaptFromNullFlagObjects() FAIL_ON_NULL, false, true); - String methodName = "nullFlagObjects"; - verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, "nullFlagObjects", RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromBlockPosition() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(ARGUMENT_TYPES.size(), BLOCK_POSITION), + FAIL_ON_NULL, + false, + true); + verifyAllAdaptations(actualConvention, "blockPosition", RETURN_TYPE, ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromBlockPositionObjects() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(OBJECTS_ARGUMENT_TYPES.size(), BLOCK_POSITION), + FAIL_ON_NULL, + false, + true); + verifyAllAdaptations(actualConvention, "blockPositionObjects", RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromBlockPositionNotNull() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(ARGUMENT_TYPES.size(), BLOCK_POSITION_NOT_NULL), + FAIL_ON_NULL, + false, + true); + verifyAllAdaptations(actualConvention, "blockPosition", RETURN_TYPE, ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromBlockPositionNotNullObjects() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(OBJECTS_ARGUMENT_TYPES.size(), BLOCK_POSITION_NOT_NULL), + FAIL_ON_NULL, + false, + true); + verifyAllAdaptations(actualConvention, "blockPositionObjects", RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); } private static void verifyAllAdaptations( @@ -212,8 +254,11 @@ private static void adaptAndVerify( } catch (IllegalArgumentException e) { if (!ScalarFunctionAdapter.canAdapt(actualConvention, expectedConvention)) { - assertSame(expectedConvention.getReturnConvention(), FAIL_ON_NULL); if (hasNullableToNoNullableAdaptation(actualConvention, expectedConvention)) { + assertSame(expectedConvention.getReturnConvention(), FAIL_ON_NULL); + return; + } + if (actualConvention.getArgumentConventions().stream().anyMatch(convention -> convention == BLOCK_POSITION || convention == BLOCK_POSITION_NOT_NULL)) { return; } } @@ -588,6 +633,80 @@ public boolean nullFlagObjects( return true; } + @SuppressWarnings("unused") + public boolean blockPosition( + Block doubleBlock, int doublePosition, + Block sliceBlock, int slicePosition, + Block blockBlock, int blockPosition) + { + checkState(!invoked, "Already invoked"); + invoked = true; + objectsMethod = false; + + if (doubleBlock.isNull(doublePosition)) { + this.doubleValue = null; + } + else { + this.doubleValue = DOUBLE.getDouble(doubleBlock, doublePosition); + } + + if (sliceBlock.isNull(slicePosition)) { + this.sliceValue = null; + } + else { + this.sliceValue = VARCHAR.getSlice(sliceBlock, slicePosition); + } + + if (blockBlock.isNull(blockPosition)) { + this.blockValue = null; + } + else { + this.blockValue = ARRAY_TYPE.getObject(blockBlock, blockPosition); + } + return true; + } + + @SuppressWarnings("unused") + public boolean blockPositionObjects( + Block sliceBlock, int slicePosition, + Block blockBlock, int blockPosition, + Block objectCharBlock, int objectCharPosition, + Block objectTimestampBlock, int objectTimestampPosition) + { + checkState(!invoked, "Already invoked"); + invoked = true; + objectsMethod = true; + + if (sliceBlock.isNull(slicePosition)) { + this.sliceValue = null; + } + else { + this.sliceValue = VARCHAR.getSlice(sliceBlock, slicePosition); + } + + if (blockBlock.isNull(blockPosition)) { + this.blockValue = null; + } + else { + this.blockValue = ARRAY_TYPE.getObject(blockBlock, blockPosition); + } + + if (objectCharBlock.isNull(objectCharPosition)) { + this.objectCharValue = null; + } + else { + this.objectCharValue = CHAR_TYPE.getObject(objectCharBlock, objectCharPosition); + } + + if (objectTimestampBlock.isNull(objectTimestampPosition)) { + this.objectTimestampValue = null; + } + else { + this.objectTimestampValue = TIMESTAMP_TYPE.getObject(objectTimestampBlock, objectTimestampPosition); + } + return true; + } + public void verify( InvocationConvention actualConvention, BitSet nullArguments, @@ -628,7 +747,8 @@ public void verify( private static boolean shouldFunctionBeInvoked(InvocationConvention actualConvention, BitSet nullArguments) { for (int i = 0; i < actualConvention.getArgumentConventions().size(); i++) { - if (actualConvention.getArgumentConvention(i) == NEVER_NULL && nullArguments.get(i)) { + InvocationArgumentConvention argumentConvention = actualConvention.getArgumentConvention(i); + if ((argumentConvention == NEVER_NULL || argumentConvention == BLOCK_POSITION_NOT_NULL) && nullArguments.get(i)) { return false; } }