diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java index ebbd67bf7cd15..630bc98a9656e 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java @@ -222,6 +222,10 @@ static MethodHandle lookupMethod(PainlessLookup painlessLookup, FunctionTable fu String signature = (String) args[upTo++]; int numCaptures = Integer.parseInt(signature.substring(signature.indexOf(',')+1)); arity -= numCaptures; + // arity in painlessLookup does not include 'this' reference + if (signature.charAt(1) == 't') { + arity--; + } } } @@ -251,11 +255,12 @@ static MethodHandle lookupMethod(PainlessLookup painlessLookup, FunctionTable fu String signature = (String) args[upTo++]; int separator = signature.lastIndexOf('.'); int separator2 = signature.indexOf(','); - String type = signature.substring(1, separator); + String type = signature.substring(2, separator); + boolean needsScriptInstance = signature.charAt(1) == 't'; String call = signature.substring(separator+1, separator2); int numCaptures = Integer.parseInt(signature.substring(separator2+1)); MethodHandle filter; - Class interfaceType = method.typeParameters.get(i - 1 - replaced); + Class interfaceType = method.typeParameters.get(i - 1 - replaced - (needsScriptInstance ? 1 : 0)); if (signature.charAt(0) == 'S') { // the implementation is strongly typed, now that we know the interface type, // we have everything. @@ -266,7 +271,8 @@ static MethodHandle lookupMethod(PainlessLookup painlessLookup, FunctionTable fu interfaceType, type, call, - numCaptures + numCaptures, + needsScriptInstance ); } else if (signature.charAt(0) == 'D') { // the interface type is now known, but we need to get the implementation. @@ -292,7 +298,7 @@ static MethodHandle lookupMethod(PainlessLookup painlessLookup, FunctionTable fu } // the filter now ignores the signature (placeholder) on the stack filter = MethodHandles.dropArguments(filter, 0, String.class); - handle = MethodHandles.collectArguments(handle, i, filter); + handle = MethodHandles.collectArguments(handle, i - (needsScriptInstance ? 1 : 0), filter); i += numCaptures; replaced += numCaptures; } @@ -328,20 +334,23 @@ static MethodHandle lookupReference(PainlessLookup painlessLookup, FunctionTable return lookupReferenceInternal(painlessLookup, functions, constants, methodHandlesLookup, interfaceType, PainlessLookupUtility.typeToCanonicalTypeName(implMethod.targetClass), - implMethod.javaMethod.getName(), 1); + implMethod.javaMethod.getName(), 1, false); } /** Returns a method handle to an implementation of clazz, given method reference signature. */ private static MethodHandle lookupReferenceInternal( PainlessLookup painlessLookup, FunctionTable functions, Map constants, - MethodHandles.Lookup methodHandlesLookup, Class clazz, String type, String call, int captures - ) throws Throwable { + MethodHandles.Lookup methodHandlesLookup, Class clazz, String type, String call, int captures, + boolean needsScriptInstance) throws Throwable { - final FunctionRef ref = FunctionRef.create(painlessLookup, functions, null, clazz, type, call, captures, constants); + final FunctionRef ref = + FunctionRef.create(painlessLookup, functions, null, clazz, type, call, captures, constants, needsScriptInstance); + Class[] parameters = ref.factoryMethodParameters(needsScriptInstance ? methodHandlesLookup.lookupClass() : null); + MethodType factoryMethodType = MethodType.methodType(clazz, parameters); final CallSite callSite = LambdaBootstrap.lambdaBootstrap( methodHandlesLookup, ref.interfaceMethodName, - ref.factoryMethodType, + factoryMethodType, ref.interfaceMethodType, ref.delegateClassName, ref.delegateInvokeType, @@ -351,7 +360,7 @@ private static MethodHandle lookupReferenceInternal( ref.isDelegateAugmented ? 1 : 0, ref.delegateInjections ); - return callSite.dynamicInvoker().asType(MethodType.methodType(clazz, ref.factoryMethodType.parameterArray())); + return callSite.dynamicInvoker().asType(MethodType.methodType(clazz, parameters)); } /** @@ -1268,4 +1277,30 @@ static MethodHandle arrayIndexNormalizer(Class arrayType) { private ArrayIndexNormalizeHelper() {} } + + public static class Encoding { + public final boolean isStatic; + public final boolean needsInstance; + public final String symbol; + public final String methodName; + public final int numCaptures; + public final String encoding; + + public Encoding(boolean isStatic, boolean needsInstance, String symbol, String methodName, int numCaptures) { + this.isStatic = isStatic; + this.needsInstance = needsInstance; + this.symbol = symbol; + this.methodName = methodName; + this.numCaptures = numCaptures; + this.encoding = (isStatic ? "S" : "D") + (needsInstance ? "t" : "f") + + symbol + "." + + methodName + "," + + numCaptures; + } + + @Override + public String toString() { + return encoding; + } + } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/FunctionRef.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/FunctionRef.java index 6d6d6651c053a..ff99e0d28da90 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/FunctionRef.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/FunctionRef.java @@ -14,6 +14,7 @@ import org.elasticsearch.painless.lookup.PainlessMethod; import org.elasticsearch.painless.symbol.FunctionTable; import org.elasticsearch.painless.symbol.FunctionTable.LocalFunction; +import org.objectweb.asm.Type; import java.lang.invoke.MethodType; import java.lang.reflect.Modifier; @@ -21,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; import static org.elasticsearch.painless.WriterConstants.CLASS_NAME; import static org.objectweb.asm.Opcodes.H_INVOKEINTERFACE; @@ -44,9 +46,11 @@ public class FunctionRef { * @param methodName the right hand side of a method reference expression * @param numberOfCaptures number of captured arguments * @param constants constants used for injection when necessary + * @param needsScriptInstance uses an instance method and so receiver must be captured. */ public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable functionTable, Location location, - Class targetClass, String typeName, String methodName, int numberOfCaptures, Map constants) { + Class targetClass, String typeName, String methodName, int numberOfCaptures, Map constants, + boolean needsScriptInstance) { Objects.requireNonNull(painlessLookup); Objects.requireNonNull(targetClass); @@ -98,7 +102,7 @@ public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable fu delegateClassName = CLASS_NAME; isDelegateInterface = false; isDelegateAugmented = false; - delegateInvokeType = H_INVOKESTATIC; + delegateInvokeType = needsScriptInstance ? H_INVOKEVIRTUAL : H_INVOKESTATIC; delegateMethodName = localFunction.getMangledName(); delegateMethodType = localFunction.getMethodType(); delegateInjections = new Object[0]; @@ -213,7 +217,7 @@ public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable fu return new FunctionRef(interfaceMethodName, interfaceMethodType, delegateClassName, isDelegateInterface, isDelegateAugmented, delegateInvokeType, delegateMethodName, delegateMethodType, delegateInjections, - factoryMethodType + factoryMethodType, needsScriptInstance ? WriterConstants.CLASS_TYPE : null ); } catch (IllegalArgumentException iae) { if (location != null) { @@ -243,13 +247,15 @@ public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable fu /** injected constants */ public final Object[] delegateInjections; /** factory (CallSite) method signature */ - public final MethodType factoryMethodType; + private final MethodType factoryMethodType; + /** factory (CallSite) method receiver, this modifies the method descriptor for the factory method */ + public final Type factoryMethodReceiver; private FunctionRef( String interfaceMethodName, MethodType interfaceMethodType, String delegateClassName, boolean isDelegateInterface, boolean isDelegateAugmented, int delegateInvokeType, String delegateMethodName, MethodType delegateMethodType, Object[] delegateInjections, - MethodType factoryMethodType) { + MethodType factoryMethodType, Type factoryMethodReceiver) { this.interfaceMethodName = interfaceMethodName; this.interfaceMethodType = interfaceMethodType; @@ -261,5 +267,27 @@ private FunctionRef( this.delegateMethodType = delegateMethodType; this.delegateInjections = delegateInjections; this.factoryMethodType = factoryMethodType; + this.factoryMethodReceiver = factoryMethodReceiver; + } + + /** Get the factory method type, with updated receiver if {@code factoryMethodReceiver} is set */ + public String getFactoryMethodDescriptor() { + if (factoryMethodReceiver == null) { + return factoryMethodType.toMethodDescriptorString(); + } + List arguments = factoryMethodType.parameterList().stream().map(Type::getType).collect(Collectors.toList()); + arguments.add(0, factoryMethodReceiver); + Type[] argArray = new Type[arguments.size()]; + arguments.toArray(argArray); + return Type.getMethodDescriptor(Type.getType(factoryMethodType.returnType()), argArray); + } + + /** Get the factory method type, updating the receiver if {@code factoryMethodReceiverClass} is non-null */ + public Class[] factoryMethodParameters(Class factoryMethodReceiverClass) { + List> parameters = new ArrayList<>(factoryMethodType.parameterList()); + if (factoryMethodReceiverClass != null) { + parameters.add(0, factoryMethodReceiverClass); + } + return parameters.toArray(new Class[0]); } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/LambdaBootstrap.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/LambdaBootstrap.java index 6492ae361b1f9..6d8d58e43ede3 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/LambdaBootstrap.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/LambdaBootstrap.java @@ -23,6 +23,8 @@ import java.lang.invoke.MethodType; import java.security.AccessController; import java.security.PrivilegedAction; +import java.util.List; +import java.util.stream.Collectors; import static java.lang.invoke.MethodHandles.Lookup; import static org.elasticsearch.painless.WriterConstants.CLASS_VERSION; @@ -392,6 +394,8 @@ private static void generateInterfaceMethod( // Loads any passed in arguments onto the stack. iface.loadArgs(); + String functionalInterfaceWithCaptures; + // Handles the case for a lambda function or a static reference method. // interfaceMethodType and delegateMethodType both have the captured types // inserted into their type signatures. This later allows the delegate @@ -402,6 +406,7 @@ private static void generateInterfaceMethod( if (delegateInvokeType == H_INVOKESTATIC) { interfaceMethodType = interfaceMethodType.insertParameterTypes(0, factoryMethodType.parameterArray()); + functionalInterfaceWithCaptures = interfaceMethodType.toMethodDescriptorString(); delegateMethodType = delegateMethodType.insertParameterTypes(0, factoryMethodType.parameterArray()); } else if (delegateInvokeType == H_INVOKEVIRTUAL || @@ -414,19 +419,32 @@ private static void generateInterfaceMethod( Class clazz = delegateMethodType.parameterType(0); delegateClassType = Type.getType(clazz); delegateMethodType = delegateMethodType.dropParameterTypes(0, 1); + functionalInterfaceWithCaptures = interfaceMethodType.toMethodDescriptorString(); // Handles the case for a virtual or interface reference method with 'this' // captured. interfaceMethodType inserts the 'this' type into its // method signature. This later allows the delegate // method to be invoked dynamically and have the interface method types // appropriately converted to the delegate method types. // Example: something::toString - } else if (captures.length == 1) { + } else { Class clazz = factoryMethodType.parameterType(0); delegateClassType = Type.getType(clazz); - interfaceMethodType = interfaceMethodType.insertParameterTypes(0, clazz); - } else { - throw new LambdaConversionException( - "unexpected number of captures [ " + captures.length + "]"); + + // functionalInterfaceWithCaptures needs to add the receiver and other captures + List parameters = interfaceMethodType.parameterList().stream().map(Type::getType).collect(Collectors.toList()); + parameters.add(0, delegateClassType); + for (int i = 1; i < captures.length; i++) { + parameters.add(i, captures[i].type); + } + Type[] parametersArray = parameters.toArray(new Type[0]); + functionalInterfaceWithCaptures = Type.getMethodDescriptor(Type.getType(interfaceMethodType.returnType()), parametersArray); + + // delegateMethod does not need the receiver + List> factoryParameters = factoryMethodType.parameterList(); + if (factoryParameters.size() > 1) { + List> factoryParametersWithReceiver = factoryParameters.subList(1, factoryParameters.size()); + delegateMethodType = delegateMethodType.insertParameterTypes(0, factoryParametersWithReceiver); + } } } else { throw new IllegalStateException( @@ -445,7 +463,7 @@ private static void generateInterfaceMethod( System.arraycopy(injections, 0, args, 2, injections.length); iface.invokeDynamic( delegateMethodName, - Type.getMethodType(interfaceMethodType.toMethodDescriptorString()).getDescriptor(), + Type.getMethodType(functionalInterfaceWithCaptures).getDescriptor(), DELEGATE_BOOTSTRAP_HANDLE, args); diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java index 07f0b122ff1a5..237a006293a8d 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java @@ -514,11 +514,6 @@ public void invokeLambdaCall(FunctionRef functionRef) { args[6] = functionRef.isDelegateAugmented ? 1 : 0; System.arraycopy(functionRef.delegateInjections, 0, args, 7, functionRef.delegateInjections.length); - invokeDynamic( - functionRef.interfaceMethodName, - functionRef.factoryMethodType.toMethodDescriptorString(), - LAMBDA_BOOTSTRAP_HANDLE, - args - ); + invokeDynamic(functionRef.interfaceMethodName, functionRef.getFactoryMethodDescriptor(), LAMBDA_BOOTSTRAP_HANDLE, args); } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/antlr/Walker.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/antlr/Walker.java index 91c1ba2c7f7a6..608ad449be3bb 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/antlr/Walker.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/antlr/Walker.java @@ -273,7 +273,7 @@ public ANode visitFunction(FunctionContext ctx) { } return new SFunction(nextIdentifier(), location(ctx), - rtnType, name, paramTypes, paramNames, new SBlock(nextIdentifier(), location(ctx), statements), false, true, false, false); + rtnType, name, paramTypes, paramNames, new SBlock(nextIdentifier(), location(ctx), statements), false, false, false, false); } @Override diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java index 5abc28a2a1e80..8fc2a377afc43 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java @@ -99,6 +99,7 @@ import org.elasticsearch.painless.symbol.IRDecorations.IRCCaptureBox; import org.elasticsearch.painless.symbol.IRDecorations.IRCContinuous; import org.elasticsearch.painless.symbol.IRDecorations.IRCInitialize; +import org.elasticsearch.painless.symbol.IRDecorations.IRCInstanceCapture; import org.elasticsearch.painless.symbol.IRDecorations.IRCStatic; import org.elasticsearch.painless.symbol.IRDecorations.IRCSynthetic; import org.elasticsearch.painless.symbol.IRDecorations.IRCVarArgs; @@ -1226,6 +1227,11 @@ public void visitDefInterfaceReference(DefInterfaceReferenceNode irDefInterfaceR // which is resolved and replace at runtime methodWriter.push((String)null); + if (irDefInterfaceReferenceNode.hasCondition(IRCInstanceCapture.class)) { + Variable capturedThis = writeScope.getInternalVariable("this"); + methodWriter.visitVarInsn(CLASS_TYPE.getOpcode(Opcodes.ILOAD), capturedThis.getSlot()); + } + List captureNames = irDefInterfaceReferenceNode.getDecorationValue(IRDCaptureNames.class); boolean captureBox = irDefInterfaceReferenceNode.hasCondition(IRCCaptureBox.class); @@ -1247,6 +1253,11 @@ public void visitTypedInterfaceReference(TypedInterfaceReferenceNode irTypedInte MethodWriter methodWriter = writeScope.getMethodWriter(); methodWriter.writeDebugInfo(irTypedInterfaceReferenceNode.getLocation()); + if (irTypedInterfaceReferenceNode.hasCondition(IRCInstanceCapture.class)) { + Variable capturedThis = writeScope.getInternalVariable("this"); + methodWriter.visitVarInsn(CLASS_TYPE.getOpcode(Opcodes.ILOAD), capturedThis.getSlot()); + } + List captureNames = irTypedInterfaceReferenceNode.getDecorationValue(IRDCaptureNames.class); boolean captureBox = irTypedInterfaceReferenceNode.hasCondition(IRCCaptureBox.class); @@ -1576,7 +1587,12 @@ public void visitInvokeCallDef(InvokeCallDefNode irInvokeCallDefNode, WriteScope DefInterfaceReferenceNode defInterfaceReferenceNode = (DefInterfaceReferenceNode)irArgumentNode; List captureNames = defInterfaceReferenceNode.getDecorationValueOrDefault(IRDCaptureNames.class, Collections.emptyList()); - boostrapArguments.add(defInterfaceReferenceNode.getDecorationValue(IRDDefReferenceEncoding.class)); + boostrapArguments.add(defInterfaceReferenceNode.getDecorationValue(IRDDefReferenceEncoding.class).toString()); + + if (defInterfaceReferenceNode.hasCondition(IRCInstanceCapture.class)) { + capturedCount++; + typeParameters.add(ScriptThis.class); + } // the encoding uses a char to indicate the number of captures // where the value is the number of current arguments plus the @@ -1596,7 +1612,12 @@ public void visitInvokeCallDef(InvokeCallDefNode irInvokeCallDefNode, WriteScope Type[] asmParameterTypes = new Type[typeParameters.size()]; for (int index = 0; index < asmParameterTypes.length; ++index) { - asmParameterTypes[index] = MethodWriter.getType(typeParameters.get(index)); + Class typeParameter = typeParameters.get(index); + if (typeParameter.equals(ScriptThis.class)) { + asmParameterTypes[index] = CLASS_TYPE; + } else { + asmParameterTypes[index] = MethodWriter.getType(typeParameters.get(index)); + } } String methodName = irInvokeCallDefNode.getDecorationValue(IRDName.class); @@ -1763,4 +1784,7 @@ public void visitDup(DupNode irDupNode, WriteScope writeScope) { methodWriter.writeDup(size, depth); } + + // placeholder class referring to the script instance + private static final class ScriptThis {} } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java index fc744b51325f8..769702afe94b3 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java @@ -88,6 +88,8 @@ import org.elasticsearch.painless.symbol.Decorations.ExpressionPainlessCast; import org.elasticsearch.painless.symbol.Decorations.GetterPainlessMethod; import org.elasticsearch.painless.symbol.Decorations.InLoop; +import org.elasticsearch.painless.symbol.Decorations.InstanceCapturingFunctionRef; +import org.elasticsearch.painless.symbol.Decorations.InstanceCapturingLambda; import org.elasticsearch.painless.symbol.Decorations.InstanceType; import org.elasticsearch.painless.symbol.Decorations.Internal; import org.elasticsearch.painless.symbol.Decorations.IterablePainlessMethod; @@ -1726,7 +1728,9 @@ public void visitCallLocal(ECallLocal userCallLocalNode, SemanticScope semanticS localFunction = null; } - if (localFunction == null) { + if (localFunction != null) { + semanticScope.setUsesInstanceMethod(); + } else { importedMethod = scriptScope.getPainlessLookup().lookupImportedPainlessMethod(methodName, userArgumentsSize); if (importedMethod == null) { @@ -2195,6 +2199,10 @@ public void visitLambda(ELambda userLambdaNode, SemanticScope semanticScope) { semanticScope.setCondition(userBlockNode, LastSource.class); visit(userBlockNode, lambdaScope); + if (lambdaScope.usesInstanceMethod()) { + semanticScope.setCondition(userLambdaNode, InstanceCapturingLambda.class); + } + if (semanticScope.getCondition(userBlockNode, MethodEscape.class) == false) { throw userLambdaNode.createError(new IllegalArgumentException("not all paths return a value for lambda")); } @@ -2214,18 +2222,19 @@ public void visitLambda(ELambda userLambdaNode, SemanticScope semanticScope) { // desugar lambda body into a synthetic method String name = scriptScope.getNextSyntheticName("lambda"); - scriptScope.getFunctionTable().addFunction(name, returnType, typeParametersWithCaptures, true, true); + boolean isStatic = lambdaScope.usesInstanceMethod() == false; + scriptScope.getFunctionTable().addFunction(name, returnType, typeParametersWithCaptures, true, isStatic); Class valueType; // setup method reference to synthetic method if (targetType == null) { - String defReferenceEncoding = "Sthis." + name + "," + capturedVariables.size(); valueType = String.class; - semanticScope.putDecoration(userLambdaNode, new EncodingDecoration(defReferenceEncoding)); + semanticScope.putDecoration(userLambdaNode, + new EncodingDecoration(true, lambdaScope.usesInstanceMethod(), "this", name, capturedVariables.size())); } else { FunctionRef ref = FunctionRef.create(scriptScope.getPainlessLookup(), scriptScope.getFunctionTable(), location, targetType.getTargetType(), "this", name, capturedVariables.size(), - scriptScope.getCompilerSettings().asMap()); + scriptScope.getCompilerSettings().asMap(), lambdaScope.usesInstanceMethod()); valueType = targetType.getTargetType(); semanticScope.putDecoration(userLambdaNode, new ReferenceDecoration(ref)); } @@ -2256,7 +2265,8 @@ public void visitFunctionRef(EFunctionRef userFunctionRefNode, SemanticScope sem TargetType targetType = semanticScope.getDecoration(userFunctionRefNode, TargetType.class); Class valueType; - if (symbol.equals("this") || type != null) { + boolean isInstanceReference = "this".equals(symbol); + if (isInstanceReference || type != null) { if (semanticScope.getCondition(userFunctionRefNode, Write.class)) { throw userFunctionRefNode.createError(new IllegalArgumentException( "invalid assignment: cannot assign a value to function reference [" + symbol + ":" + methodName + "]")); @@ -2267,14 +2277,16 @@ public void visitFunctionRef(EFunctionRef userFunctionRefNode, SemanticScope sem "not a statement: function reference [" + symbol + ":" + methodName + "] not used")); } + if (isInstanceReference) { + semanticScope.setCondition(userFunctionRefNode, InstanceCapturingFunctionRef.class); + } if (targetType == null) { valueType = String.class; - String defReferenceEncoding = "S" + symbol + "." + methodName + ",0"; - semanticScope.putDecoration(userFunctionRefNode, new EncodingDecoration(defReferenceEncoding)); + semanticScope.putDecoration(userFunctionRefNode, new EncodingDecoration(true, isInstanceReference, symbol, methodName, 0)); } else { FunctionRef ref = FunctionRef.create(scriptScope.getPainlessLookup(), scriptScope.getFunctionTable(), location, targetType.getTargetType(), symbol, methodName, 0, - scriptScope.getCompilerSettings().asMap()); + scriptScope.getCompilerSettings().asMap(), isInstanceReference); valueType = targetType.getTargetType(); semanticScope.putDecoration(userFunctionRefNode, new ReferenceDecoration(ref)); } @@ -2297,23 +2309,23 @@ public void visitFunctionRef(EFunctionRef userFunctionRefNode, SemanticScope sem } if (targetType == null) { - String defReferenceEncoding; + EncodingDecoration encodingDecoration; if (captured.getType() == def.class) { // dynamic implementation - defReferenceEncoding = "D" + symbol + "." + methodName + ",1"; + encodingDecoration = new EncodingDecoration(false, false, symbol, methodName, 1); } else { // typed implementation - defReferenceEncoding = "S" + captured.getCanonicalTypeName() + "." + methodName + ",1"; + encodingDecoration = new EncodingDecoration(true, false, captured.getCanonicalTypeName(), methodName, 1); } valueType = String.class; - semanticScope.putDecoration(userFunctionRefNode, new EncodingDecoration(defReferenceEncoding)); + semanticScope.putDecoration(userFunctionRefNode, encodingDecoration); } else { valueType = targetType.getTargetType(); // static case if (captured.getType() != def.class) { FunctionRef ref = FunctionRef.create(scriptScope.getPainlessLookup(), scriptScope.getFunctionTable(), location, targetType.getTargetType(), captured.getCanonicalTypeName(), methodName, 1, - scriptScope.getCompilerSettings().asMap()); + scriptScope.getCompilerSettings().asMap(), false); semanticScope.putDecoration(userFunctionRefNode, new ReferenceDecoration(ref)); } } @@ -2357,13 +2369,12 @@ public void visitNewArrayFunctionRef(ENewArrayFunctionRef userNewArrayFunctionRe semanticScope.putDecoration(userNewArrayFunctionRefNode, new MethodNameDecoration(name)); if (targetType == null) { - String defReferenceEncoding = "Sthis." + name + ",0"; valueType = String.class; - scriptScope.putDecoration(userNewArrayFunctionRefNode, new EncodingDecoration(defReferenceEncoding)); + scriptScope.putDecoration(userNewArrayFunctionRefNode, new EncodingDecoration(true, false, "this", name, 0)); } else { FunctionRef ref = FunctionRef.create(scriptScope.getPainlessLookup(), scriptScope.getFunctionTable(), userNewArrayFunctionRefNode.getLocation(), targetType.getTargetType(), "this", name, 0, - scriptScope.getCompilerSettings().asMap()); + scriptScope.getCompilerSettings().asMap(), false); valueType = targetType.getTargetType(); semanticScope.putDecoration(userNewArrayFunctionRefNode, new ReferenceDecoration(ref)); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java index 0bf3826f19a45..1c004c985b5b2 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java @@ -8,6 +8,7 @@ package org.elasticsearch.painless.phase; +import org.elasticsearch.painless.Def; import org.elasticsearch.painless.DefBootstrap; import org.elasticsearch.painless.FunctionRef; import org.elasticsearch.painless.Location; @@ -158,6 +159,8 @@ import org.elasticsearch.painless.symbol.Decorations.ExpressionPainlessCast; import org.elasticsearch.painless.symbol.Decorations.GetterPainlessMethod; import org.elasticsearch.painless.symbol.Decorations.IRNodeDecoration; +import org.elasticsearch.painless.symbol.Decorations.InstanceCapturingLambda; +import org.elasticsearch.painless.symbol.Decorations.InstanceCapturingFunctionRef; import org.elasticsearch.painless.symbol.Decorations.InstanceType; import org.elasticsearch.painless.symbol.Decorations.IterablePainlessMethod; import org.elasticsearch.painless.symbol.Decorations.ListShortcut; @@ -193,6 +196,7 @@ import org.elasticsearch.painless.symbol.IRDecorations.IRCCaptureBox; import org.elasticsearch.painless.symbol.IRDecorations.IRCContinuous; import org.elasticsearch.painless.symbol.IRDecorations.IRCInitialize; +import org.elasticsearch.painless.symbol.IRDecorations.IRCInstanceCapture; import org.elasticsearch.painless.symbol.IRDecorations.IRCRead; import org.elasticsearch.painless.symbol.IRDecorations.IRCStatic; import org.elasticsearch.painless.symbol.IRDecorations.IRCSynthetic; @@ -1356,7 +1360,12 @@ public void visitLambda(ELambda userLambdaNode, ScriptScope scriptScope) { new ArrayList<>(scriptScope.getDecoration(userLambdaNode, TypeParameters.class).getTypeParameters()))); irFunctionNode.attachDecoration(new IRDParameterNames( new ArrayList<>(scriptScope.getDecoration(userLambdaNode, ParameterNames.class).getParameterNames()))); - irFunctionNode.attachCondition(IRCStatic.class); + if (scriptScope.getCondition(userLambdaNode, InstanceCapturingLambda.class)) { + irFunctionNode.attachCondition(IRCInstanceCapture.class); + irExpressionNode.attachCondition(IRCInstanceCapture.class); + } else { + irFunctionNode.attachCondition(IRCStatic.class); + } irFunctionNode.attachCondition(IRCSynthetic.class); irFunctionNode.attachDecoration(new IRDMaxLoopCounter(scriptScope.getCompilerSettings().getMaxLoopCounter())); irClassNode.addFunctionNode(irFunctionNode); @@ -1386,9 +1395,12 @@ public void visitFunctionRef(EFunctionRef userFunctionRefNode, ScriptScope scrip CapturesDecoration capturesDecoration = scriptScope.getDecoration(userFunctionRefNode, CapturesDecoration.class); if (targetType == null) { - String encoding = scriptScope.getDecoration(userFunctionRefNode, EncodingDecoration.class).getEncoding(); + Def.Encoding encoding = scriptScope.getDecoration(userFunctionRefNode, EncodingDecoration.class).getEncoding(); DefInterfaceReferenceNode defInterfaceReferenceNode = new DefInterfaceReferenceNode(userFunctionRefNode.getLocation()); defInterfaceReferenceNode.attachDecoration(new IRDDefReferenceEncoding(encoding)); + if (scriptScope.getCondition(userFunctionRefNode, InstanceCapturingFunctionRef.class)) { + defInterfaceReferenceNode.attachCondition(IRCInstanceCapture.class); + } irReferenceNode = defInterfaceReferenceNode; } else if (capturesDecoration != null && capturesDecoration.getCaptures().get(0).getType() == def.class) { TypedCaptureReferenceNode typedCaptureReferenceNode = new TypedCaptureReferenceNode(userFunctionRefNode.getLocation()); @@ -1398,6 +1410,9 @@ public void visitFunctionRef(EFunctionRef userFunctionRefNode, ScriptScope scrip FunctionRef reference = scriptScope.getDecoration(userFunctionRefNode, ReferenceDecoration.class).getReference(); TypedInterfaceReferenceNode typedInterfaceReferenceNode = new TypedInterfaceReferenceNode(userFunctionRefNode.getLocation()); typedInterfaceReferenceNode.attachDecoration(new IRDReference(reference)); + if (scriptScope.getCondition(userFunctionRefNode, InstanceCapturingFunctionRef.class)) { + typedInterfaceReferenceNode.attachCondition(IRCInstanceCapture.class); + } irReferenceNode = typedInterfaceReferenceNode; } @@ -1427,7 +1442,7 @@ public void visitNewArrayFunctionRef(ENewArrayFunctionRef userNewArrayFunctionRe typedInterfaceReferenceNode.attachDecoration(new IRDReference(reference)); irReferenceNode = typedInterfaceReferenceNode; } else { - String encoding = scriptScope.getDecoration(userNewArrayFunctionRefNode, EncodingDecoration.class).getEncoding(); + Def.Encoding encoding = scriptScope.getDecoration(userNewArrayFunctionRefNode, EncodingDecoration.class).getEncoding(); DefInterfaceReferenceNode defInterfaceReferenceNode = new DefInterfaceReferenceNode(userNewArrayFunctionRefNode.getLocation()); defInterfaceReferenceNode.attachDecoration(new IRDDefReferenceEncoding(encoding)); irReferenceNode = defInterfaceReferenceNode; diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java index 5c5f503433fd3..bce418be4ce12 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java @@ -8,6 +8,7 @@ package org.elasticsearch.painless.symbol; +import org.elasticsearch.painless.Def; import org.elasticsearch.painless.FunctionRef; import org.elasticsearch.painless.ir.IRNode; import org.elasticsearch.painless.lookup.PainlessCast; @@ -513,13 +514,13 @@ public FunctionRef getReference() { public static class EncodingDecoration implements Decoration { - private final String encoding; + private final Def.Encoding encoding; - public EncodingDecoration(String encoding) { - this.encoding = Objects.requireNonNull(encoding); + public EncodingDecoration(boolean isStatic, boolean needsInstance, String symbol, String methodName, int captures) { + this.encoding = new Def.Encoding(isStatic, needsInstance, symbol, methodName, captures); } - public String getEncoding() { + public Def.Encoding getEncoding() { return encoding; } } @@ -610,4 +611,14 @@ public LocalFunction getConverter() { public interface IsDocument extends Condition { } + + // Does the lambda need to capture the enclosing instance? + public interface InstanceCapturingLambda extends Condition { + + } + + // Does the function reference need to capture the enclosing instance? + public interface InstanceCapturingFunctionRef extends Condition { + + } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/IRDecorations.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/IRDecorations.java index 7e0086b932272..303d5d3506f5e 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/IRDecorations.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/IRDecorations.java @@ -8,6 +8,7 @@ package org.elasticsearch.painless.symbol; +import org.elasticsearch.painless.Def; import org.elasticsearch.painless.FunctionRef; import org.elasticsearch.painless.Operation; import org.elasticsearch.painless.ir.IRNode.IRCondition; @@ -165,9 +166,9 @@ public IRDName(String value) { } /** describes an encoding used to resolve references and lambdas at runtime */ - public static class IRDDefReferenceEncoding extends IRDecoration { + public static class IRDDefReferenceEncoding extends IRDecoration { - public IRDDefReferenceEncoding(String value) { + public IRDDefReferenceEncoding(Def.Encoding value) { super(value); } } @@ -337,6 +338,14 @@ private IRCSynthetic() { } } + /** describes if a method needs to capture the script "this" */ + public static class IRCInstanceCapture implements IRCondition { + + private IRCInstanceCapture() { + + } + } + /** describes the maximum number of loop iterations possible in a method */ public static class IRDMaxLoopCounter extends IRDecoration { diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/SemanticScope.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/SemanticScope.java index 15d4e87de5aa6..ff29353ef761a 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/SemanticScope.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/SemanticScope.java @@ -137,6 +137,7 @@ public static class LambdaScope extends SemanticScope { protected final SemanticScope parent; protected final Class returnType; protected final Set captures = new HashSet<>(); + protected boolean usesInstanceMethod = false; protected LambdaScope(SemanticScope parent, Class returnType) { super(parent.scriptScope, parent.usedVariables); @@ -190,6 +191,19 @@ public String getReturnCanonicalTypeName() { public Set getCaptures() { return Collections.unmodifiableSet(captures); } + + @Override + public void setUsesInstanceMethod() { + if (usesInstanceMethod) { + return; + } + usesInstanceMethod = true; + } + + @Override + public boolean usesInstanceMethod() { + return usesInstanceMethod; + } } /** @@ -340,6 +354,13 @@ public Variable defineVariable(Location location, Class type, String name, bo public abstract boolean isVariableDefined(String name); public abstract Variable getVariable(Location location, String name); + // We only want to track instance method use inside of lambdas for "this" injection. It's a noop for other scopes. + public void setUsesInstanceMethod() {} + + public boolean usesInstanceMethod() { + return false; + } + public Variable defineInternalVariable(Location location, Class type, String name, boolean isReadOnly) { return defineVariable(location, type, "#" + name, isReadOnly); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/toxcontent/DecorationToXContent.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/toxcontent/DecorationToXContent.java index 8f47164a2b2ff..069113f88b522 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/toxcontent/DecorationToXContent.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/toxcontent/DecorationToXContent.java @@ -291,8 +291,8 @@ public static void ToXContent(ReferenceDecoration referenceDecoration, XContentB builder.endArray(); } - builder.field("factoryMethodType"); - ToXContent(ref.factoryMethodType, builder); + builder.field("factoryMethodDescriptor", ref.getFactoryMethodDescriptor()); + builder.endObject(); } diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/EmitTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/EmitTests.java new file mode 100644 index 0000000000000..037c73015dc20 --- /dev/null +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/EmitTests.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.painless; + +import org.elasticsearch.painless.spi.Whitelist; +import org.elasticsearch.painless.spi.WhitelistLoader; +import org.elasticsearch.script.ScriptContext; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class EmitTests extends ScriptTestCase { + @Override + protected Map, List> scriptContexts() { + Map, List> contexts = new HashMap<>(); + List whitelists = new ArrayList<>(Whitelist.BASE_WHITELISTS); + whitelists.add(WhitelistLoader.loadFromResourceFiles(Whitelist.class, "org.elasticsearch.painless.test")); + contexts.put(TestFieldScript.CONTEXT, whitelists); + return contexts; + } + + @Override + public TestFieldScript exec(String script) { + TestFieldScript.Factory factory = scriptEngine.compile(null, script, TestFieldScript.CONTEXT, new HashMap<>()); + TestFieldScript testScript = factory.newInstance(); + testScript.execute(); + return testScript; + } + + public void testEmit() { + TestFieldScript script = exec("emit(1L)"); + assertNotNull(script); + assertArrayEquals(new long[]{1L}, script.fetchValues()); + } + + public void testEmitFromUserFunction() { + TestFieldScript script = exec("void doEmit(long l) { emit(l) } doEmit(1L); doEmit(100L)"); + assertNotNull(script); + assertArrayEquals(new long[]{1L, 100L}, script.fetchValues()); + } +} diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/FunctionRefTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/FunctionRefTests.java index b95a1ea35558b..d7bf274370a69 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/FunctionRefTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/FunctionRefTests.java @@ -150,12 +150,12 @@ public void testCapturingMethodReferenceMultipleLambdasDefEverywhere() { "return test.twoFunctionsOfX(x::concat, y::substring);")); } - public void testOwnStaticMethodReference() { + public void testOwnMethodReference() { assertEquals(2, exec("int mycompare(int i, int j) { j - i } " + "List l = new ArrayList(); l.add(2); l.add(1); l.sort(this::mycompare); return l.get(0);")); } - public void testOwnStaticMethodReferenceDef() { + public void testOwnMethodReferenceDef() { assertEquals(2, exec("int mycompare(int i, int j) { j - i } " + "def l = new ArrayList(); l.add(2); l.add(1); l.sort(this::mycompare); return l.get(0);")); } diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/TestFieldScript.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/TestFieldScript.java new file mode 100644 index 0000000000000..9982bddf46e1d --- /dev/null +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/TestFieldScript.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.painless; + +import org.elasticsearch.script.ScriptContext; + +import java.util.ArrayList; +import java.util.List; + +public abstract class TestFieldScript { + private final List values = new ArrayList<>(); + + @SuppressWarnings("unused") + public static final String[] PARAMETERS = {}; + public interface Factory { + TestFieldScript newInstance(); + } + + public static final ScriptContext CONTEXT = + new ScriptContext<>("painless_test_fieldscript", TestFieldScript.Factory.class); + + public static class Emit { + private final TestFieldScript script; + + public Emit(TestFieldScript script) { + this.script = script; + } + + public void emit(long v) { + script.emit(v); + } + } + + public abstract void execute(); + + public final void emit(long v) { + values.add(v); + } + + public long[] fetchValues() { + return values.stream().mapToLong(i->i).toArray(); + } +} diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/ToXContentTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ToXContentTests.java index e6a74a7cff64e..4ef3953ae8065 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/ToXContentTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ToXContentTests.java @@ -25,7 +25,7 @@ public class ToXContentTests extends ScriptTestCase { public void testUserFunction() { Map func = getFunction("def twofive(int i) { return 25 + i; } int j = 23; twofive(j)", "twofive"); assertFalse((Boolean)func.get("isInternal")); - assertTrue((Boolean)func.get("isStatic")); + assertFalse((Boolean)func.get("isStatic")); assertEquals("SFunction", func.get("node")); assertEquals("def", func.get("returns")); assertEquals(List.of("int"), func.get("parameterTypes")); diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/UserFunctionTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/UserFunctionTests.java index 6ac4ac1483c07..175fa03614d98 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/UserFunctionTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/UserFunctionTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.painless; import java.util.List; +import java.util.Map; public class UserFunctionTests extends ScriptTestCase { public void testZeroArgumentUserFunction() { @@ -28,10 +29,126 @@ public void testUserFunctionDefCallRef() { "if (getSource().startsWith('sour')) { l.add(255); }\n" + "return l;"; assertEquals(List.of(1, 49, 100, 255), exec(source)); - assertBytecodeExists(source, "public static &getSource()Ljava/lang/String"); - assertBytecodeExists(source, "public static &getMulti()I"); - assertBytecodeExists(source, "INVOKESTATIC org/elasticsearch/painless/PainlessScript$Script.&getMulti ()I"); - assertBytecodeExists(source, "public static &myCompare(II)I"); - assertBytecodeExists(source, "INVOKESTATIC org/elasticsearch/painless/PainlessScript$Script.&myCompare (II)I"); + assertBytecodeExists(source, "public &getSource()Ljava/lang/String"); + assertBytecodeExists(source, "public &getMulti()I"); + assertBytecodeExists(source, "INVOKEVIRTUAL org/elasticsearch/painless/PainlessScript$Script.&getMulti ()I"); + assertBytecodeExists(source, "public &myCompare(II)I"); + assertBytecodeExists(source, "INVOKEVIRTUAL org/elasticsearch/painless/PainlessScript$Script.&myCompare (II)I"); + } + + public void testChainedUserMethods() { + String source = "int myCompare(int a, int b) { getMulti() * (a - b) }\n" + + "int getMulti() { -1 }\n" + + "List l = [1, 100, -100];\n" + + "l.sort(this::myCompare);\n" + + "l;\n"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + } + + + public void testChainedUserMethodsLambda() { + String source = "int myCompare(int a, int b) { getMulti() * (a - b) }\n" + + "int getMulti() { -1 }\n" + + "List l = [1, 100, -100];\n" + + "l.sort((a, b) -> myCompare(a, b));\n" + + "l;\n"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + } + + public void testChainedUserMethodsDef() { + String source = "int myCompare(int a, int b) { getMulti() * (a - b) }\n" + + "int getMulti() { -1 }\n" + + "def l = [1, 100, -100];\n" + + "l.sort(this::myCompare);\n" + + "l;\n"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + } + + + public void testChainedUserMethodsLambdaDef() { + String source = "int myCompare(int a, int b) { getMulti() * (a - b) }\n" + + "int getMulti() { -1 }\n" + + "def l = [1, 100, -100];\n" + + "l.sort((a, b) -> myCompare(a, b));\n" + + "l;\n"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + } + + public void testChainedUserMethodsLambdaCaptureDef() { + String source = "int myCompare(int a, int b, int x, int m) { getMulti(m) * (a - b + x) }\n" + + "int getMulti(int m) { -1 * m }\n" + + "def l = [1, 100, -100];\n" + + "int cx = 100;\n" + + "int cm = 1;\n" + + "l.sort((a, b) -> myCompare(a, b, cx, cm));\n" + + "l;\n"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + } + + public void testMethodReferenceInUserFunction() { + String source = "int myCompare(int a, int b, String s) { " + + " Map m = ['f': 5];" + + " a - b + m.computeIfAbsent(s, this::getLength) " + + "}\n" + + "int getLength(String s) { s.length() }\n" + + "def l = [1, 0, -2];\n" + + "String s = 'g';\n" + + "l.sort((a, b) -> myCompare(a, b, s));\n" + + "l;\n"; + assertEquals(List.of(-2, 1, 0), exec(source, Map.of("a", 1), false)); + } + + public void testUserFunctionVirtual() { + String source = "int myCompare(int x, int y) { return -1 * (x - y) }\n" + + "return myCompare(100, 90);"; + assertEquals(-10, exec(source, Map.of("a", 1), false)); + assertBytecodeExists(source, "INVOKEVIRTUAL org/elasticsearch/painless/PainlessScript$Script.&myCompare (II)I"); + } + + public void testUserFunctionRef() { + String source = "int myCompare(int x, int y) { return -1 * x - y }\n" + + "List l = [1, 100, -100];\n" + + "l.sort(this::myCompare);\n" + + "return l;"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + assertBytecodeExists(source, "public &myCompare(II)I"); + } + + public void testUserFunctionRefEmpty() { + String source = "int myCompare(int x, int y) { return -1 * x - y }\n" + + "[].sort((a, b) -> myCompare(a, b));\n"; + assertNull(exec(source, Map.of("a", 1), false)); + assertBytecodeExists(source, "public &myCompare(II)I"); + assertBytecodeExists(source, "INVOKEVIRTUAL org/elasticsearch/painless/PainlessScript$Script.&myCompare (II)I"); + } + + public void testUserFunctionCallInLambda() { + String source = "int myCompare(int x, int y) { -1 * ( x - y ) }\n" + + "List l = [1, 100, -100];\n" + + "l.sort((a, b) -> myCompare(a, b));\n" + + "return l;"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + assertBytecodeExists(source, "public &myCompare(II)I"); + assertBytecodeExists(source, "INVOKEVIRTUAL org/elasticsearch/painless/PainlessScript$Script.&myCompare (II)I"); + } + + public void testUserFunctionLambdaCapture() { + String source = "int myCompare(Object o, int x, int y) { return o != null ? -1 * ( x - y ) : ( x - y ) }\n" + + "List l = [1, 100, -100];\n" + + "Object q = '';\n" + + "l.sort((a, b) -> myCompare(q, a, b));\n" + + "return l;"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + assertBytecodeExists(source, "public &myCompare(Ljava/lang/Object;II)I"); + assertBytecodeExists(source, "INVOKEVIRTUAL org/elasticsearch/painless/PainlessScript$Script.&myCompare (Ljava/lang/Object;II)I"); + } + + public void testLambdaCapture() { + String source = "List l = [1, 100, -100];\n" + + "int q = -1;\n" + + "l.sort((a, b) -> q * ( a - b ));\n" + + "return l;"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + assertBytecodeExists(source, "public static synthetic lambda$synthetic$0(ILjava/lang/Object;Ljava/lang/Object;)I"); } } diff --git a/modules/lang-painless/src/test/resources/org/elasticsearch/painless/spi/org.elasticsearch.painless.test b/modules/lang-painless/src/test/resources/org/elasticsearch/painless/spi/org.elasticsearch.painless.test index 28c032418f373..c6554db9169ac 100644 --- a/modules/lang-painless/src/test/resources/org/elasticsearch/painless/spi/org.elasticsearch.painless.test +++ b/modules/lang-painless/src/test/resources/org/elasticsearch/painless/spi/org.elasticsearch.painless.test @@ -15,6 +15,13 @@ class org.elasticsearch.painless.api.Json { class org.elasticsearch.painless.BindingsTests$BindingsTestScript { } +# Runtime-field-like test objects +class org.elasticsearch.painless.TestFieldScript @no_import { +} +class org.elasticsearch.painless.TestFieldScript$Factory @no_import { +} + + class org.elasticsearch.painless.FeatureTestObject @no_import { int z () @@ -55,4 +62,5 @@ static_import { int classMul(int, int) from_class org.elasticsearch.painless.BindingsTests @compile_time_only int compileTimeBlowUp(int, int) from_class org.elasticsearch.painless.BindingsTests @compile_time_only List fancyConstant(String, String) from_class org.elasticsearch.painless.BindingsTests @compile_time_only + void emit(org.elasticsearch.painless.TestFieldScript, long) bound_to org.elasticsearch.painless.TestFieldScript$Emit }