Skip to content

Commit

Permalink
Script: Mangle user function names (elastic#72892)
Browse files Browse the repository at this point in the history
Prepend `&` to user function names.  In future changes user
functions will switch from being static methods to member methods.
The mangled user function names will prohibit users from overriden
other script methods.

Refs: elastic#69742
Backport: f6bf99c
  • Loading branch information
stu-elastic committed May 10, 2021
1 parent b0cad22 commit 4c9e1f6
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public static FunctionRef create(PainlessLookup painlessLookup, FunctionTable fu
isDelegateInterface = false;
isDelegateAugmented = false;
delegateInvokeType = H_INVOKESTATIC;
delegateMethodName = localFunction.getFunctionName();
delegateMethodName = localFunction.getMangledName();
delegateMethodType = localFunction.getMethodType();
delegateInjections = new Object[0];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public void visitFunction(SFunction userFunctionNode, ScriptScope scriptScope) {
typeParameters.add(paramType);
}

functionTable.addFunction(functionName, returnType, typeParameters, userFunctionNode.isInternal(), userFunctionNode.isStatic());
functionTable.addMangledFunction(functionName, returnType, typeParameters, userFunctionNode.isInternal(),
userFunctionNode.isStatic());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,11 @@ public void visitFunction(SFunction userFunctionNode, ScriptScope scriptScope) {

FunctionNode irFunctionNode = new FunctionNode(userFunctionNode.getLocation());
irFunctionNode.setBlockNode(irBlockNode);
irFunctionNode.attachDecoration(new IRDName(userFunctionNode.getFunctionName()));
String mangledName = scriptScope.getFunctionTable().getFunction(
userFunctionNode.getFunctionName(),
userFunctionNode.getCanonicalTypeNameParameters().size()
).getMangledName();
irFunctionNode.attachDecoration(new IRDName(mangledName));
irFunctionNode.attachDecoration(new IRDReturnType(returnType));
irFunctionNode.attachDecoration(new IRDTypeParameters(new ArrayList<>(localFunction.getTypeParameters())));
irFunctionNode.attachDecoration(new IRDParameterNames(new ArrayList<>(userFunctionNode.getParameterNames())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
*/
public class FunctionTable {

private static final String MANGLED_FUNCTION_NAME_PREFIX = "&";

public static class LocalFunction {

protected final String functionName;
protected final String mangledName;
protected final Class<?> returnType;
protected final List<Class<?>> typeParameters;
protected final boolean isInternal;
Expand All @@ -38,8 +41,14 @@ public static class LocalFunction {

public LocalFunction(
String functionName, Class<?> returnType, List<Class<?>> typeParameters, boolean isInternal, boolean isStatic) {
this(functionName, "", returnType, typeParameters, isInternal, isStatic);
}

private LocalFunction(String functionName, String mangle,
Class<?> returnType, List<Class<?>> typeParameters, boolean isInternal, boolean isStatic) {

this.functionName = Objects.requireNonNull(functionName);
this.mangledName = Objects.requireNonNull(mangle) + this.functionName;
this.returnType = Objects.requireNonNull(returnType);
this.typeParameters = Collections.unmodifiableList(Objects.requireNonNull(typeParameters));
this.isInternal = isInternal;
Expand All @@ -49,12 +58,12 @@ public LocalFunction(
Class<?>[] javaTypeParameters = typeParameters.stream().map(PainlessLookupUtility::typeToJavaType).toArray(Class<?>[]::new);

this.methodType = MethodType.methodType(javaReturnType, javaTypeParameters);
this.asmMethod = new org.objectweb.asm.commons.Method(functionName,
this.asmMethod = new org.objectweb.asm.commons.Method(mangledName,
MethodType.methodType(javaReturnType, javaTypeParameters).toMethodDescriptorString());
}

public String getFunctionName() {
return functionName;
public String getMangledName() {
return mangledName;
}

public Class<?> getReturnType() {
Expand Down Expand Up @@ -103,8 +112,11 @@ public LocalFunction addFunction(
return function;
}

public LocalFunction addFunction(LocalFunction function) {
String functionKey = buildLocalFunctionKey(function.getFunctionName(), function.getTypeParameters().size());
public LocalFunction addMangledFunction(String functionName,
Class<?> returnType, List<Class<?>> typeParameters, boolean isInternal, boolean isStatic) {
String functionKey = buildLocalFunctionKey(functionName, typeParameters.size());
LocalFunction function =
new LocalFunction(functionName, MANGLED_FUNCTION_NAME_PREFIX, returnType, typeParameters, isInternal, isStatic);
localFunctions.put(functionKey, function);
return function;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ public static void ToXContent(PainlessMethod method, XContentBuilderWrapper buil

public static void ToXContent(FunctionTable.LocalFunction localFunction, XContentBuilderWrapper builder) {
builder.startObject();
builder.field("functionName", localFunction.getFunctionName());
builder.field("mangledName", localFunction.getMangledName());
builder.field("returnType", localFunction.getReturnType().getSimpleName());
if (localFunction.getTypeParameters().isEmpty() == false) {
builder.field("typeParameters", classNames(localFunction.getTypeParameters()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,23 @@ public void testZeroArgumentUserFunction() {
String source = "def twofive() { return 25; } twofive()";
assertEquals(25, exec(source));
}

public void testUserFunctionDefCallRef() {
String source =
"String getSource() { 'source'; }\n" +
"int myCompare(int a, int b) { getMulti() * Integer.compare(a, b) }\n" +
"int getMulti() { return -1 }\n" +
"def l = [1, 100, -100];\n" +
"if (myCompare(10, 50) > 0) { l.add(50 + getMulti()) }\n" +
"l.sort(this::myCompare);\n" +
"if (l[0] == 100) { l.remove(l.size() - 1) ; l.sort((a, b) -> -1 * myCompare(a, b)) } \n"+
"if (getSource().startsWith('sour')) { l.add(255); }\n" +
"return l;";
assertEquals(org.elasticsearch.common.collect.List.of(1, 49, 100, 255), exec(source));
assertBytecodeExists(source, "public static &getSource()Ljava/lang/String");
assertBytecodeExists(source, "public static &getMulti()I");
assertBytecodeExists(source, "INVOKESTATIC org/elasticsearch/painless/PainlessScript$Script.&getMulti ()I");
assertBytecodeExists(source, "public static &myCompare(II)I");
assertBytecodeExists(source, "INVOKESTATIC org/elasticsearch/painless/PainlessScript$Script.&myCompare (II)I");
}
}

0 comments on commit 4c9e1f6

Please sign in to comment.