Skip to content

Commit

Permalink
Script: User funcs are script instance methods (#74268) (#74853)
Browse files Browse the repository at this point in the history
User defined functions are instance methods on the Script class.

Update lambdas and method references to capture the script `this`
reference.

Def method encoding string takes an extra char at index 1, whether
to capture the script reference.

For runtime fields, this means emit, which is an script instance
method already, now works in user defined functions.

Fixes: #69742
Refs: #68235
Backport: e26fa4e
  • Loading branch information
stu-elastic authored Jul 1, 2021
1 parent 57420f4 commit e9800bd
Show file tree
Hide file tree
Showing 18 changed files with 457 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ static MethodHandle lookupMethod(PainlessLookup painlessLookup, FunctionTable fu
String signature = (String) args[upTo++];
int numCaptures = Integer.parseInt(signature.substring(signature.indexOf(',')+1));
arity -= numCaptures;
// arity in painlessLookup does not include 'this' reference
if (signature.charAt(1) == 't') {
arity--;
}
}
}

Expand Down Expand Up @@ -251,11 +255,12 @@ static MethodHandle lookupMethod(PainlessLookup painlessLookup, FunctionTable fu
String signature = (String) args[upTo++];
int separator = signature.lastIndexOf('.');
int separator2 = signature.indexOf(',');
String type = signature.substring(1, separator);
String type = signature.substring(2, separator);
boolean needsScriptInstance = signature.charAt(1) == 't';
String call = signature.substring(separator+1, separator2);
int numCaptures = Integer.parseInt(signature.substring(separator2+1));
MethodHandle filter;
Class<?> interfaceType = method.typeParameters.get(i - 1 - replaced);
Class<?> interfaceType = method.typeParameters.get(i - 1 - replaced - (needsScriptInstance ? 1 : 0));
if (signature.charAt(0) == 'S') {
// the implementation is strongly typed, now that we know the interface type,
// we have everything.
Expand All @@ -266,7 +271,8 @@ static MethodHandle lookupMethod(PainlessLookup painlessLookup, FunctionTable fu
interfaceType,
type,
call,
numCaptures
numCaptures,
needsScriptInstance
);
} else if (signature.charAt(0) == 'D') {
// the interface type is now known, but we need to get the implementation.
Expand All @@ -292,7 +298,7 @@ static MethodHandle lookupMethod(PainlessLookup painlessLookup, FunctionTable fu
}
// the filter now ignores the signature (placeholder) on the stack
filter = MethodHandles.dropArguments(filter, 0, String.class);
handle = MethodHandles.collectArguments(handle, i, filter);
handle = MethodHandles.collectArguments(handle, i - (needsScriptInstance ? 1 : 0), filter);
i += numCaptures;
replaced += numCaptures;
}
Expand Down Expand Up @@ -328,20 +334,23 @@ static MethodHandle lookupReference(PainlessLookup painlessLookup, FunctionTable

return lookupReferenceInternal(painlessLookup, functions, constants,
methodHandlesLookup, interfaceType, PainlessLookupUtility.typeToCanonicalTypeName(implMethod.targetClass),
implMethod.javaMethod.getName(), 1);
implMethod.javaMethod.getName(), 1, false);
}

/** Returns a method handle to an implementation of clazz, given method reference signature. */
private static MethodHandle lookupReferenceInternal(
PainlessLookup painlessLookup, FunctionTable functions, Map<String, Object> constants,
MethodHandles.Lookup methodHandlesLookup, Class<?> clazz, String type, String call, int captures
) throws Throwable {
MethodHandles.Lookup methodHandlesLookup, Class<?> clazz, String type, String call, int captures,
boolean needsScriptInstance) throws Throwable {

final FunctionRef ref = FunctionRef.create(painlessLookup, functions, null, clazz, type, call, captures, constants);
final FunctionRef ref =
FunctionRef.create(painlessLookup, functions, null, clazz, type, call, captures, constants, needsScriptInstance);
Class<?>[] parameters = ref.factoryMethodParameters(needsScriptInstance ? methodHandlesLookup.lookupClass() : null);
MethodType factoryMethodType = MethodType.methodType(clazz, parameters);
final CallSite callSite = LambdaBootstrap.lambdaBootstrap(
methodHandlesLookup,
ref.interfaceMethodName,
ref.factoryMethodType,
factoryMethodType,
ref.interfaceMethodType,
ref.delegateClassName,
ref.delegateInvokeType,
Expand All @@ -351,7 +360,7 @@ private static MethodHandle lookupReferenceInternal(
ref.isDelegateAugmented ? 1 : 0,
ref.delegateInjections
);
return callSite.dynamicInvoker().asType(MethodType.methodType(clazz, ref.factoryMethodType.parameterArray()));
return callSite.dynamicInvoker().asType(MethodType.methodType(clazz, parameters));
}

/**
Expand Down Expand Up @@ -1268,4 +1277,30 @@ static MethodHandle arrayIndexNormalizer(Class<?> arrayType) {

private ArrayIndexNormalizeHelper() {}
}

public static class Encoding {
public final boolean isStatic;
public final boolean needsInstance;
public final String symbol;
public final String methodName;
public final int numCaptures;
public final String encoding;

public Encoding(boolean isStatic, boolean needsInstance, String symbol, String methodName, int numCaptures) {
this.isStatic = isStatic;
this.needsInstance = needsInstance;
this.symbol = symbol;
this.methodName = methodName;
this.numCaptures = numCaptures;
this.encoding = (isStatic ? "S" : "D") + (needsInstance ? "t" : "f") +
symbol + "." +
methodName + "," +
numCaptures;
}

@Override
public String toString() {
return encoding;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import org.elasticsearch.painless.lookup.PainlessMethod;
import org.elasticsearch.painless.symbol.FunctionTable;
import org.elasticsearch.painless.symbol.FunctionTable.LocalFunction;
import org.objectweb.asm.Type;

import java.lang.invoke.MethodType;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.painless.WriterConstants.CLASS_NAME;
import static org.objectweb.asm.Opcodes.H_INVOKEINTERFACE;
Expand All @@ -44,9 +46,11 @@ public class FunctionRef {
* @param methodName the right hand side of a method reference expression
* @param numberOfCaptures number of captured arguments
* @param constants constants used for injection when necessary
* @param needsScriptInstance uses an instance method and so receiver must be captured.
*/
public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable functionTable, Location location,
Class<?> targetClass, String typeName, String methodName, int numberOfCaptures, Map<String, Object> constants) {
Class<?> targetClass, String typeName, String methodName, int numberOfCaptures, Map<String, Object> constants,
boolean needsScriptInstance) {

Objects.requireNonNull(painlessLookup);
Objects.requireNonNull(targetClass);
Expand Down Expand Up @@ -98,7 +102,7 @@ public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable fu
delegateClassName = CLASS_NAME;
isDelegateInterface = false;
isDelegateAugmented = false;
delegateInvokeType = H_INVOKESTATIC;
delegateInvokeType = needsScriptInstance ? H_INVOKEVIRTUAL : H_INVOKESTATIC;
delegateMethodName = localFunction.getMangledName();
delegateMethodType = localFunction.getMethodType();
delegateInjections = new Object[0];
Expand Down Expand Up @@ -213,7 +217,7 @@ public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable fu
return new FunctionRef(interfaceMethodName, interfaceMethodType,
delegateClassName, isDelegateInterface, isDelegateAugmented,
delegateInvokeType, delegateMethodName, delegateMethodType, delegateInjections,
factoryMethodType
factoryMethodType, needsScriptInstance ? WriterConstants.CLASS_TYPE : null
);
} catch (IllegalArgumentException iae) {
if (location != null) {
Expand Down Expand Up @@ -243,13 +247,15 @@ public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable fu
/** injected constants */
public final Object[] delegateInjections;
/** factory (CallSite) method signature */
public final MethodType factoryMethodType;
private final MethodType factoryMethodType;
/** factory (CallSite) method receiver, this modifies the method descriptor for the factory method */
public final Type factoryMethodReceiver;

private FunctionRef(
String interfaceMethodName, MethodType interfaceMethodType,
String delegateClassName, boolean isDelegateInterface, boolean isDelegateAugmented,
int delegateInvokeType, String delegateMethodName, MethodType delegateMethodType, Object[] delegateInjections,
MethodType factoryMethodType) {
MethodType factoryMethodType, Type factoryMethodReceiver) {

this.interfaceMethodName = interfaceMethodName;
this.interfaceMethodType = interfaceMethodType;
Expand All @@ -261,5 +267,27 @@ private FunctionRef(
this.delegateMethodType = delegateMethodType;
this.delegateInjections = delegateInjections;
this.factoryMethodType = factoryMethodType;
this.factoryMethodReceiver = factoryMethodReceiver;
}

/** Get the factory method type, with updated receiver if {@code factoryMethodReceiver} is set */
public String getFactoryMethodDescriptor() {
if (factoryMethodReceiver == null) {
return factoryMethodType.toMethodDescriptorString();
}
List<Type> arguments = factoryMethodType.parameterList().stream().map(Type::getType).collect(Collectors.toList());
arguments.add(0, factoryMethodReceiver);
Type[] argArray = new Type[arguments.size()];
arguments.toArray(argArray);
return Type.getMethodDescriptor(Type.getType(factoryMethodType.returnType()), argArray);
}

/** Get the factory method type, updating the receiver if {@code factoryMethodReceiverClass} is non-null */
public Class<?>[] factoryMethodParameters(Class<?> factoryMethodReceiverClass) {
List<Class<?>> parameters = new ArrayList<>(factoryMethodType.parameterList());
if (factoryMethodReceiverClass != null) {
parameters.add(0, factoryMethodReceiverClass);
}
return parameters.toArray(new Class<?>[0]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.lang.invoke.MethodType;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.List;
import java.util.stream.Collectors;

import static java.lang.invoke.MethodHandles.Lookup;
import static org.elasticsearch.painless.WriterConstants.CLASS_VERSION;
Expand Down Expand Up @@ -392,6 +394,8 @@ private static void generateInterfaceMethod(
// Loads any passed in arguments onto the stack.
iface.loadArgs();

String functionalInterfaceWithCaptures;

// Handles the case for a lambda function or a static reference method.
// interfaceMethodType and delegateMethodType both have the captured types
// inserted into their type signatures. This later allows the delegate
Expand All @@ -402,6 +406,7 @@ private static void generateInterfaceMethod(
if (delegateInvokeType == H_INVOKESTATIC) {
interfaceMethodType =
interfaceMethodType.insertParameterTypes(0, factoryMethodType.parameterArray());
functionalInterfaceWithCaptures = interfaceMethodType.toMethodDescriptorString();
delegateMethodType =
delegateMethodType.insertParameterTypes(0, factoryMethodType.parameterArray());
} else if (delegateInvokeType == H_INVOKEVIRTUAL ||
Expand All @@ -414,19 +419,32 @@ private static void generateInterfaceMethod(
Class<?> clazz = delegateMethodType.parameterType(0);
delegateClassType = Type.getType(clazz);
delegateMethodType = delegateMethodType.dropParameterTypes(0, 1);
functionalInterfaceWithCaptures = interfaceMethodType.toMethodDescriptorString();
// Handles the case for a virtual or interface reference method with 'this'
// captured. interfaceMethodType inserts the 'this' type into its
// method signature. This later allows the delegate
// method to be invoked dynamically and have the interface method types
// appropriately converted to the delegate method types.
// Example: something::toString
} else if (captures.length == 1) {
} else {
Class<?> clazz = factoryMethodType.parameterType(0);
delegateClassType = Type.getType(clazz);
interfaceMethodType = interfaceMethodType.insertParameterTypes(0, clazz);
} else {
throw new LambdaConversionException(
"unexpected number of captures [ " + captures.length + "]");

// functionalInterfaceWithCaptures needs to add the receiver and other captures
List<Type> parameters = interfaceMethodType.parameterList().stream().map(Type::getType).collect(Collectors.toList());
parameters.add(0, delegateClassType);
for (int i = 1; i < captures.length; i++) {
parameters.add(i, captures[i].type);
}
Type[] parametersArray = parameters.toArray(new Type[0]);
functionalInterfaceWithCaptures = Type.getMethodDescriptor(Type.getType(interfaceMethodType.returnType()), parametersArray);

// delegateMethod does not need the receiver
List<Class<?>> factoryParameters = factoryMethodType.parameterList();
if (factoryParameters.size() > 1) {
List<Class<?>> factoryParametersWithReceiver = factoryParameters.subList(1, factoryParameters.size());
delegateMethodType = delegateMethodType.insertParameterTypes(0, factoryParametersWithReceiver);
}
}
} else {
throw new IllegalStateException(
Expand All @@ -445,7 +463,7 @@ private static void generateInterfaceMethod(
System.arraycopy(injections, 0, args, 2, injections.length);
iface.invokeDynamic(
delegateMethodName,
Type.getMethodType(interfaceMethodType.toMethodDescriptorString()).getDescriptor(),
Type.getMethodType(functionalInterfaceWithCaptures).getDescriptor(),
DELEGATE_BOOTSTRAP_HANDLE,
args);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,6 @@ public void invokeLambdaCall(FunctionRef functionRef) {
args[6] = functionRef.isDelegateAugmented ? 1 : 0;
System.arraycopy(functionRef.delegateInjections, 0, args, 7, functionRef.delegateInjections.length);

invokeDynamic(
functionRef.interfaceMethodName,
functionRef.factoryMethodType.toMethodDescriptorString(),
LAMBDA_BOOTSTRAP_HANDLE,
args
);
invokeDynamic(functionRef.interfaceMethodName, functionRef.getFactoryMethodDescriptor(), LAMBDA_BOOTSTRAP_HANDLE, args);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ public ANode visitFunction(FunctionContext ctx) {
}

return new SFunction(nextIdentifier(), location(ctx),
rtnType, name, paramTypes, paramNames, new SBlock(nextIdentifier(), location(ctx), statements), false, true, false, false);
rtnType, name, paramTypes, paramNames, new SBlock(nextIdentifier(), location(ctx), statements), false, false, false, false);
}

@Override
Expand Down
Loading

0 comments on commit e9800bd

Please sign in to comment.