diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeUtils.java b/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeUtils.java index 53511c54f541..4c1a82b985eb 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeUtils.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/BytecodeUtils.java @@ -420,7 +420,7 @@ else if (type == Boolean.class) { public static BytecodeExpression invoke(Binding binding, String name) { // ensure that name doesn't have a special characters - return invokeDynamic(BOOTSTRAP_METHOD, ImmutableList.of(binding.getBindingId()), name.replaceAll("[^A-Za-z0-9_$]", "_"), binding.getType()); + return invokeDynamic(BOOTSTRAP_METHOD, ImmutableList.of(binding.getBindingId()), sanitizeName(name), binding.getType()); } public static BytecodeExpression invoke(Binding binding, BoundSignature signature) @@ -428,6 +428,14 @@ public static BytecodeExpression invoke(Binding binding, BoundSignature signatur return invoke(binding, signature.getName()); } + /** + * Replace characters that are not safe to use in a JVM identifier. + */ + public static String sanitizeName(String name) + { + return name.replaceAll("[^A-Za-z0-9_$]", "_"); + } + public static BytecodeNode generateWrite(CallSiteBinder callSiteBinder, Scope scope, Variable wasNullVariable, Type type) { Class valueJavaType = type.getJavaType(); diff --git a/presto-main/src/main/java/io/prestosql/sql/gen/LambdaBytecodeGenerator.java b/presto-main/src/main/java/io/prestosql/sql/gen/LambdaBytecodeGenerator.java index d6c046d0ab41..fbe95c8b8274 100644 --- a/presto-main/src/main/java/io/prestosql/sql/gen/LambdaBytecodeGenerator.java +++ b/presto-main/src/main/java/io/prestosql/sql/gen/LambdaBytecodeGenerator.java @@ -122,7 +122,7 @@ public static CompiledLambda preGenerateLambdaExpression( for (int i = 0; i < lambdaExpression.getArguments().size(); i++) { Class type = Primitives.wrap(lambdaExpression.getArgumentTypes().get(i).getJavaType()); String argumentName = lambdaExpression.getArguments().get(i); - Parameter arg = arg("lambda_" + argumentName, type); + Parameter arg = arg("lambda_" + i + "_" + BytecodeUtils.sanitizeName(argumentName), type); parameters.add(arg); parameterMapBuilder.put(argumentName, new ParameterAndType(arg, type)); } diff --git a/presto-main/src/test/java/io/prestosql/operator/scalar/TestLambdaExpression.java b/presto-main/src/test/java/io/prestosql/operator/scalar/TestLambdaExpression.java index cbae2060c4bd..619ba8dfce25 100644 --- a/presto-main/src/test/java/io/prestosql/operator/scalar/TestLambdaExpression.java +++ b/presto-main/src/test/java/io/prestosql/operator/scalar/TestLambdaExpression.java @@ -54,6 +54,14 @@ public void testBasic() assertFunction("apply(5 + RANDOM(1), x -> x + 1)", INTEGER, 6); } + @Test + public void testParameterName() + { + // parameter which is not valid identifier in Java + String nonLetters = "a.b c; d ' \n \\n \""; + assertFunction("apply(5, " + quote(nonLetters) + " -> " + quote(nonLetters) + " * 2)", INTEGER, 10); + } + @Test public void testNull() { @@ -158,4 +166,9 @@ public void testTypeCombinations() assertFunction("apply(ARRAY['abc', NULL, '123'], x -> x[2])", createVarcharType(3), null); assertFunction("apply(MAP(ARRAY['abc', 'def'], ARRAY[123, 456]), x -> map_keys(x))", new ArrayType(createVarcharType(3)), ImmutableList.of("abc", "def")); } + + private static String quote(String identifier) + { + return "\"" + identifier.replace("\"", "\"\"") + "\""; + } } diff --git a/presto-main/src/test/java/io/prestosql/sql/query/TestLambdaExpressions.java b/presto-main/src/test/java/io/prestosql/sql/query/TestLambdaExpressions.java index cf51c6125db4..3839733b7f23 100644 --- a/presto-main/src/test/java/io/prestosql/sql/query/TestLambdaExpressions.java +++ b/presto-main/src/test/java/io/prestosql/sql/query/TestLambdaExpressions.java @@ -63,6 +63,20 @@ public void testDuplicateLambdaExpressions() .matches("VALUES ROW(ARRAY[2, 3, 4], ARRAY[11e0, 21e0, 31e0])"); } + @Test + public void testParameterName() + { + // lambda may be using parameter which is not valid identifier in Java + assertThat(assertions.query("" + + "WITH t AS ( " + + " SELECT count(*) AS \"a.b c; d\" FROM (VALUES (42)) " + + " UNION ALL " + + " SELECT * FROM (VALUES (77)) v(\"a.b c; d\") " + + ") " + + "SELECT transform(ARRAY[1], x -> x + \"a.b c; d\") FROM t")) + .matches("VALUES ARRAY[BIGINT '2'], ARRAY[BIGINT '78']"); + } + @Test public void testNestedLambda() {