From e26fa4e6da7c3e2d2526f4da1c2c57e46dd1fca4 Mon Sep 17 00:00:00 2001 From: Stuart Tettemer Date: Thu, 1 Jul 2021 08:41:44 -0500 Subject: [PATCH] Script: User funcs are script instance methods (#74268) User defined functions are instance methods on the Script class. Update lambdas and method references to capture the script `this` reference. Def method encoding string takes an extra char at index 1, whether to capture the script reference. For runtime fields, this means emit, which is an script instance method already, now works in user defined functions. Fixes: #69742 Refs: #68235 --- .../java/org/elasticsearch/painless/Def.java | 55 ++++++-- .../elasticsearch/painless/FunctionRef.java | 38 +++++- .../painless/LambdaBootstrap.java | 30 ++++- .../elasticsearch/painless/MethodWriter.java | 7 +- .../elasticsearch/painless/antlr/Walker.java | 2 +- .../phase/DefaultIRTreeToASMBytesPhase.java | 28 +++- .../phase/DefaultSemanticAnalysisPhase.java | 45 ++++--- .../phase/DefaultUserTreeToIRTreePhase.java | 21 ++- .../painless/symbol/Decorations.java | 19 ++- .../painless/symbol/IRDecorations.java | 13 +- .../painless/symbol/SemanticScope.java | 21 +++ .../toxcontent/DecorationToXContent.java | 4 +- .../org/elasticsearch/painless/EmitTests.java | 49 +++++++ .../painless/FunctionRefTests.java | 4 +- .../painless/TestFieldScript.java | 49 +++++++ .../painless/ToXContentTests.java | 2 +- .../painless/UserFunctionTests.java | 127 +++++++++++++++++- .../spi/org.elasticsearch.painless.test | 8 ++ 18 files changed, 456 insertions(+), 66 deletions(-) create mode 100644 modules/lang-painless/src/test/java/org/elasticsearch/painless/EmitTests.java create mode 100644 modules/lang-painless/src/test/java/org/elasticsearch/painless/TestFieldScript.java diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java index ebbd67bf7cd15..630bc98a9656e 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java @@ -222,6 +222,10 @@ static MethodHandle lookupMethod(PainlessLookup painlessLookup, FunctionTable fu String signature = (String) args[upTo++]; int numCaptures = Integer.parseInt(signature.substring(signature.indexOf(',')+1)); arity -= numCaptures; + // arity in painlessLookup does not include 'this' reference + if (signature.charAt(1) == 't') { + arity--; + } } } @@ -251,11 +255,12 @@ static MethodHandle lookupMethod(PainlessLookup painlessLookup, FunctionTable fu String signature = (String) args[upTo++]; int separator = signature.lastIndexOf('.'); int separator2 = signature.indexOf(','); - String type = signature.substring(1, separator); + String type = signature.substring(2, separator); + boolean needsScriptInstance = signature.charAt(1) == 't'; String call = signature.substring(separator+1, separator2); int numCaptures = Integer.parseInt(signature.substring(separator2+1)); MethodHandle filter; - Class interfaceType = method.typeParameters.get(i - 1 - replaced); + Class interfaceType = method.typeParameters.get(i - 1 - replaced - (needsScriptInstance ? 1 : 0)); if (signature.charAt(0) == 'S') { // the implementation is strongly typed, now that we know the interface type, // we have everything. @@ -266,7 +271,8 @@ static MethodHandle lookupMethod(PainlessLookup painlessLookup, FunctionTable fu interfaceType, type, call, - numCaptures + numCaptures, + needsScriptInstance ); } else if (signature.charAt(0) == 'D') { // the interface type is now known, but we need to get the implementation. @@ -292,7 +298,7 @@ static MethodHandle lookupMethod(PainlessLookup painlessLookup, FunctionTable fu } // the filter now ignores the signature (placeholder) on the stack filter = MethodHandles.dropArguments(filter, 0, String.class); - handle = MethodHandles.collectArguments(handle, i, filter); + handle = MethodHandles.collectArguments(handle, i - (needsScriptInstance ? 1 : 0), filter); i += numCaptures; replaced += numCaptures; } @@ -328,20 +334,23 @@ static MethodHandle lookupReference(PainlessLookup painlessLookup, FunctionTable return lookupReferenceInternal(painlessLookup, functions, constants, methodHandlesLookup, interfaceType, PainlessLookupUtility.typeToCanonicalTypeName(implMethod.targetClass), - implMethod.javaMethod.getName(), 1); + implMethod.javaMethod.getName(), 1, false); } /** Returns a method handle to an implementation of clazz, given method reference signature. */ private static MethodHandle lookupReferenceInternal( PainlessLookup painlessLookup, FunctionTable functions, Map constants, - MethodHandles.Lookup methodHandlesLookup, Class clazz, String type, String call, int captures - ) throws Throwable { + MethodHandles.Lookup methodHandlesLookup, Class clazz, String type, String call, int captures, + boolean needsScriptInstance) throws Throwable { - final FunctionRef ref = FunctionRef.create(painlessLookup, functions, null, clazz, type, call, captures, constants); + final FunctionRef ref = + FunctionRef.create(painlessLookup, functions, null, clazz, type, call, captures, constants, needsScriptInstance); + Class[] parameters = ref.factoryMethodParameters(needsScriptInstance ? methodHandlesLookup.lookupClass() : null); + MethodType factoryMethodType = MethodType.methodType(clazz, parameters); final CallSite callSite = LambdaBootstrap.lambdaBootstrap( methodHandlesLookup, ref.interfaceMethodName, - ref.factoryMethodType, + factoryMethodType, ref.interfaceMethodType, ref.delegateClassName, ref.delegateInvokeType, @@ -351,7 +360,7 @@ private static MethodHandle lookupReferenceInternal( ref.isDelegateAugmented ? 1 : 0, ref.delegateInjections ); - return callSite.dynamicInvoker().asType(MethodType.methodType(clazz, ref.factoryMethodType.parameterArray())); + return callSite.dynamicInvoker().asType(MethodType.methodType(clazz, parameters)); } /** @@ -1268,4 +1277,30 @@ static MethodHandle arrayIndexNormalizer(Class arrayType) { private ArrayIndexNormalizeHelper() {} } + + public static class Encoding { + public final boolean isStatic; + public final boolean needsInstance; + public final String symbol; + public final String methodName; + public final int numCaptures; + public final String encoding; + + public Encoding(boolean isStatic, boolean needsInstance, String symbol, String methodName, int numCaptures) { + this.isStatic = isStatic; + this.needsInstance = needsInstance; + this.symbol = symbol; + this.methodName = methodName; + this.numCaptures = numCaptures; + this.encoding = (isStatic ? "S" : "D") + (needsInstance ? "t" : "f") + + symbol + "." + + methodName + "," + + numCaptures; + } + + @Override + public String toString() { + return encoding; + } + } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/FunctionRef.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/FunctionRef.java index 6d6d6651c053a..ff99e0d28da90 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/FunctionRef.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/FunctionRef.java @@ -14,6 +14,7 @@ import org.elasticsearch.painless.lookup.PainlessMethod; import org.elasticsearch.painless.symbol.FunctionTable; import org.elasticsearch.painless.symbol.FunctionTable.LocalFunction; +import org.objectweb.asm.Type; import java.lang.invoke.MethodType; import java.lang.reflect.Modifier; @@ -21,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; import static org.elasticsearch.painless.WriterConstants.CLASS_NAME; import static org.objectweb.asm.Opcodes.H_INVOKEINTERFACE; @@ -44,9 +46,11 @@ public class FunctionRef { * @param methodName the right hand side of a method reference expression * @param numberOfCaptures number of captured arguments * @param constants constants used for injection when necessary + * @param needsScriptInstance uses an instance method and so receiver must be captured. */ public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable functionTable, Location location, - Class targetClass, String typeName, String methodName, int numberOfCaptures, Map constants) { + Class targetClass, String typeName, String methodName, int numberOfCaptures, Map constants, + boolean needsScriptInstance) { Objects.requireNonNull(painlessLookup); Objects.requireNonNull(targetClass); @@ -98,7 +102,7 @@ public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable fu delegateClassName = CLASS_NAME; isDelegateInterface = false; isDelegateAugmented = false; - delegateInvokeType = H_INVOKESTATIC; + delegateInvokeType = needsScriptInstance ? H_INVOKEVIRTUAL : H_INVOKESTATIC; delegateMethodName = localFunction.getMangledName(); delegateMethodType = localFunction.getMethodType(); delegateInjections = new Object[0]; @@ -213,7 +217,7 @@ public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable fu return new FunctionRef(interfaceMethodName, interfaceMethodType, delegateClassName, isDelegateInterface, isDelegateAugmented, delegateInvokeType, delegateMethodName, delegateMethodType, delegateInjections, - factoryMethodType + factoryMethodType, needsScriptInstance ? WriterConstants.CLASS_TYPE : null ); } catch (IllegalArgumentException iae) { if (location != null) { @@ -243,13 +247,15 @@ public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable fu /** injected constants */ public final Object[] delegateInjections; /** factory (CallSite) method signature */ - public final MethodType factoryMethodType; + private final MethodType factoryMethodType; + /** factory (CallSite) method receiver, this modifies the method descriptor for the factory method */ + public final Type factoryMethodReceiver; private FunctionRef( String interfaceMethodName, MethodType interfaceMethodType, String delegateClassName, boolean isDelegateInterface, boolean isDelegateAugmented, int delegateInvokeType, String delegateMethodName, MethodType delegateMethodType, Object[] delegateInjections, - MethodType factoryMethodType) { + MethodType factoryMethodType, Type factoryMethodReceiver) { this.interfaceMethodName = interfaceMethodName; this.interfaceMethodType = interfaceMethodType; @@ -261,5 +267,27 @@ private FunctionRef( this.delegateMethodType = delegateMethodType; this.delegateInjections = delegateInjections; this.factoryMethodType = factoryMethodType; + this.factoryMethodReceiver = factoryMethodReceiver; + } + + /** Get the factory method type, with updated receiver if {@code factoryMethodReceiver} is set */ + public String getFactoryMethodDescriptor() { + if (factoryMethodReceiver == null) { + return factoryMethodType.toMethodDescriptorString(); + } + List arguments = factoryMethodType.parameterList().stream().map(Type::getType).collect(Collectors.toList()); + arguments.add(0, factoryMethodReceiver); + Type[] argArray = new Type[arguments.size()]; + arguments.toArray(argArray); + return Type.getMethodDescriptor(Type.getType(factoryMethodType.returnType()), argArray); + } + + /** Get the factory method type, updating the receiver if {@code factoryMethodReceiverClass} is non-null */ + public Class[] factoryMethodParameters(Class factoryMethodReceiverClass) { + List> parameters = new ArrayList<>(factoryMethodType.parameterList()); + if (factoryMethodReceiverClass != null) { + parameters.add(0, factoryMethodReceiverClass); + } + return parameters.toArray(new Class[0]); } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/LambdaBootstrap.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/LambdaBootstrap.java index 6492ae361b1f9..6d8d58e43ede3 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/LambdaBootstrap.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/LambdaBootstrap.java @@ -23,6 +23,8 @@ import java.lang.invoke.MethodType; import java.security.AccessController; import java.security.PrivilegedAction; +import java.util.List; +import java.util.stream.Collectors; import static java.lang.invoke.MethodHandles.Lookup; import static org.elasticsearch.painless.WriterConstants.CLASS_VERSION; @@ -392,6 +394,8 @@ private static void generateInterfaceMethod( // Loads any passed in arguments onto the stack. iface.loadArgs(); + String functionalInterfaceWithCaptures; + // Handles the case for a lambda function or a static reference method. // interfaceMethodType and delegateMethodType both have the captured types // inserted into their type signatures. This later allows the delegate @@ -402,6 +406,7 @@ private static void generateInterfaceMethod( if (delegateInvokeType == H_INVOKESTATIC) { interfaceMethodType = interfaceMethodType.insertParameterTypes(0, factoryMethodType.parameterArray()); + functionalInterfaceWithCaptures = interfaceMethodType.toMethodDescriptorString(); delegateMethodType = delegateMethodType.insertParameterTypes(0, factoryMethodType.parameterArray()); } else if (delegateInvokeType == H_INVOKEVIRTUAL || @@ -414,19 +419,32 @@ private static void generateInterfaceMethod( Class clazz = delegateMethodType.parameterType(0); delegateClassType = Type.getType(clazz); delegateMethodType = delegateMethodType.dropParameterTypes(0, 1); + functionalInterfaceWithCaptures = interfaceMethodType.toMethodDescriptorString(); // Handles the case for a virtual or interface reference method with 'this' // captured. interfaceMethodType inserts the 'this' type into its // method signature. This later allows the delegate // method to be invoked dynamically and have the interface method types // appropriately converted to the delegate method types. // Example: something::toString - } else if (captures.length == 1) { + } else { Class clazz = factoryMethodType.parameterType(0); delegateClassType = Type.getType(clazz); - interfaceMethodType = interfaceMethodType.insertParameterTypes(0, clazz); - } else { - throw new LambdaConversionException( - "unexpected number of captures [ " + captures.length + "]"); + + // functionalInterfaceWithCaptures needs to add the receiver and other captures + List parameters = interfaceMethodType.parameterList().stream().map(Type::getType).collect(Collectors.toList()); + parameters.add(0, delegateClassType); + for (int i = 1; i < captures.length; i++) { + parameters.add(i, captures[i].type); + } + Type[] parametersArray = parameters.toArray(new Type[0]); + functionalInterfaceWithCaptures = Type.getMethodDescriptor(Type.getType(interfaceMethodType.returnType()), parametersArray); + + // delegateMethod does not need the receiver + List> factoryParameters = factoryMethodType.parameterList(); + if (factoryParameters.size() > 1) { + List> factoryParametersWithReceiver = factoryParameters.subList(1, factoryParameters.size()); + delegateMethodType = delegateMethodType.insertParameterTypes(0, factoryParametersWithReceiver); + } } } else { throw new IllegalStateException( @@ -445,7 +463,7 @@ private static void generateInterfaceMethod( System.arraycopy(injections, 0, args, 2, injections.length); iface.invokeDynamic( delegateMethodName, - Type.getMethodType(interfaceMethodType.toMethodDescriptorString()).getDescriptor(), + Type.getMethodType(functionalInterfaceWithCaptures).getDescriptor(), DELEGATE_BOOTSTRAP_HANDLE, args); diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java index 07f0b122ff1a5..237a006293a8d 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java @@ -514,11 +514,6 @@ public void invokeLambdaCall(FunctionRef functionRef) { args[6] = functionRef.isDelegateAugmented ? 1 : 0; System.arraycopy(functionRef.delegateInjections, 0, args, 7, functionRef.delegateInjections.length); - invokeDynamic( - functionRef.interfaceMethodName, - functionRef.factoryMethodType.toMethodDescriptorString(), - LAMBDA_BOOTSTRAP_HANDLE, - args - ); + invokeDynamic(functionRef.interfaceMethodName, functionRef.getFactoryMethodDescriptor(), LAMBDA_BOOTSTRAP_HANDLE, args); } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/antlr/Walker.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/antlr/Walker.java index 91c1ba2c7f7a6..608ad449be3bb 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/antlr/Walker.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/antlr/Walker.java @@ -273,7 +273,7 @@ public ANode visitFunction(FunctionContext ctx) { } return new SFunction(nextIdentifier(), location(ctx), - rtnType, name, paramTypes, paramNames, new SBlock(nextIdentifier(), location(ctx), statements), false, true, false, false); + rtnType, name, paramTypes, paramNames, new SBlock(nextIdentifier(), location(ctx), statements), false, false, false, false); } @Override diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java index 5abc28a2a1e80..8fc2a377afc43 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java @@ -99,6 +99,7 @@ import org.elasticsearch.painless.symbol.IRDecorations.IRCCaptureBox; import org.elasticsearch.painless.symbol.IRDecorations.IRCContinuous; import org.elasticsearch.painless.symbol.IRDecorations.IRCInitialize; +import org.elasticsearch.painless.symbol.IRDecorations.IRCInstanceCapture; import org.elasticsearch.painless.symbol.IRDecorations.IRCStatic; import org.elasticsearch.painless.symbol.IRDecorations.IRCSynthetic; import org.elasticsearch.painless.symbol.IRDecorations.IRCVarArgs; @@ -1226,6 +1227,11 @@ public void visitDefInterfaceReference(DefInterfaceReferenceNode irDefInterfaceR // which is resolved and replace at runtime methodWriter.push((String)null); + if (irDefInterfaceReferenceNode.hasCondition(IRCInstanceCapture.class)) { + Variable capturedThis = writeScope.getInternalVariable("this"); + methodWriter.visitVarInsn(CLASS_TYPE.getOpcode(Opcodes.ILOAD), capturedThis.getSlot()); + } + List captureNames = irDefInterfaceReferenceNode.getDecorationValue(IRDCaptureNames.class); boolean captureBox = irDefInterfaceReferenceNode.hasCondition(IRCCaptureBox.class); @@ -1247,6 +1253,11 @@ public void visitTypedInterfaceReference(TypedInterfaceReferenceNode irTypedInte MethodWriter methodWriter = writeScope.getMethodWriter(); methodWriter.writeDebugInfo(irTypedInterfaceReferenceNode.getLocation()); + if (irTypedInterfaceReferenceNode.hasCondition(IRCInstanceCapture.class)) { + Variable capturedThis = writeScope.getInternalVariable("this"); + methodWriter.visitVarInsn(CLASS_TYPE.getOpcode(Opcodes.ILOAD), capturedThis.getSlot()); + } + List captureNames = irTypedInterfaceReferenceNode.getDecorationValue(IRDCaptureNames.class); boolean captureBox = irTypedInterfaceReferenceNode.hasCondition(IRCCaptureBox.class); @@ -1576,7 +1587,12 @@ public void visitInvokeCallDef(InvokeCallDefNode irInvokeCallDefNode, WriteScope DefInterfaceReferenceNode defInterfaceReferenceNode = (DefInterfaceReferenceNode)irArgumentNode; List captureNames = defInterfaceReferenceNode.getDecorationValueOrDefault(IRDCaptureNames.class, Collections.emptyList()); - boostrapArguments.add(defInterfaceReferenceNode.getDecorationValue(IRDDefReferenceEncoding.class)); + boostrapArguments.add(defInterfaceReferenceNode.getDecorationValue(IRDDefReferenceEncoding.class).toString()); + + if (defInterfaceReferenceNode.hasCondition(IRCInstanceCapture.class)) { + capturedCount++; + typeParameters.add(ScriptThis.class); + } // the encoding uses a char to indicate the number of captures // where the value is the number of current arguments plus the @@ -1596,7 +1612,12 @@ public void visitInvokeCallDef(InvokeCallDefNode irInvokeCallDefNode, WriteScope Type[] asmParameterTypes = new Type[typeParameters.size()]; for (int index = 0; index < asmParameterTypes.length; ++index) { - asmParameterTypes[index] = MethodWriter.getType(typeParameters.get(index)); + Class typeParameter = typeParameters.get(index); + if (typeParameter.equals(ScriptThis.class)) { + asmParameterTypes[index] = CLASS_TYPE; + } else { + asmParameterTypes[index] = MethodWriter.getType(typeParameters.get(index)); + } } String methodName = irInvokeCallDefNode.getDecorationValue(IRDName.class); @@ -1763,4 +1784,7 @@ public void visitDup(DupNode irDupNode, WriteScope writeScope) { methodWriter.writeDup(size, depth); } + + // placeholder class referring to the script instance + private static final class ScriptThis {} } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java index fc744b51325f8..769702afe94b3 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java @@ -88,6 +88,8 @@ import org.elasticsearch.painless.symbol.Decorations.ExpressionPainlessCast; import org.elasticsearch.painless.symbol.Decorations.GetterPainlessMethod; import org.elasticsearch.painless.symbol.Decorations.InLoop; +import org.elasticsearch.painless.symbol.Decorations.InstanceCapturingFunctionRef; +import org.elasticsearch.painless.symbol.Decorations.InstanceCapturingLambda; import org.elasticsearch.painless.symbol.Decorations.InstanceType; import org.elasticsearch.painless.symbol.Decorations.Internal; import org.elasticsearch.painless.symbol.Decorations.IterablePainlessMethod; @@ -1726,7 +1728,9 @@ public void visitCallLocal(ECallLocal userCallLocalNode, SemanticScope semanticS localFunction = null; } - if (localFunction == null) { + if (localFunction != null) { + semanticScope.setUsesInstanceMethod(); + } else { importedMethod = scriptScope.getPainlessLookup().lookupImportedPainlessMethod(methodName, userArgumentsSize); if (importedMethod == null) { @@ -2195,6 +2199,10 @@ public void visitLambda(ELambda userLambdaNode, SemanticScope semanticScope) { semanticScope.setCondition(userBlockNode, LastSource.class); visit(userBlockNode, lambdaScope); + if (lambdaScope.usesInstanceMethod()) { + semanticScope.setCondition(userLambdaNode, InstanceCapturingLambda.class); + } + if (semanticScope.getCondition(userBlockNode, MethodEscape.class) == false) { throw userLambdaNode.createError(new IllegalArgumentException("not all paths return a value for lambda")); } @@ -2214,18 +2222,19 @@ public void visitLambda(ELambda userLambdaNode, SemanticScope semanticScope) { // desugar lambda body into a synthetic method String name = scriptScope.getNextSyntheticName("lambda"); - scriptScope.getFunctionTable().addFunction(name, returnType, typeParametersWithCaptures, true, true); + boolean isStatic = lambdaScope.usesInstanceMethod() == false; + scriptScope.getFunctionTable().addFunction(name, returnType, typeParametersWithCaptures, true, isStatic); Class valueType; // setup method reference to synthetic method if (targetType == null) { - String defReferenceEncoding = "Sthis." + name + "," + capturedVariables.size(); valueType = String.class; - semanticScope.putDecoration(userLambdaNode, new EncodingDecoration(defReferenceEncoding)); + semanticScope.putDecoration(userLambdaNode, + new EncodingDecoration(true, lambdaScope.usesInstanceMethod(), "this", name, capturedVariables.size())); } else { FunctionRef ref = FunctionRef.create(scriptScope.getPainlessLookup(), scriptScope.getFunctionTable(), location, targetType.getTargetType(), "this", name, capturedVariables.size(), - scriptScope.getCompilerSettings().asMap()); + scriptScope.getCompilerSettings().asMap(), lambdaScope.usesInstanceMethod()); valueType = targetType.getTargetType(); semanticScope.putDecoration(userLambdaNode, new ReferenceDecoration(ref)); } @@ -2256,7 +2265,8 @@ public void visitFunctionRef(EFunctionRef userFunctionRefNode, SemanticScope sem TargetType targetType = semanticScope.getDecoration(userFunctionRefNode, TargetType.class); Class valueType; - if (symbol.equals("this") || type != null) { + boolean isInstanceReference = "this".equals(symbol); + if (isInstanceReference || type != null) { if (semanticScope.getCondition(userFunctionRefNode, Write.class)) { throw userFunctionRefNode.createError(new IllegalArgumentException( "invalid assignment: cannot assign a value to function reference [" + symbol + ":" + methodName + "]")); @@ -2267,14 +2277,16 @@ public void visitFunctionRef(EFunctionRef userFunctionRefNode, SemanticScope sem "not a statement: function reference [" + symbol + ":" + methodName + "] not used")); } + if (isInstanceReference) { + semanticScope.setCondition(userFunctionRefNode, InstanceCapturingFunctionRef.class); + } if (targetType == null) { valueType = String.class; - String defReferenceEncoding = "S" + symbol + "." + methodName + ",0"; - semanticScope.putDecoration(userFunctionRefNode, new EncodingDecoration(defReferenceEncoding)); + semanticScope.putDecoration(userFunctionRefNode, new EncodingDecoration(true, isInstanceReference, symbol, methodName, 0)); } else { FunctionRef ref = FunctionRef.create(scriptScope.getPainlessLookup(), scriptScope.getFunctionTable(), location, targetType.getTargetType(), symbol, methodName, 0, - scriptScope.getCompilerSettings().asMap()); + scriptScope.getCompilerSettings().asMap(), isInstanceReference); valueType = targetType.getTargetType(); semanticScope.putDecoration(userFunctionRefNode, new ReferenceDecoration(ref)); } @@ -2297,23 +2309,23 @@ public void visitFunctionRef(EFunctionRef userFunctionRefNode, SemanticScope sem } if (targetType == null) { - String defReferenceEncoding; + EncodingDecoration encodingDecoration; if (captured.getType() == def.class) { // dynamic implementation - defReferenceEncoding = "D" + symbol + "." + methodName + ",1"; + encodingDecoration = new EncodingDecoration(false, false, symbol, methodName, 1); } else { // typed implementation - defReferenceEncoding = "S" + captured.getCanonicalTypeName() + "." + methodName + ",1"; + encodingDecoration = new EncodingDecoration(true, false, captured.getCanonicalTypeName(), methodName, 1); } valueType = String.class; - semanticScope.putDecoration(userFunctionRefNode, new EncodingDecoration(defReferenceEncoding)); + semanticScope.putDecoration(userFunctionRefNode, encodingDecoration); } else { valueType = targetType.getTargetType(); // static case if (captured.getType() != def.class) { FunctionRef ref = FunctionRef.create(scriptScope.getPainlessLookup(), scriptScope.getFunctionTable(), location, targetType.getTargetType(), captured.getCanonicalTypeName(), methodName, 1, - scriptScope.getCompilerSettings().asMap()); + scriptScope.getCompilerSettings().asMap(), false); semanticScope.putDecoration(userFunctionRefNode, new ReferenceDecoration(ref)); } } @@ -2357,13 +2369,12 @@ public void visitNewArrayFunctionRef(ENewArrayFunctionRef userNewArrayFunctionRe semanticScope.putDecoration(userNewArrayFunctionRefNode, new MethodNameDecoration(name)); if (targetType == null) { - String defReferenceEncoding = "Sthis." + name + ",0"; valueType = String.class; - scriptScope.putDecoration(userNewArrayFunctionRefNode, new EncodingDecoration(defReferenceEncoding)); + scriptScope.putDecoration(userNewArrayFunctionRefNode, new EncodingDecoration(true, false, "this", name, 0)); } else { FunctionRef ref = FunctionRef.create(scriptScope.getPainlessLookup(), scriptScope.getFunctionTable(), userNewArrayFunctionRefNode.getLocation(), targetType.getTargetType(), "this", name, 0, - scriptScope.getCompilerSettings().asMap()); + scriptScope.getCompilerSettings().asMap(), false); valueType = targetType.getTargetType(); semanticScope.putDecoration(userNewArrayFunctionRefNode, new ReferenceDecoration(ref)); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java index 0bf3826f19a45..1c004c985b5b2 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java @@ -8,6 +8,7 @@ package org.elasticsearch.painless.phase; +import org.elasticsearch.painless.Def; import org.elasticsearch.painless.DefBootstrap; import org.elasticsearch.painless.FunctionRef; import org.elasticsearch.painless.Location; @@ -158,6 +159,8 @@ import org.elasticsearch.painless.symbol.Decorations.ExpressionPainlessCast; import org.elasticsearch.painless.symbol.Decorations.GetterPainlessMethod; import org.elasticsearch.painless.symbol.Decorations.IRNodeDecoration; +import org.elasticsearch.painless.symbol.Decorations.InstanceCapturingLambda; +import org.elasticsearch.painless.symbol.Decorations.InstanceCapturingFunctionRef; import org.elasticsearch.painless.symbol.Decorations.InstanceType; import org.elasticsearch.painless.symbol.Decorations.IterablePainlessMethod; import org.elasticsearch.painless.symbol.Decorations.ListShortcut; @@ -193,6 +196,7 @@ import org.elasticsearch.painless.symbol.IRDecorations.IRCCaptureBox; import org.elasticsearch.painless.symbol.IRDecorations.IRCContinuous; import org.elasticsearch.painless.symbol.IRDecorations.IRCInitialize; +import org.elasticsearch.painless.symbol.IRDecorations.IRCInstanceCapture; import org.elasticsearch.painless.symbol.IRDecorations.IRCRead; import org.elasticsearch.painless.symbol.IRDecorations.IRCStatic; import org.elasticsearch.painless.symbol.IRDecorations.IRCSynthetic; @@ -1356,7 +1360,12 @@ public void visitLambda(ELambda userLambdaNode, ScriptScope scriptScope) { new ArrayList<>(scriptScope.getDecoration(userLambdaNode, TypeParameters.class).getTypeParameters()))); irFunctionNode.attachDecoration(new IRDParameterNames( new ArrayList<>(scriptScope.getDecoration(userLambdaNode, ParameterNames.class).getParameterNames()))); - irFunctionNode.attachCondition(IRCStatic.class); + if (scriptScope.getCondition(userLambdaNode, InstanceCapturingLambda.class)) { + irFunctionNode.attachCondition(IRCInstanceCapture.class); + irExpressionNode.attachCondition(IRCInstanceCapture.class); + } else { + irFunctionNode.attachCondition(IRCStatic.class); + } irFunctionNode.attachCondition(IRCSynthetic.class); irFunctionNode.attachDecoration(new IRDMaxLoopCounter(scriptScope.getCompilerSettings().getMaxLoopCounter())); irClassNode.addFunctionNode(irFunctionNode); @@ -1386,9 +1395,12 @@ public void visitFunctionRef(EFunctionRef userFunctionRefNode, ScriptScope scrip CapturesDecoration capturesDecoration = scriptScope.getDecoration(userFunctionRefNode, CapturesDecoration.class); if (targetType == null) { - String encoding = scriptScope.getDecoration(userFunctionRefNode, EncodingDecoration.class).getEncoding(); + Def.Encoding encoding = scriptScope.getDecoration(userFunctionRefNode, EncodingDecoration.class).getEncoding(); DefInterfaceReferenceNode defInterfaceReferenceNode = new DefInterfaceReferenceNode(userFunctionRefNode.getLocation()); defInterfaceReferenceNode.attachDecoration(new IRDDefReferenceEncoding(encoding)); + if (scriptScope.getCondition(userFunctionRefNode, InstanceCapturingFunctionRef.class)) { + defInterfaceReferenceNode.attachCondition(IRCInstanceCapture.class); + } irReferenceNode = defInterfaceReferenceNode; } else if (capturesDecoration != null && capturesDecoration.getCaptures().get(0).getType() == def.class) { TypedCaptureReferenceNode typedCaptureReferenceNode = new TypedCaptureReferenceNode(userFunctionRefNode.getLocation()); @@ -1398,6 +1410,9 @@ public void visitFunctionRef(EFunctionRef userFunctionRefNode, ScriptScope scrip FunctionRef reference = scriptScope.getDecoration(userFunctionRefNode, ReferenceDecoration.class).getReference(); TypedInterfaceReferenceNode typedInterfaceReferenceNode = new TypedInterfaceReferenceNode(userFunctionRefNode.getLocation()); typedInterfaceReferenceNode.attachDecoration(new IRDReference(reference)); + if (scriptScope.getCondition(userFunctionRefNode, InstanceCapturingFunctionRef.class)) { + typedInterfaceReferenceNode.attachCondition(IRCInstanceCapture.class); + } irReferenceNode = typedInterfaceReferenceNode; } @@ -1427,7 +1442,7 @@ public void visitNewArrayFunctionRef(ENewArrayFunctionRef userNewArrayFunctionRe typedInterfaceReferenceNode.attachDecoration(new IRDReference(reference)); irReferenceNode = typedInterfaceReferenceNode; } else { - String encoding = scriptScope.getDecoration(userNewArrayFunctionRefNode, EncodingDecoration.class).getEncoding(); + Def.Encoding encoding = scriptScope.getDecoration(userNewArrayFunctionRefNode, EncodingDecoration.class).getEncoding(); DefInterfaceReferenceNode defInterfaceReferenceNode = new DefInterfaceReferenceNode(userNewArrayFunctionRefNode.getLocation()); defInterfaceReferenceNode.attachDecoration(new IRDDefReferenceEncoding(encoding)); irReferenceNode = defInterfaceReferenceNode; diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java index 5c5f503433fd3..bce418be4ce12 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/Decorations.java @@ -8,6 +8,7 @@ package org.elasticsearch.painless.symbol; +import org.elasticsearch.painless.Def; import org.elasticsearch.painless.FunctionRef; import org.elasticsearch.painless.ir.IRNode; import org.elasticsearch.painless.lookup.PainlessCast; @@ -513,13 +514,13 @@ public FunctionRef getReference() { public static class EncodingDecoration implements Decoration { - private final String encoding; + private final Def.Encoding encoding; - public EncodingDecoration(String encoding) { - this.encoding = Objects.requireNonNull(encoding); + public EncodingDecoration(boolean isStatic, boolean needsInstance, String symbol, String methodName, int captures) { + this.encoding = new Def.Encoding(isStatic, needsInstance, symbol, methodName, captures); } - public String getEncoding() { + public Def.Encoding getEncoding() { return encoding; } } @@ -610,4 +611,14 @@ public LocalFunction getConverter() { public interface IsDocument extends Condition { } + + // Does the lambda need to capture the enclosing instance? + public interface InstanceCapturingLambda extends Condition { + + } + + // Does the function reference need to capture the enclosing instance? + public interface InstanceCapturingFunctionRef extends Condition { + + } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/IRDecorations.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/IRDecorations.java index 7e0086b932272..303d5d3506f5e 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/IRDecorations.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/IRDecorations.java @@ -8,6 +8,7 @@ package org.elasticsearch.painless.symbol; +import org.elasticsearch.painless.Def; import org.elasticsearch.painless.FunctionRef; import org.elasticsearch.painless.Operation; import org.elasticsearch.painless.ir.IRNode.IRCondition; @@ -165,9 +166,9 @@ public IRDName(String value) { } /** describes an encoding used to resolve references and lambdas at runtime */ - public static class IRDDefReferenceEncoding extends IRDecoration { + public static class IRDDefReferenceEncoding extends IRDecoration { - public IRDDefReferenceEncoding(String value) { + public IRDDefReferenceEncoding(Def.Encoding value) { super(value); } } @@ -337,6 +338,14 @@ private IRCSynthetic() { } } + /** describes if a method needs to capture the script "this" */ + public static class IRCInstanceCapture implements IRCondition { + + private IRCInstanceCapture() { + + } + } + /** describes the maximum number of loop iterations possible in a method */ public static class IRDMaxLoopCounter extends IRDecoration { diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/SemanticScope.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/SemanticScope.java index 15d4e87de5aa6..ff29353ef761a 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/SemanticScope.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/symbol/SemanticScope.java @@ -137,6 +137,7 @@ public static class LambdaScope extends SemanticScope { protected final SemanticScope parent; protected final Class returnType; protected final Set captures = new HashSet<>(); + protected boolean usesInstanceMethod = false; protected LambdaScope(SemanticScope parent, Class returnType) { super(parent.scriptScope, parent.usedVariables); @@ -190,6 +191,19 @@ public String getReturnCanonicalTypeName() { public Set getCaptures() { return Collections.unmodifiableSet(captures); } + + @Override + public void setUsesInstanceMethod() { + if (usesInstanceMethod) { + return; + } + usesInstanceMethod = true; + } + + @Override + public boolean usesInstanceMethod() { + return usesInstanceMethod; + } } /** @@ -340,6 +354,13 @@ public Variable defineVariable(Location location, Class type, String name, bo public abstract boolean isVariableDefined(String name); public abstract Variable getVariable(Location location, String name); + // We only want to track instance method use inside of lambdas for "this" injection. It's a noop for other scopes. + public void setUsesInstanceMethod() {} + + public boolean usesInstanceMethod() { + return false; + } + public Variable defineInternalVariable(Location location, Class type, String name, boolean isReadOnly) { return defineVariable(location, type, "#" + name, isReadOnly); } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/toxcontent/DecorationToXContent.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/toxcontent/DecorationToXContent.java index 8f47164a2b2ff..069113f88b522 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/toxcontent/DecorationToXContent.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/toxcontent/DecorationToXContent.java @@ -291,8 +291,8 @@ public static void ToXContent(ReferenceDecoration referenceDecoration, XContentB builder.endArray(); } - builder.field("factoryMethodType"); - ToXContent(ref.factoryMethodType, builder); + builder.field("factoryMethodDescriptor", ref.getFactoryMethodDescriptor()); + builder.endObject(); } diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/EmitTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/EmitTests.java new file mode 100644 index 0000000000000..037c73015dc20 --- /dev/null +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/EmitTests.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.painless; + +import org.elasticsearch.painless.spi.Whitelist; +import org.elasticsearch.painless.spi.WhitelistLoader; +import org.elasticsearch.script.ScriptContext; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class EmitTests extends ScriptTestCase { + @Override + protected Map, List> scriptContexts() { + Map, List> contexts = new HashMap<>(); + List whitelists = new ArrayList<>(Whitelist.BASE_WHITELISTS); + whitelists.add(WhitelistLoader.loadFromResourceFiles(Whitelist.class, "org.elasticsearch.painless.test")); + contexts.put(TestFieldScript.CONTEXT, whitelists); + return contexts; + } + + @Override + public TestFieldScript exec(String script) { + TestFieldScript.Factory factory = scriptEngine.compile(null, script, TestFieldScript.CONTEXT, new HashMap<>()); + TestFieldScript testScript = factory.newInstance(); + testScript.execute(); + return testScript; + } + + public void testEmit() { + TestFieldScript script = exec("emit(1L)"); + assertNotNull(script); + assertArrayEquals(new long[]{1L}, script.fetchValues()); + } + + public void testEmitFromUserFunction() { + TestFieldScript script = exec("void doEmit(long l) { emit(l) } doEmit(1L); doEmit(100L)"); + assertNotNull(script); + assertArrayEquals(new long[]{1L, 100L}, script.fetchValues()); + } +} diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/FunctionRefTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/FunctionRefTests.java index b95a1ea35558b..d7bf274370a69 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/FunctionRefTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/FunctionRefTests.java @@ -150,12 +150,12 @@ public void testCapturingMethodReferenceMultipleLambdasDefEverywhere() { "return test.twoFunctionsOfX(x::concat, y::substring);")); } - public void testOwnStaticMethodReference() { + public void testOwnMethodReference() { assertEquals(2, exec("int mycompare(int i, int j) { j - i } " + "List l = new ArrayList(); l.add(2); l.add(1); l.sort(this::mycompare); return l.get(0);")); } - public void testOwnStaticMethodReferenceDef() { + public void testOwnMethodReferenceDef() { assertEquals(2, exec("int mycompare(int i, int j) { j - i } " + "def l = new ArrayList(); l.add(2); l.add(1); l.sort(this::mycompare); return l.get(0);")); } diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/TestFieldScript.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/TestFieldScript.java new file mode 100644 index 0000000000000..9982bddf46e1d --- /dev/null +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/TestFieldScript.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.painless; + +import org.elasticsearch.script.ScriptContext; + +import java.util.ArrayList; +import java.util.List; + +public abstract class TestFieldScript { + private final List values = new ArrayList<>(); + + @SuppressWarnings("unused") + public static final String[] PARAMETERS = {}; + public interface Factory { + TestFieldScript newInstance(); + } + + public static final ScriptContext CONTEXT = + new ScriptContext<>("painless_test_fieldscript", TestFieldScript.Factory.class); + + public static class Emit { + private final TestFieldScript script; + + public Emit(TestFieldScript script) { + this.script = script; + } + + public void emit(long v) { + script.emit(v); + } + } + + public abstract void execute(); + + public final void emit(long v) { + values.add(v); + } + + public long[] fetchValues() { + return values.stream().mapToLong(i->i).toArray(); + } +} diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/ToXContentTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ToXContentTests.java index e6a74a7cff64e..4ef3953ae8065 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/ToXContentTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/ToXContentTests.java @@ -25,7 +25,7 @@ public class ToXContentTests extends ScriptTestCase { public void testUserFunction() { Map func = getFunction("def twofive(int i) { return 25 + i; } int j = 23; twofive(j)", "twofive"); assertFalse((Boolean)func.get("isInternal")); - assertTrue((Boolean)func.get("isStatic")); + assertFalse((Boolean)func.get("isStatic")); assertEquals("SFunction", func.get("node")); assertEquals("def", func.get("returns")); assertEquals(List.of("int"), func.get("parameterTypes")); diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/UserFunctionTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/UserFunctionTests.java index 6ac4ac1483c07..175fa03614d98 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/UserFunctionTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/UserFunctionTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.painless; import java.util.List; +import java.util.Map; public class UserFunctionTests extends ScriptTestCase { public void testZeroArgumentUserFunction() { @@ -28,10 +29,126 @@ public void testUserFunctionDefCallRef() { "if (getSource().startsWith('sour')) { l.add(255); }\n" + "return l;"; assertEquals(List.of(1, 49, 100, 255), exec(source)); - assertBytecodeExists(source, "public static &getSource()Ljava/lang/String"); - assertBytecodeExists(source, "public static &getMulti()I"); - assertBytecodeExists(source, "INVOKESTATIC org/elasticsearch/painless/PainlessScript$Script.&getMulti ()I"); - assertBytecodeExists(source, "public static &myCompare(II)I"); - assertBytecodeExists(source, "INVOKESTATIC org/elasticsearch/painless/PainlessScript$Script.&myCompare (II)I"); + assertBytecodeExists(source, "public &getSource()Ljava/lang/String"); + assertBytecodeExists(source, "public &getMulti()I"); + assertBytecodeExists(source, "INVOKEVIRTUAL org/elasticsearch/painless/PainlessScript$Script.&getMulti ()I"); + assertBytecodeExists(source, "public &myCompare(II)I"); + assertBytecodeExists(source, "INVOKEVIRTUAL org/elasticsearch/painless/PainlessScript$Script.&myCompare (II)I"); + } + + public void testChainedUserMethods() { + String source = "int myCompare(int a, int b) { getMulti() * (a - b) }\n" + + "int getMulti() { -1 }\n" + + "List l = [1, 100, -100];\n" + + "l.sort(this::myCompare);\n" + + "l;\n"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + } + + + public void testChainedUserMethodsLambda() { + String source = "int myCompare(int a, int b) { getMulti() * (a - b) }\n" + + "int getMulti() { -1 }\n" + + "List l = [1, 100, -100];\n" + + "l.sort((a, b) -> myCompare(a, b));\n" + + "l;\n"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + } + + public void testChainedUserMethodsDef() { + String source = "int myCompare(int a, int b) { getMulti() * (a - b) }\n" + + "int getMulti() { -1 }\n" + + "def l = [1, 100, -100];\n" + + "l.sort(this::myCompare);\n" + + "l;\n"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + } + + + public void testChainedUserMethodsLambdaDef() { + String source = "int myCompare(int a, int b) { getMulti() * (a - b) }\n" + + "int getMulti() { -1 }\n" + + "def l = [1, 100, -100];\n" + + "l.sort((a, b) -> myCompare(a, b));\n" + + "l;\n"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + } + + public void testChainedUserMethodsLambdaCaptureDef() { + String source = "int myCompare(int a, int b, int x, int m) { getMulti(m) * (a - b + x) }\n" + + "int getMulti(int m) { -1 * m }\n" + + "def l = [1, 100, -100];\n" + + "int cx = 100;\n" + + "int cm = 1;\n" + + "l.sort((a, b) -> myCompare(a, b, cx, cm));\n" + + "l;\n"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + } + + public void testMethodReferenceInUserFunction() { + String source = "int myCompare(int a, int b, String s) { " + + " Map m = ['f': 5];" + + " a - b + m.computeIfAbsent(s, this::getLength) " + + "}\n" + + "int getLength(String s) { s.length() }\n" + + "def l = [1, 0, -2];\n" + + "String s = 'g';\n" + + "l.sort((a, b) -> myCompare(a, b, s));\n" + + "l;\n"; + assertEquals(List.of(-2, 1, 0), exec(source, Map.of("a", 1), false)); + } + + public void testUserFunctionVirtual() { + String source = "int myCompare(int x, int y) { return -1 * (x - y) }\n" + + "return myCompare(100, 90);"; + assertEquals(-10, exec(source, Map.of("a", 1), false)); + assertBytecodeExists(source, "INVOKEVIRTUAL org/elasticsearch/painless/PainlessScript$Script.&myCompare (II)I"); + } + + public void testUserFunctionRef() { + String source = "int myCompare(int x, int y) { return -1 * x - y }\n" + + "List l = [1, 100, -100];\n" + + "l.sort(this::myCompare);\n" + + "return l;"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + assertBytecodeExists(source, "public &myCompare(II)I"); + } + + public void testUserFunctionRefEmpty() { + String source = "int myCompare(int x, int y) { return -1 * x - y }\n" + + "[].sort((a, b) -> myCompare(a, b));\n"; + assertNull(exec(source, Map.of("a", 1), false)); + assertBytecodeExists(source, "public &myCompare(II)I"); + assertBytecodeExists(source, "INVOKEVIRTUAL org/elasticsearch/painless/PainlessScript$Script.&myCompare (II)I"); + } + + public void testUserFunctionCallInLambda() { + String source = "int myCompare(int x, int y) { -1 * ( x - y ) }\n" + + "List l = [1, 100, -100];\n" + + "l.sort((a, b) -> myCompare(a, b));\n" + + "return l;"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + assertBytecodeExists(source, "public &myCompare(II)I"); + assertBytecodeExists(source, "INVOKEVIRTUAL org/elasticsearch/painless/PainlessScript$Script.&myCompare (II)I"); + } + + public void testUserFunctionLambdaCapture() { + String source = "int myCompare(Object o, int x, int y) { return o != null ? -1 * ( x - y ) : ( x - y ) }\n" + + "List l = [1, 100, -100];\n" + + "Object q = '';\n" + + "l.sort((a, b) -> myCompare(q, a, b));\n" + + "return l;"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + assertBytecodeExists(source, "public &myCompare(Ljava/lang/Object;II)I"); + assertBytecodeExists(source, "INVOKEVIRTUAL org/elasticsearch/painless/PainlessScript$Script.&myCompare (Ljava/lang/Object;II)I"); + } + + public void testLambdaCapture() { + String source = "List l = [1, 100, -100];\n" + + "int q = -1;\n" + + "l.sort((a, b) -> q * ( a - b ));\n" + + "return l;"; + assertEquals(List.of(100, 1, -100), exec(source, Map.of("a", 1), false)); + assertBytecodeExists(source, "public static synthetic lambda$synthetic$0(ILjava/lang/Object;Ljava/lang/Object;)I"); } } diff --git a/modules/lang-painless/src/test/resources/org/elasticsearch/painless/spi/org.elasticsearch.painless.test b/modules/lang-painless/src/test/resources/org/elasticsearch/painless/spi/org.elasticsearch.painless.test index 28c032418f373..c6554db9169ac 100644 --- a/modules/lang-painless/src/test/resources/org/elasticsearch/painless/spi/org.elasticsearch.painless.test +++ b/modules/lang-painless/src/test/resources/org/elasticsearch/painless/spi/org.elasticsearch.painless.test @@ -15,6 +15,13 @@ class org.elasticsearch.painless.api.Json { class org.elasticsearch.painless.BindingsTests$BindingsTestScript { } +# Runtime-field-like test objects +class org.elasticsearch.painless.TestFieldScript @no_import { +} +class org.elasticsearch.painless.TestFieldScript$Factory @no_import { +} + + class org.elasticsearch.painless.FeatureTestObject @no_import { int z () @@ -55,4 +62,5 @@ static_import { int classMul(int, int) from_class org.elasticsearch.painless.BindingsTests @compile_time_only int compileTimeBlowUp(int, int) from_class org.elasticsearch.painless.BindingsTests @compile_time_only List fancyConstant(String, String) from_class org.elasticsearch.painless.BindingsTests @compile_time_only + void emit(org.elasticsearch.painless.TestFieldScript, long) bound_to org.elasticsearch.painless.TestFieldScript$Emit }