Skip to content

Commit

Permalink
Optimize lambda expression body
Browse files Browse the repository at this point in the history
  • Loading branch information
guiyanakuang authored Aug 12, 2023
1 parent 92b72dd commit 0f4b813
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,18 @@ protected Object visitLambdaExpression(LambdaExpression node, Object context)
if (optimize) {
// TODO: enable optimization related to lambda expression
// A mechanism to convert function type back into lambda expression need to exist to enable optimization
return node;
Object value = processWithExceptionHandling(node.getBody(), context);
Expression optimizedBody;

// value may be null, converted to an expression by toExpression(value, type)
if (value instanceof Expression) {
optimizedBody = (Expression) value;
}
else {
Type type = type(node.getBody());
optimizedBody = toExpression(value, type);
}
return new LambdaExpression(node.getArguments(), optimizedBody);
}

Expression body = node.getBody();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,25 @@ public void testIsNotNull()
assertOptimizedEquals("bound_decimal_long IS NOT NULL", "true");
}

@Test
public void testLambdaBody()
{
assertOptimizedEquals("transform(ARRAY[bound_long], n -> CAST(n as BIGINT))",
"transform(ARRAY[bound_long], n -> n)");
assertOptimizedEquals("transform(ARRAY[bound_long], n -> CAST(n as VARCHAR(5)))",
"transform(ARRAY[bound_long], n -> CAST(n as VARCHAR(5)))");
assertOptimizedEquals("transform(ARRAY[bound_long], n -> IF(false, 1, 0 / 0))",
"transform(ARRAY[bound_long], n -> 0 / 0)");
assertOptimizedEquals("transform(ARRAY[bound_long], n -> 5 / 0)",
"transform(ARRAY[bound_long], n -> 5 / 0)");
assertOptimizedEquals("transform(ARRAY[bound_long], n -> nullif(true, true))",
"transform(ARRAY[bound_long], n -> CAST(null AS Boolean))");
assertOptimizedEquals("transform(ARRAY[bound_long], n -> n + 10 * 10)",
"transform(ARRAY[bound_long], n -> n + 100)");
assertOptimizedEquals("reduce_agg(bound_long, 0, (a, b) -> IF(false, a, b), (a, b) -> IF(true, a, b))",
"reduce_agg(bound_long, 0, (a, b) -> b, (a, b) -> a)");
}

@Test
public void testNullIf()
{
Expand Down

0 comments on commit 0f4b813

Please sign in to comment.