Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix scalar function adapter #18473

Merged
merged 2 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static java.lang.invoke.MethodHandles.collectArguments;
import static java.lang.invoke.MethodHandles.constant;
import static java.lang.invoke.MethodHandles.dropArguments;
import static java.lang.invoke.MethodHandles.empty;
import static java.lang.invoke.MethodHandles.explicitCastArguments;
import static java.lang.invoke.MethodHandles.filterArguments;
import static java.lang.invoke.MethodHandles.guardWithTest;
Expand All @@ -52,7 +52,6 @@
import static java.lang.invoke.MethodHandles.permuteArguments;
import static java.lang.invoke.MethodHandles.publicLookup;
import static java.lang.invoke.MethodHandles.throwException;
import static java.lang.invoke.MethodHandles.zero;
import static java.lang.invoke.MethodType.methodType;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -141,7 +140,7 @@ private static boolean canAdaptParameter(
//noinspection DataFlowIssue
case NEVER_NULL -> true;
};
case BLOCK_POSITION_NOT_NULL -> expectedArgumentConvention == BLOCK_POSITION && returnConvention.isNullable();
case BLOCK_POSITION_NOT_NULL -> expectedArgumentConvention == BLOCK_POSITION && (returnConvention.isNullable() || returnConvention == DEFAULT_ON_NULL);
case BLOCK_POSITION -> expectedArgumentConvention == BLOCK_POSITION_NOT_NULL;
case BOXED_NULLABLE, NULL_FLAG -> true;
case IN_OUT -> false;
Expand Down Expand Up @@ -297,10 +296,9 @@ private static MethodHandle adaptParameter(
if (returnConvention == FAIL_ON_NULL) {
throw new IllegalArgumentException("RETURN_NULL_ON_NULL adaptation cannot be used with FAIL_ON_NULL return convention");
}
MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention);
return guardWithTest(
isNullArgument(methodHandle.type(), parameterIndex),
nullReturnValue,
getNullShortCircuitResult(methodHandle, returnConvention),
methodHandle);
}

Expand Down Expand Up @@ -341,15 +339,14 @@ private static MethodHandle adaptParameter(
// add a null flag to call
methodHandle = dropArguments(methodHandle, parameterIndex + 1, boolean.class);

MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention);
return guardWithTest(
isTrueNullFlag(methodHandle.type(), parameterIndex),
nullReturnValue,
getNullShortCircuitResult(methodHandle, returnConvention),
methodHandle);
}

if (actualArgumentConvention == BOXED_NULLABLE) {
return collectArguments(methodHandle, parameterIndex, boxedToNullFlagFilter(returnConvention, methodHandle.type().parameterType(parameterIndex)));
return collectArguments(methodHandle, parameterIndex, boxedToNullFlagFilter(methodHandle.type().parameterType(parameterIndex)));
}
}

Expand Down Expand Up @@ -381,10 +378,9 @@ private static MethodHandle adaptParameter(
// if caller sets the null flag, return null, otherwise invoke target
methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue);

MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention);
return guardWithTest(
isBlockPositionNull(methodHandle.type(), parameterIndex),
nullReturnValue,
getNullShortCircuitResult(methodHandle, returnConvention),
methodHandle);
}

Expand All @@ -400,7 +396,7 @@ private static MethodHandle adaptParameter(
getBlockValue = explicitCastArguments(getBlockValue, getBlockValue.type().changeReturnType(wrap(getBlockValue.type().returnType())));
getBlockValue = guardWithTest(
isBlockPositionNull(getBlockValue.type(), 0),
returnNull(getBlockValue.type(), returnConvention),
empty(getBlockValue.type()),
getBlockValue);
methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue);
return methodHandle;
Expand All @@ -411,11 +407,13 @@ private static MethodHandle adaptParameter(
MethodHandle isNull = isBlockPositionNull(getBlockValue.type(), 0);
methodHandle = collectArguments(methodHandle, parameterIndex + 1, isNull);

// long, Block, int => Block, int, Block, int
// convert get block value to be null safe
getBlockValue = guardWithTest(
isBlockPositionNull(getBlockValue.type(), 0),
returnNull(getBlockValue.type(), returnConvention),
empty(getBlockValue.type()),
getBlockValue);

// long, Block, int => Block, int, Block, int
methodHandle = collectArguments(methodHandle, parameterIndex, getBlockValue);

int[] reorder = IntStream.range(0, methodHandle.type().parameterCount())
Expand All @@ -427,7 +425,7 @@ private static MethodHandle adaptParameter(
}

if (actualArgumentConvention == BLOCK_POSITION_NOT_NULL) {
if (returnConvention.isNullable()) {
if (returnConvention != FAIL_ON_NULL) {
MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention);
return guardWithTest(
isBlockPositionNull(methodHandle.type(), parameterIndex),
Expand All @@ -446,10 +444,9 @@ private static MethodHandle adaptParameter(
// if caller sets the null flag, return null, otherwise invoke target
methodHandle = collectArguments(methodHandle, parameterIndex, getInOutValue);

MethodHandle nullReturnValue = getNullShortCircuitResult(methodHandle, returnConvention);
return guardWithTest(
isInOutNull(methodHandle.type(), parameterIndex),
nullReturnValue,
getNullShortCircuitResult(methodHandle, returnConvention),
methodHandle);
}

Expand All @@ -465,7 +462,7 @@ private static MethodHandle adaptParameter(
getInOutValue = explicitCastArguments(getInOutValue, getInOutValue.type().changeReturnType(wrap(getInOutValue.type().returnType())));
getInOutValue = guardWithTest(
isInOutNull(getInOutValue.type(), 0),
returnNull(getInOutValue.type(), returnConvention),
empty(getInOutValue.type()),
getInOutValue);
methodHandle = collectArguments(methodHandle, parameterIndex, getInOutValue);
return methodHandle;
Expand All @@ -479,7 +476,7 @@ private static MethodHandle adaptParameter(
// long, InOut => InOut, InOut
getInOutValue = guardWithTest(
isInOutNull(getInOutValue.type(), 0),
returnNull(getInOutValue.type(), returnConvention),
empty(getInOutValue.type()),
getInOutValue);
methodHandle = collectArguments(methodHandle, parameterIndex, getInOutValue);

Expand All @@ -493,7 +490,7 @@ private static MethodHandle adaptParameter(
}
}

throw new IllegalArgumentException("Cannot convert argument %s to %s with return convention %s".formatted(expectedArgumentConvention, actualArgumentConvention, returnConvention));
throw new IllegalArgumentException("Cannot convert argument %s to %s with return convention %s".formatted(actualArgumentConvention, expectedArgumentConvention, returnConvention));
}

private static MethodHandle getBlockValue(Type argumentType, Class<?> expectedType)
Expand Down Expand Up @@ -585,7 +582,7 @@ else if (methodArgumentType == double.class) {
}
}

private static MethodHandle boxedToNullFlagFilter(InvocationReturnConvention returnConvention, Class<?> argumentType)
private static MethodHandle boxedToNullFlagFilter(Class<?> argumentType)
{
// Start with identity
MethodHandle handle = identity(argumentType);
Expand All @@ -598,7 +595,7 @@ private static MethodHandle boxedToNullFlagFilter(InvocationReturnConvention ret
// if the flag is true, return null, otherwise invoke identity
return guardWithTest(
isTrueNullFlag(handle.type(), 0),
returnNull(handle.type(), returnConvention),
empty(handle.type()),
handle);
}

Expand Down Expand Up @@ -657,45 +654,11 @@ private static MethodHandle lookupIsNullMethod()
}

private static MethodHandle getNullShortCircuitResult(MethodHandle methodHandle, InvocationReturnConvention returnConvention)
{
MethodHandle nullReturnValue;
if (returnConvention == DEFAULT_ON_NULL) {
nullReturnValue = returnDefault(methodHandle.type());
}
else {
nullReturnValue = returnNull(methodHandle.type(), returnConvention);
}
return nullReturnValue;
}

private static MethodHandle returnDefault(MethodType methodType)
{
// Start with a constant default value of the expected return type: f():R
MethodHandle returnDefault = zero(methodType.returnType());

// Add extra argument to match expected method type: f(a, b, c, ..., n):R
returnDefault = permuteArguments(returnDefault, methodType.changeReturnType(methodType.returnType()));

// Convert return to a primitive is necessary: f(a, b, c, ..., n):r
returnDefault = explicitCastArguments(returnDefault, methodType);
return returnDefault;
}

private static MethodHandle returnNull(MethodType methodType, InvocationReturnConvention returnConvention)
{
if (returnConvention == BLOCK_BUILDER) {
return permuteArguments(APPEND_NULL_METHOD, methodType, methodType.parameterCount() - 1);
return permuteArguments(APPEND_NULL_METHOD, methodHandle.type(), methodHandle.type().parameterCount() - 1);
}

// Start with a constant null value of the expected return type: f():R
MethodHandle returnNull = constant(wrap(methodType.returnType()), null);

// Add extra argument to match expected method type: f(a, b, c, ..., n):R
returnNull = permuteArguments(returnNull, methodType.changeReturnType(wrap(methodType.returnType())));

// Convert return to a primitive is necessary: f(a, b, c, ..., n):r
returnNull = explicitCastArguments(returnNull, methodType);
return returnNull;
return empty(methodHandle.type());
}

private static MethodHandle lookupAppendNullMethod()
Expand Down
Loading