From 53efab1b5fccde5963697b2e382cc729597ce75d Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 15 Jul 2023 18:34:58 -0700 Subject: [PATCH] Add test coverage for BLOCK_POSITION and BLOCK_POSITION_NOT_NULL --- .../spi/function/ScalarFunctionAdapter.java | 4 +- .../function/TestScalarFunctionAdapter.java | 148 ++++++++++++++++-- 2 files changed, 136 insertions(+), 16 deletions(-) diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java index 2a2c40157ff6..2f5f62881e8f 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/ScalarFunctionAdapter.java @@ -140,7 +140,7 @@ private static boolean canAdaptParameter( //noinspection DataFlowIssue case NEVER_NULL -> true; }; - case BLOCK_POSITION_NOT_NULL -> expectedArgumentConvention == BLOCK_POSITION && returnConvention.isNullable(); + case BLOCK_POSITION_NOT_NULL -> expectedArgumentConvention == BLOCK_POSITION && (returnConvention.isNullable() || returnConvention == DEFAULT_ON_NULL); case BLOCK_POSITION -> expectedArgumentConvention == BLOCK_POSITION_NOT_NULL; case BOXED_NULLABLE, NULL_FLAG -> true; case IN_OUT -> false; @@ -490,7 +490,7 @@ private static MethodHandle adaptParameter( } } - throw new IllegalArgumentException("Cannot convert argument %s to %s with return convention %s".formatted(expectedArgumentConvention, actualArgumentConvention, returnConvention)); + throw new IllegalArgumentException("Cannot convert argument %s to %s with return convention %s".formatted(actualArgumentConvention, expectedArgumentConvention, returnConvention)); } private static MethodHandle getBlockValue(Type argumentType, Class expectedType) diff --git a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java index 5b3fb9188185..8a4930526d18 100644 --- a/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java +++ b/core/trino-spi/src/test/java/io/trino/spi/function/TestScalarFunctionAdapter.java @@ -88,8 +88,7 @@ public void testAdaptFromNeverNull() FAIL_ON_NULL, false, true); - String methodName = "neverNull"; - verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, "neverNull", RETURN_TYPE, ARGUMENT_TYPES); } @Test @@ -101,8 +100,7 @@ public void testAdaptFromNeverNullObjects() FAIL_ON_NULL, false, true); - String methodName = "neverNullObjects"; - verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, "neverNullObjects", RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); } @Test @@ -114,8 +112,7 @@ public void testAdaptFromBoxedNull() FAIL_ON_NULL, false, true); - String methodName = "boxedNull"; - verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, "boxedNull", RETURN_TYPE, ARGUMENT_TYPES); } @Test @@ -127,8 +124,7 @@ public void testAdaptFromBoxedNullObjects() FAIL_ON_NULL, false, true); - String methodName = "boxedNullObjects"; - verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, "boxedNullObjects", RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); } @Test @@ -140,8 +136,7 @@ public void testAdaptFromNullFlag() FAIL_ON_NULL, false, true); - String methodName = "nullFlag"; - verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, "nullFlag", RETURN_TYPE, ARGUMENT_TYPES); } @Test @@ -153,8 +148,55 @@ public void testAdaptFromNullFlagObjects() FAIL_ON_NULL, false, true); - String methodName = "nullFlagObjects"; - verifyAllAdaptations(actualConvention, methodName, RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + verifyAllAdaptations(actualConvention, "nullFlagObjects", RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromBlockPosition() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(ARGUMENT_TYPES.size(), BLOCK_POSITION), + FAIL_ON_NULL, + false, + true); + verifyAllAdaptations(actualConvention, "blockPosition", RETURN_TYPE, ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromBlockPositionObjects() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(OBJECTS_ARGUMENT_TYPES.size(), BLOCK_POSITION), + FAIL_ON_NULL, + false, + true); + verifyAllAdaptations(actualConvention, "blockPositionObjects", RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromBlockPositionNotNull() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(ARGUMENT_TYPES.size(), BLOCK_POSITION_NOT_NULL), + FAIL_ON_NULL, + false, + true); + verifyAllAdaptations(actualConvention, "blockPosition", RETURN_TYPE, ARGUMENT_TYPES); + } + + @Test + public void testAdaptFromBlockPositionNotNullObjects() + throws Throwable + { + InvocationConvention actualConvention = new InvocationConvention( + nCopies(OBJECTS_ARGUMENT_TYPES.size(), BLOCK_POSITION_NOT_NULL), + FAIL_ON_NULL, + false, + true); + verifyAllAdaptations(actualConvention, "blockPositionObjects", RETURN_TYPE, OBJECTS_ARGUMENT_TYPES); } private static void verifyAllAdaptations( @@ -212,8 +254,11 @@ private static void adaptAndVerify( } catch (IllegalArgumentException e) { if (!ScalarFunctionAdapter.canAdapt(actualConvention, expectedConvention)) { - assertSame(expectedConvention.getReturnConvention(), FAIL_ON_NULL); if (hasNullableToNoNullableAdaptation(actualConvention, expectedConvention)) { + assertSame(expectedConvention.getReturnConvention(), FAIL_ON_NULL); + return; + } + if (actualConvention.getArgumentConventions().stream().anyMatch(convention -> convention == BLOCK_POSITION || convention == BLOCK_POSITION_NOT_NULL)) { return; } } @@ -588,6 +633,80 @@ public boolean nullFlagObjects( return true; } + @SuppressWarnings("unused") + public boolean blockPosition( + Block doubleBlock, int doublePosition, + Block sliceBlock, int slicePosition, + Block blockBlock, int blockPosition) + { + checkState(!invoked, "Already invoked"); + invoked = true; + objectsMethod = false; + + if (doubleBlock.isNull(doublePosition)) { + this.doubleValue = null; + } + else { + this.doubleValue = DOUBLE.getDouble(doubleBlock, doublePosition); + } + + if (sliceBlock.isNull(slicePosition)) { + this.sliceValue = null; + } + else { + this.sliceValue = VARCHAR.getSlice(sliceBlock, slicePosition); + } + + if (blockBlock.isNull(blockPosition)) { + this.blockValue = null; + } + else { + this.blockValue = ARRAY_TYPE.getObject(blockBlock, blockPosition); + } + return true; + } + + @SuppressWarnings("unused") + public boolean blockPositionObjects( + Block sliceBlock, int slicePosition, + Block blockBlock, int blockPosition, + Block objectCharBlock, int objectCharPosition, + Block objectTimestampBlock, int objectTimestampPosition) + { + checkState(!invoked, "Already invoked"); + invoked = true; + objectsMethod = true; + + if (sliceBlock.isNull(slicePosition)) { + this.sliceValue = null; + } + else { + this.sliceValue = VARCHAR.getSlice(sliceBlock, slicePosition); + } + + if (blockBlock.isNull(blockPosition)) { + this.blockValue = null; + } + else { + this.blockValue = ARRAY_TYPE.getObject(blockBlock, blockPosition); + } + + if (objectCharBlock.isNull(objectCharPosition)) { + this.objectCharValue = null; + } + else { + this.objectCharValue = CHAR_TYPE.getObject(objectCharBlock, objectCharPosition); + } + + if (objectTimestampBlock.isNull(objectTimestampPosition)) { + this.objectTimestampValue = null; + } + else { + this.objectTimestampValue = TIMESTAMP_TYPE.getObject(objectTimestampBlock, objectTimestampPosition); + } + return true; + } + public void verify( InvocationConvention actualConvention, BitSet nullArguments, @@ -628,7 +747,8 @@ public void verify( private static boolean shouldFunctionBeInvoked(InvocationConvention actualConvention, BitSet nullArguments) { for (int i = 0; i < actualConvention.getArgumentConventions().size(); i++) { - if (actualConvention.getArgumentConvention(i) == NEVER_NULL && nullArguments.get(i)) { + InvocationArgumentConvention argumentConvention = actualConvention.getArgumentConvention(i); + if ((argumentConvention == NEVER_NULL || argumentConvention == BLOCK_POSITION_NOT_NULL) && nullArguments.get(i)) { return false; } }