From a7cc65f38c6f248afdc6c27c393146ed447e3c7a Mon Sep 17 00:00:00 2001 From: Ryan Ernst Date: Tue, 31 Jul 2018 08:28:03 -0700 Subject: [PATCH] Scripting: Fix painless compiler loader to know about context classes (#32385) This commit fixes the painless compiler classloader to know about the classes from the script context. This fixes an issue when a custom context is used from a plugin which caused a ClassNotFoundException for the script class and its factory classes. --- .../org/elasticsearch/painless/Compiler.java | 46 +++++++++++++------ .../painless/LambdaBootstrap.java | 7 ++- .../painless/PainlessScriptEngine.java | 4 +- .../painless/BaseClassTests.java | 44 +++++++++--------- .../org/elasticsearch/painless/Debugger.java | 2 +- 5 files changed, 61 insertions(+), 42 deletions(-) diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Compiler.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Compiler.java index 03345fcfff35a..807b4409d7ae9 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Compiler.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Compiler.java @@ -69,17 +69,14 @@ final class Compiler { /** * A secure class loader used to define Painless scripts. */ - static final class Loader extends SecureClassLoader { + final class Loader extends SecureClassLoader { private final AtomicInteger lambdaCounter = new AtomicInteger(0); - private final PainlessLookup painlessLookup; /** * @param parent The parent ClassLoader. */ - Loader(ClassLoader parent, PainlessLookup painlessLookup) { + Loader(ClassLoader parent) { super(parent); - - this.painlessLookup = painlessLookup; } /** @@ -90,6 +87,15 @@ static final class Loader extends SecureClassLoader { */ @Override public Class findClass(String name) throws ClassNotFoundException { + if (scriptClass.getName().equals(name)) { + return scriptClass; + } + if (factoryClass != null && factoryClass.getName().equals(name)) { + return factoryClass; + } + if (statefulFactoryClass != null && statefulFactoryClass.getName().equals(name)) { + return statefulFactoryClass; + } Class found = painlessLookup.getClassFromBinaryName(name); return found != null ? found : super.findClass(name); @@ -139,13 +145,23 @@ int newLambdaIdentifier() { * {@link Compiler}'s specified {@link PainlessLookup}. */ public Loader createLoader(ClassLoader parent) { - return new Loader(parent, painlessLookup); + return new Loader(parent); } /** - * The class/interface the script is guaranteed to derive/implement. + * The class/interface the script will implement. + */ + private final Class scriptClass; + + /** + * The class/interface to create the {@code scriptClass} instance. + */ + private final Class factoryClass; + + /** + * An optional class/interface to create the {@code factoryClass} instance. */ - private final Class base; + private final Class statefulFactoryClass; /** * The whitelist the script will use. @@ -154,11 +170,15 @@ public Loader createLoader(ClassLoader parent) { /** * Standard constructor. - * @param base The class/interface the script is guaranteed to derive/implement. + * @param scriptClass The class/interface the script will implement. + * @param factoryClass An optional class/interface to create the {@code scriptClass} instance. + * @param statefulFactoryClass An optional class/interface to create the {@code factoryClass} instance. * @param painlessLookup The whitelist the script will use. */ - Compiler(Class base, PainlessLookup painlessLookup) { - this.base = base; + Compiler(Class scriptClass, Class factoryClass, Class statefulFactoryClass, PainlessLookup painlessLookup) { + this.scriptClass = scriptClass; + this.factoryClass = factoryClass; + this.statefulFactoryClass = statefulFactoryClass; this.painlessLookup = painlessLookup; } @@ -177,7 +197,7 @@ Constructor compile(Loader loader, MainMethodReserved reserved, String name, " plugin if a script longer than this length is a requirement."); } - ScriptClassInfo scriptClassInfo = new ScriptClassInfo(painlessLookup, base); + ScriptClassInfo scriptClassInfo = new ScriptClassInfo(painlessLookup, scriptClass); SSource root = Walker.buildPainlessTree(scriptClassInfo, reserved, name, source, settings, painlessLookup, null); root.analyze(painlessLookup); @@ -209,7 +229,7 @@ byte[] compile(String name, String source, CompilerSettings settings, Printer de " plugin if a script longer than this length is a requirement."); } - ScriptClassInfo scriptClassInfo = new ScriptClassInfo(painlessLookup, base); + ScriptClassInfo scriptClassInfo = new ScriptClassInfo(painlessLookup, scriptClass); SSource root = Walker.buildPainlessTree(scriptClassInfo, new MainMethodReserved(), name, source, settings, painlessLookup, debugStream); root.analyze(painlessLookup); diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/LambdaBootstrap.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/LambdaBootstrap.java index 3fc8554b271e2..d95dc4266889b 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/LambdaBootstrap.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/LambdaBootstrap.java @@ -36,7 +36,6 @@ import java.security.PrivilegedAction; import static java.lang.invoke.MethodHandles.Lookup; -import static org.elasticsearch.painless.Compiler.Loader; import static org.elasticsearch.painless.WriterConstants.CLASS_VERSION; import static org.elasticsearch.painless.WriterConstants.CTOR_METHOD_NAME; import static org.elasticsearch.painless.WriterConstants.DELEGATE_BOOTSTRAP_HANDLE; @@ -207,7 +206,7 @@ public static CallSite lambdaBootstrap( MethodType delegateMethodType, int isDelegateInterface) throws LambdaConversionException { - Loader loader = (Loader)lookup.lookupClass().getClassLoader(); + Compiler.Loader loader = (Compiler.Loader)lookup.lookupClass().getClassLoader(); String lambdaClassName = Type.getInternalName(lookup.lookupClass()) + "$$Lambda" + loader.newLambdaIdentifier(); Type lambdaClassType = Type.getObjectType(lambdaClassName); Type delegateClassType = Type.getObjectType(delegateClassName.replace('.', '/')); @@ -457,11 +456,11 @@ private static void endLambdaClass(ClassWriter cw) { } /** - * Defines the {@link Class} for the lambda class using the same {@link Loader} + * Defines the {@link Class} for the lambda class using the same {@link Compiler.Loader} * that originally defined the class for the Painless script. */ private static Class createLambdaClass( - Loader loader, + Compiler.Loader loader, ClassWriter cw, Type lambdaClassType) { diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/PainlessScriptEngine.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/PainlessScriptEngine.java index a5a9823d13018..3a2a6d1452df1 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/PainlessScriptEngine.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/PainlessScriptEngine.java @@ -102,10 +102,10 @@ public PainlessScriptEngine(Settings settings, Map, List, List> entry : contexts.entrySet()) { ScriptContext context = entry.getKey(); if (context.instanceClazz.equals(SearchScript.class) || context.instanceClazz.equals(ExecutableScript.class)) { - contextsToCompilers.put(context, new Compiler(GenericElasticsearchScript.class, + contextsToCompilers.put(context, new Compiler(GenericElasticsearchScript.class, null, null, PainlessLookupBuilder.buildFromWhitelists(entry.getValue()))); } else { - contextsToCompilers.put(context, new Compiler(context.instanceClazz, + contextsToCompilers.put(context, new Compiler(context.instanceClazz, context.factoryClazz, context.statefulFactoryClazz, PainlessLookupBuilder.buildFromWhitelists(entry.getValue()))); } } diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/BaseClassTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/BaseClassTests.java index c852d5a41dec1..c68302bde56f2 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/BaseClassTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/BaseClassTests.java @@ -69,7 +69,7 @@ public Map getTestMap() { } public void testGets() { - Compiler compiler = new Compiler(Gets.class, painlessLookup); + Compiler compiler = new Compiler(Gets.class, null, null, painlessLookup); Map map = new HashMap<>(); map.put("s", 1); @@ -87,7 +87,7 @@ public abstract static class NoArgs { public abstract Object execute(); } public void testNoArgs() { - Compiler compiler = new Compiler(NoArgs.class, painlessLookup); + Compiler compiler = new Compiler(NoArgs.class, null, null, painlessLookup); assertEquals(1, ((NoArgs)scriptEngine.compile(compiler, null, "1", emptyMap())).execute()); assertEquals("foo", ((NoArgs)scriptEngine.compile(compiler, null, "'foo'", emptyMap())).execute()); @@ -111,13 +111,13 @@ public abstract static class OneArg { public abstract Object execute(Object arg); } public void testOneArg() { - Compiler compiler = new Compiler(OneArg.class, painlessLookup); + Compiler compiler = new Compiler(OneArg.class, null, null, painlessLookup); Object rando = randomInt(); assertEquals(rando, ((OneArg)scriptEngine.compile(compiler, null, "arg", emptyMap())).execute(rando)); rando = randomAlphaOfLength(5); assertEquals(rando, ((OneArg)scriptEngine.compile(compiler, null, "arg", emptyMap())).execute(rando)); - Compiler noargs = new Compiler(NoArgs.class, painlessLookup); + Compiler noargs = new Compiler(NoArgs.class, null, null, painlessLookup); Exception e = expectScriptThrows(IllegalArgumentException.class, () -> scriptEngine.compile(noargs, null, "doc", emptyMap())); assertEquals("Variable [doc] is not defined.", e.getMessage()); @@ -132,7 +132,7 @@ public abstract static class ArrayArg { public abstract Object execute(String[] arg); } public void testArrayArg() { - Compiler compiler = new Compiler(ArrayArg.class, painlessLookup); + Compiler compiler = new Compiler(ArrayArg.class, null, null, painlessLookup); String rando = randomAlphaOfLength(5); assertEquals(rando, ((ArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new String[] {rando, "foo"})); } @@ -142,7 +142,7 @@ public abstract static class PrimitiveArrayArg { public abstract Object execute(int[] arg); } public void testPrimitiveArrayArg() { - Compiler compiler = new Compiler(PrimitiveArrayArg.class, painlessLookup); + Compiler compiler = new Compiler(PrimitiveArrayArg.class, null, null, painlessLookup); int rando = randomInt(); assertEquals(rando, ((PrimitiveArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new int[] {rando, 10})); } @@ -152,7 +152,7 @@ public abstract static class DefArrayArg { public abstract Object execute(Object[] arg); } public void testDefArrayArg() { - Compiler compiler = new Compiler(DefArrayArg.class, painlessLookup); + Compiler compiler = new Compiler(DefArrayArg.class, null, null, painlessLookup); Object rando = randomInt(); assertEquals(rando, ((DefArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new Object[] {rando, 10})); rando = randomAlphaOfLength(5); @@ -170,7 +170,7 @@ public abstract static class ManyArgs { public abstract boolean needsD(); } public void testManyArgs() { - Compiler compiler = new Compiler(ManyArgs.class, painlessLookup); + Compiler compiler = new Compiler(ManyArgs.class, null, null, painlessLookup); int rando = randomInt(); assertEquals(rando, ((ManyArgs)scriptEngine.compile(compiler, null, "a", emptyMap())).execute(rando, 0, 0, 0)); assertEquals(10, ((ManyArgs)scriptEngine.compile(compiler, null, "a + b + c + d", emptyMap())).execute(1, 2, 3, 4)); @@ -198,7 +198,7 @@ public abstract static class VarargTest { public abstract Object execute(String... arg); } public void testVararg() { - Compiler compiler = new Compiler(VarargTest.class, painlessLookup); + Compiler compiler = new Compiler(VarargTest.class, null, null, painlessLookup); assertEquals("foo bar baz", ((VarargTest)scriptEngine.compile(compiler, null, "String.join(' ', Arrays.asList(arg))", emptyMap())) .execute("foo", "bar", "baz")); } @@ -214,7 +214,7 @@ public Object executeWithASingleOne(int a, int b, int c) { } } public void testDefaultMethods() { - Compiler compiler = new Compiler(DefaultMethods.class, painlessLookup); + Compiler compiler = new Compiler(DefaultMethods.class, null, null, painlessLookup); int rando = randomInt(); assertEquals(rando, ((DefaultMethods)scriptEngine.compile(compiler, null, "a", emptyMap())).execute(rando, 0, 0, 0)); assertEquals(rando, ((DefaultMethods)scriptEngine.compile(compiler, null, "a", emptyMap())).executeWithASingleOne(rando, 0, 0)); @@ -228,7 +228,7 @@ public abstract static class ReturnsVoid { public abstract void execute(Map map); } public void testReturnsVoid() { - Compiler compiler = new Compiler(ReturnsVoid.class, painlessLookup); + Compiler compiler = new Compiler(ReturnsVoid.class, null, null, painlessLookup); Map map = new HashMap<>(); ((ReturnsVoid)scriptEngine.compile(compiler, null, "map.a = 'foo'", emptyMap())).execute(map); assertEquals(singletonMap("a", "foo"), map); @@ -247,7 +247,7 @@ public abstract static class ReturnsPrimitiveBoolean { public abstract boolean execute(); } public void testReturnsPrimitiveBoolean() { - Compiler compiler = new Compiler(ReturnsPrimitiveBoolean.class, painlessLookup); + Compiler compiler = new Compiler(ReturnsPrimitiveBoolean.class, null, null, painlessLookup); assertEquals(true, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "true", emptyMap())).execute()); assertEquals(false, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "false", emptyMap())).execute()); @@ -289,7 +289,7 @@ public abstract static class ReturnsPrimitiveInt { public abstract int execute(); } public void testReturnsPrimitiveInt() { - Compiler compiler = new Compiler(ReturnsPrimitiveInt.class, painlessLookup); + Compiler compiler = new Compiler(ReturnsPrimitiveInt.class, null, null, painlessLookup); assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "1", emptyMap())).execute()); assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "(int) 1L", emptyMap())).execute()); @@ -331,7 +331,7 @@ public abstract static class ReturnsPrimitiveFloat { public abstract float execute(); } public void testReturnsPrimitiveFloat() { - Compiler compiler = new Compiler(ReturnsPrimitiveFloat.class, painlessLookup); + Compiler compiler = new Compiler(ReturnsPrimitiveFloat.class, null, null, painlessLookup); assertEquals(1.1f, ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "1.1f", emptyMap())).execute(), 0); assertEquals(1.1f, ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "(float) 1.1d", emptyMap())).execute(), 0); @@ -362,7 +362,7 @@ public abstract static class ReturnsPrimitiveDouble { public abstract double execute(); } public void testReturnsPrimitiveDouble() { - Compiler compiler = new Compiler(ReturnsPrimitiveDouble.class, painlessLookup); + Compiler compiler = new Compiler(ReturnsPrimitiveDouble.class, null, null, painlessLookup); assertEquals(1.0, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "1", emptyMap())).execute(), 0); assertEquals(1.0, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "1L", emptyMap())).execute(), 0); @@ -396,7 +396,7 @@ public abstract static class NoArgumentsConstant { public abstract Object execute(String foo); } public void testNoArgumentsConstant() { - Compiler compiler = new Compiler(NoArgumentsConstant.class, painlessLookup); + Compiler compiler = new Compiler(NoArgumentsConstant.class, null, null, painlessLookup); Exception e = expectScriptThrows(IllegalArgumentException.class, false, () -> scriptEngine.compile(compiler, null, "1", emptyMap())); assertThat(e.getMessage(), startsWith( @@ -409,7 +409,7 @@ public abstract static class WrongArgumentsConstant { public abstract Object execute(String foo); } public void testWrongArgumentsConstant() { - Compiler compiler = new Compiler(WrongArgumentsConstant.class, painlessLookup); + Compiler compiler = new Compiler(WrongArgumentsConstant.class, null, null, painlessLookup); Exception e = expectScriptThrows(IllegalArgumentException.class, false, () -> scriptEngine.compile(compiler, null, "1", emptyMap())); assertThat(e.getMessage(), startsWith( @@ -422,7 +422,7 @@ public abstract static class WrongLengthOfArgumentConstant { public abstract Object execute(String foo); } public void testWrongLengthOfArgumentConstant() { - Compiler compiler = new Compiler(WrongLengthOfArgumentConstant.class, painlessLookup); + Compiler compiler = new Compiler(WrongLengthOfArgumentConstant.class, null, null, painlessLookup); Exception e = expectScriptThrows(IllegalArgumentException.class, false, () -> scriptEngine.compile(compiler, null, "1", emptyMap())); assertThat(e.getMessage(), startsWith("[" + WrongLengthOfArgumentConstant.class.getName() + "#ARGUMENTS] has length [2] but [" @@ -434,7 +434,7 @@ public abstract static class UnknownArgType { public abstract Object execute(UnknownArgType foo); } public void testUnknownArgType() { - Compiler compiler = new Compiler(UnknownArgType.class, painlessLookup); + Compiler compiler = new Compiler(UnknownArgType.class, null, null, painlessLookup); Exception e = expectScriptThrows(IllegalArgumentException.class, false, () -> scriptEngine.compile(compiler, null, "1", emptyMap())); assertEquals("[foo] is of unknown type [" + UnknownArgType.class.getName() + ". Painless interfaces can only accept arguments " @@ -446,7 +446,7 @@ public abstract static class UnknownReturnType { public abstract UnknownReturnType execute(String foo); } public void testUnknownReturnType() { - Compiler compiler = new Compiler(UnknownReturnType.class, painlessLookup); + Compiler compiler = new Compiler(UnknownReturnType.class, null, null, painlessLookup); Exception e = expectScriptThrows(IllegalArgumentException.class, false, () -> scriptEngine.compile(compiler, null, "1", emptyMap())); assertEquals("Painless can only implement execute methods returning a whitelisted type but [" + UnknownReturnType.class.getName() @@ -458,7 +458,7 @@ public abstract static class UnknownArgTypeInArray { public abstract Object execute(UnknownArgTypeInArray[] foo); } public void testUnknownArgTypeInArray() { - Compiler compiler = new Compiler(UnknownArgTypeInArray.class, painlessLookup); + Compiler compiler = new Compiler(UnknownArgTypeInArray.class, null, null, painlessLookup); Exception e = expectScriptThrows(IllegalArgumentException.class, false, () -> scriptEngine.compile(compiler, null, "1", emptyMap())); assertEquals("[foo] is of unknown type [" + UnknownArgTypeInArray.class.getName() + ". Painless interfaces can only accept " @@ -470,7 +470,7 @@ public abstract static class TwoExecuteMethods { public abstract Object execute(boolean foo); } public void testTwoExecuteMethods() { - Compiler compiler = new Compiler(TwoExecuteMethods.class, painlessLookup); + Compiler compiler = new Compiler(TwoExecuteMethods.class, null, null, painlessLookup); Exception e = expectScriptThrows(IllegalArgumentException.class, false, () -> scriptEngine.compile(compiler, null, "null", emptyMap())); assertEquals("Painless can only implement interfaces that have a single method named [execute] but [" diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/Debugger.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/Debugger.java index 48af3898e0952..ae33ebfb6e9c5 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/Debugger.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/Debugger.java @@ -40,7 +40,7 @@ static String toString(Class iface, String source, CompilerSettings settings) PrintWriter outputWriter = new PrintWriter(output); Textifier textifier = new Textifier(); try { - new Compiler(iface, PainlessLookupBuilder.buildFromWhitelists(Whitelist.BASE_WHITELISTS)) + new Compiler(iface, null, null, PainlessLookupBuilder.buildFromWhitelists(Whitelist.BASE_WHITELISTS)) .compile("", source, settings, textifier); } catch (RuntimeException e) { textifier.print(outputWriter);