Skip to content

Commit

Permalink
feat: implement correct logic for nested lambdas and more complex lam…
Browse files Browse the repository at this point in the history
…bda expressions
  • Loading branch information
stevenpyzhang committed Feb 24, 2021
1 parent d6529a3 commit 8225a82
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,24 +183,31 @@ public Void visitLikePredicate(final LikePredicate node, final TypeContext conte

@Override
public Void visitFunctionCall(final FunctionCall node, final TypeContext context) {
final List<SqlArgument> argumentTypes = new ArrayList<>();
final FunctionName functionName = node.getName();

final TypeContext contextCopy = context.getCopy();
final List<SqlArgument> argumentTypes = new ArrayList<>();
final boolean hasLambda = node.hasLambdaFunctionCallArguments();
for (final Expression argExpr : node.getArguments()) {
final TypeContext childContext = context.getCopy();
final TypeContext childContext;
if (argExpr instanceof LambdaFunctionCall) {
childContext = contextCopy.getCopy();
} else {
childContext = context.getCopy();
}
final SqlType resolvedArgType =
expressionTypeManager.getExpressionSqlType(argExpr, childContext);
process(argExpr, context.getCopy());
expressionTypeManager.getExpressionSqlType(argExpr, childContext.getCopy());

process(argExpr, childContext.getCopy());
if (argExpr instanceof LambdaFunctionCall) {
argumentTypes.add(
SqlArgument.of(
SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType())));
SqlLambda.of(contextCopy.getLambdaInputTypes(), resolvedArgType)));
} else {
argumentTypes.add(SqlArgument.of(resolvedArgType));
// for lambdas - we save the type information to resolve the lambda generics
if (hasLambda) {
context.visitType(resolvedArgType);
contextCopy.visitType(resolvedArgType);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,21 +452,32 @@ public Pair<String, SqlType> visitFunctionCall(
final String instanceName = funNameToCodeName.apply(functionName);

final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName());

final TypeContext contextCopy = context.getCopy();
final List<SqlArgument> argumentSchemas = new ArrayList<>();
final List<TypeContext> typeContextsForChildren = new ArrayList<>();
final boolean hasLambda = node.hasLambdaFunctionCallArguments();

for (final Expression argExpr : node.getArguments()) {
final TypeContext childContext = context.getCopy();
final TypeContext childContext;
if (argExpr instanceof LambdaFunctionCall) {
childContext = contextCopy.getCopy();
} else {
childContext = context.getCopy();
}

typeContextsForChildren.add(childContext.getCopy());
final SqlType resolvedArgType =
expressionTypeManager.getExpressionSqlType(argExpr, childContext);
if (argExpr instanceof LambdaFunctionCall) {
argumentSchemas.add(
SqlArgument.of(
SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType())));
SqlLambda.of(contextCopy.getLambdaInputTypes(), resolvedArgType)));
} else {
argumentSchemas.add(SqlArgument.of(resolvedArgType));
// for lambdas - we save the type information to resolve the lambda generics
if (hasLambda) {
context.visitType(resolvedArgType);
contextCopy.visitType(resolvedArgType);
}
}
}
Expand Down Expand Up @@ -494,7 +505,7 @@ public Pair<String, SqlType> visitFunctionCall(
}

joiner.add(
process(convertArgument(arg, sqlType, paramType), context.getCopy())
process(convertArgument(arg, sqlType, paramType),typeContextsForChildren.get(i))
.getLeft());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,22 +474,30 @@ public Void visitFunctionCall(
final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName());

final List<SqlArgument> argTypes = new ArrayList<>();
final TypeContext contextCopy = expressionTypeContext.getCopy();

final boolean hasLambda = node.hasLambdaFunctionCallArguments();
for (final Expression expression : node.getArguments()) {
final TypeContext childContext = expressionTypeContext.getCopy();
final TypeContext childContext;
if (expression instanceof LambdaFunctionCall) {
childContext = contextCopy.getCopy();
} else {
childContext = expressionTypeContext.getCopy();
}
process(expression, childContext);
final SqlType resolvedArgType = childContext.getSqlType();

if (expression instanceof LambdaFunctionCall) {
argTypes.add(
SqlArgument.of(
SqlLambda.of(expressionTypeContext.getLambdaInputTypes(),
childContext.getSqlType())));
SqlLambda.of(
contextCopy.getLambdaInputTypes(),
resolvedArgType)));
} else {
argTypes.add(SqlArgument.of(resolvedArgType));
// for lambdas - we save the type information to resolve the lambda generics
if (hasLambda) {
expressionTypeContext.visitType(resolvedArgType);
contextCopy.visitType(resolvedArgType);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -65,7 +64,6 @@
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.expression.tree.SubscriptExpression;
import io.confluent.ksql.execution.expression.tree.TimeLiteral;
import io.confluent.ksql.execution.expression.tree.TimestampLiteral;
import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.WhenClause;
import io.confluent.ksql.function.FunctionRegistry;
Expand All @@ -74,18 +72,16 @@
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.GenericType;
import io.confluent.ksql.function.types.LambdaType;
import io.confluent.ksql.function.types.MapType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.types.ParamTypes;
import io.confluent.ksql.function.udf.UdfMetadata;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.Operator;
import io.confluent.ksql.schema.ksql.types.SqlPrimitiveType;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.util.KsqlConfig;
import java.math.BigDecimal;
import java.sql.Timestamp;
import java.util.Collections;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -243,10 +239,10 @@ public void shouldPostfixFunctionInstancesWithUniqueId() {
final KsqlScalarFunction ssFunction = mock(KsqlScalarFunction.class);
final UdfFactory catFactory = mock(UdfFactory.class);
final KsqlScalarFunction catFunction = mock(KsqlScalarFunction.class);
givenUdf("SUBSTRING", ssFactory, ssFunction);
givenUdf("SUBSTRING", ssFactory, ssFunction, SqlTypes.STRING);
when(ssFunction.parameters())
.thenReturn(ImmutableList.of(ParamTypes.STRING, ParamTypes.INTEGER, ParamTypes.INTEGER));
givenUdf("CONCAT", catFactory, catFunction);
givenUdf("CONCAT", catFactory, catFunction, SqlTypes.STRING);
when(catFunction.parameters())
.thenReturn(ImmutableList.of(ParamTypes.STRING, ParamTypes.STRING));
final FunctionName ssName = FunctionName.of("SUBSTRING");
Expand Down Expand Up @@ -284,7 +280,7 @@ public void shouldImplicitlyCastFunctionCallParameters() {
// Given:
final UdfFactory udfFactory = mock(UdfFactory.class);
final KsqlScalarFunction udf = mock(KsqlScalarFunction.class);
givenUdf("FOO", udfFactory, udf);
givenUdf("FOO", udfFactory, udf, SqlTypes.STRING);
when(udf.parameters()).thenReturn(ImmutableList.of(ParamTypes.DOUBLE, ParamTypes.LONG));

// When:
Expand Down Expand Up @@ -312,7 +308,7 @@ public void shouldImplicitlyCastFunctionCallParametersVariadic() {
// Given:
final UdfFactory udfFactory = mock(UdfFactory.class);
final KsqlScalarFunction udf = mock(KsqlScalarFunction.class);
givenUdf("FOO", udfFactory, udf);
givenUdf("FOO", udfFactory, udf, SqlTypes.STRING);
when(udf.parameters()).thenReturn(ImmutableList.of(ParamTypes.DOUBLE, ArrayType.of(ParamTypes.LONG)));
when(udf.isVariadic()).thenReturn(true);

Expand Down Expand Up @@ -344,7 +340,7 @@ public void shouldHandleFunctionCallsWithGenerics() {
// Given:
final UdfFactory udfFactory = mock(UdfFactory.class);
final KsqlScalarFunction udf = mock(KsqlScalarFunction.class);
givenUdf("FOO", udfFactory, udf);
givenUdf("FOO", udfFactory, udf, SqlTypes.STRING);
when(udf.parameters()).thenReturn(ImmutableList.of(GenericType.of("T"), GenericType.of("T")));

// When:
Expand Down Expand Up @@ -893,12 +889,12 @@ public void shouldGenerateCorrectCodeForInPredicate() {
}

@Test
public void shouldGenerateCorrectCodeForTransformLambdaExpression() {
public void shouldGenerateCorrectCodeForLambdaExpression() {
// Given:
final UdfFactory udfFactory = mock(UdfFactory.class);
final KsqlScalarFunction udf = mock(KsqlScalarFunction.class);
givenUdf("ABS", udfFactory, udf);
givenUdf("TRANSFORM", udfFactory, udf);
givenUdf("ABS", udfFactory, udf, SqlTypes.STRING);
givenUdf("TRANSFORM", udfFactory, udf, SqlTypes.STRING);
when(udf.parameters()).
thenReturn(ImmutableList.of(
ArrayType.of(ParamTypes.DOUBLE),
Expand All @@ -925,11 +921,11 @@ javaExpression, equalTo(
}

@Test
public void shouldGenerateCorrectCodeForReduceLambdaExpression() {
public void shouldGenerateCorrectCodeForLambdaExpressionWithTwoArguments() {
// Given:
final UdfFactory udfFactory = mock(UdfFactory.class);
final KsqlScalarFunction udf = mock(KsqlScalarFunction.class);
givenUdf("REDUCE", udfFactory, udf);
givenUdf("REDUCE", udfFactory, udf, SqlTypes.STRING);
when(udf.parameters()).
thenReturn(ImmutableList.of(
ArrayType.of(ParamTypes.DOUBLE),
Expand Down Expand Up @@ -967,6 +963,155 @@ javaExpression, equalTo(
" }\n" +
"}))"));
}

@Test
public void shouldGenerateCorrectCodeForFunctionWithMultipleLambdas() {
// Given:
final UdfFactory udfFactory = mock(UdfFactory.class);
final KsqlScalarFunction udf = mock(KsqlScalarFunction.class);
givenUdf("function", udfFactory, udf, SqlTypes.STRING);
when(udf.parameters()).
thenReturn(ImmutableList.of(
ArrayType.of(ParamTypes.DOUBLE),
ParamTypes.STRING,
LambdaType.of(
ImmutableList.of(ParamTypes.DOUBLE, ParamTypes.STRING),
ParamTypes.DOUBLE),
LambdaType.of(
ImmutableList.of(ParamTypes.DOUBLE, ParamTypes.STRING),
ParamTypes.STRING)
));

final Expression expression = new FunctionCall (
FunctionName.of("function"),
ImmutableList.of(
ARRAYCOL,
COL1,
new LambdaFunctionCall(
ImmutableList.of("X", "S"),
new ArithmeticBinaryExpression(
Operator.ADD,
new LambdaVariable("X"),
new LambdaVariable("X"))
),
new LambdaFunctionCall(
ImmutableList.of("X", "S"),
new SearchedCaseExpression(
ImmutableList.of(
new WhenClause(
new ComparisonExpression(
ComparisonExpression.Type.LESS_THAN, new LambdaVariable("X"), new IntegerLiteral(10)),
new StringLiteral("test")
),
new WhenClause(
new ComparisonExpression(
ComparisonExpression.Type.LESS_THAN, new LambdaVariable("X"), new IntegerLiteral(100)),
new StringLiteral("test2")
)
),
Optional.of(new LambdaVariable("S"))
)
)));

// When:
final String javaExpression = sqlToJavaVisitor.process(expression);

// Then
assertThat(
javaExpression, equalTo("((String) function_0.evaluate(COL4, COL1, new BiFunction() {\n"
+ " @Override\n"
+ " public Object apply(Object arg1, Object arg2) {\n"
+ " final Double X = (Double) arg1;\n"
+ " final String S = (String) arg2;\n"
+ " return (X + X);\n"
+ " }\n"
+ "}, new BiFunction() {\n"
+ " @Override\n"
+ " public Object apply(Object arg1, Object arg2) {\n"
+ " final Double X = (Double) arg1;\n"
+ " final String S = (String) arg2;\n"
+ " return ((java.lang.String)SearchedCaseFunction.searchedCaseFunction(ImmutableList.copyOf(Arrays.asList( SearchedCaseFunction.whenClause( new Supplier<Boolean>() { @Override public Boolean get() { return ((((Object)(X)) == null || ((Object)(10)) == null) ? false : (X < 10)); }}, new Supplier<java.lang.String>() { @Override public java.lang.String get() { return \"test\"; }}), SearchedCaseFunction.whenClause( new Supplier<Boolean>() { @Override public Boolean get() { return ((((Object)(X)) == null || ((Object)(100)) == null) ? false : (X < 100)); }}, new Supplier<java.lang.String>() { @Override public java.lang.String get() { return \"test2\"; }}))), new Supplier<java.lang.String>() { @Override public java.lang.String get() { return S; }}));\n"
+ " }\n"
+ "}))"));
}

@Test
public void shouldGenerateCorrectCodeForNestedLambdas() {
// Given:
final UdfFactory udfFactory = mock(UdfFactory.class);
final KsqlScalarFunction udf = mock(KsqlScalarFunction.class);
givenUdf("nested", udfFactory, udf, SqlTypes.DOUBLE);
when(udf.parameters()).
thenReturn(ImmutableList.of(
ArrayType.of(ParamTypes.DOUBLE),
ParamTypes.DOUBLE,
LambdaType.of(
ImmutableList.of(ParamTypes.DOUBLE, ParamTypes.DOUBLE),
ParamTypes.DOUBLE))
);

final Expression expression = new ArithmeticBinaryExpression(
Operator.ADD,
new FunctionCall(
FunctionName.of("nested"),
ImmutableList.of(
ARRAYCOL,
new IntegerLiteral(0),
new LambdaFunctionCall(
ImmutableList.of("A", "B"),
new ArithmeticBinaryExpression(
Operator.ADD,
new FunctionCall(
FunctionName.of("nested"),
ImmutableList.of(
ARRAYCOL,
new IntegerLiteral(0),
new LambdaFunctionCall(
ImmutableList.of("Q", "V"),
new ArithmeticBinaryExpression(
Operator.ADD,
new LambdaVariable("Q"),
new LambdaVariable("V"))
))),
new LambdaVariable("B"))
))),
new IntegerLiteral(5)
);

// When:
final String javaExpression = sqlToJavaVisitor.process(expression);

// Then
assertThat(
javaExpression, equalTo(
"(((Double) nested_0.evaluate(COL4, (Double)NullSafe.apply(0,new Function() {\n"
+ " @Override\n"
+ " public Object apply(Object arg1) {\n"
+ " final Integer val = (Integer) arg1;\n"
+ " return val.doubleValue();\n"
+ " }\n"
+ "}), new BiFunction() {\n"
+ " @Override\n"
+ " public Object apply(Object arg1, Object arg2) {\n"
+ " final Double A = (Double) arg1;\n"
+ " final Integer B = (Integer) arg2;\n"
+ " return (((Double) nested_1.evaluate(COL4, (Double)NullSafe.apply(0,new Function() {\n"
+ " @Override\n"
+ " public Object apply(Object arg1) {\n"
+ " final Integer val = (Integer) arg1;\n"
+ " return val.doubleValue();\n"
+ " }\n"
+ "}), new BiFunction() {\n"
+ " @Override\n"
+ " public Object apply(Object arg1, Object arg2) {\n"
+ " final Double Q = (Double) arg1;\n"
+ " final Integer V = (Integer) arg2;\n"
+ " return (Q + V);\n"
+ " }\n"
+ "})) + B);\n"
+ " }\n"
+ "})) + 5)"));
}

@Test
public void shouldThrowErrorOnEmptyLambdaInput() {
Expand Down Expand Up @@ -1005,12 +1150,15 @@ public void shouldThrowOnTimeLiteral() {
}

private void givenUdf(
final String name, final UdfFactory factory, final KsqlScalarFunction function
final String name,
final UdfFactory factory,
final KsqlScalarFunction function,
final SqlType returnType
) {
when(functionRegistry.isAggregate(FunctionName.of(name))).thenReturn(false);
when(functionRegistry.getUdfFactory(FunctionName.of(name))).thenReturn(factory);
when(factory.getFunction(anyList())).thenReturn(function);
when(function.getReturnType(anyList())).thenReturn(SqlTypes.STRING);
when(function.getReturnType(anyList())).thenReturn(returnType);
final UdfMetadata metadata = mock(UdfMetadata.class);
when(factory.getMetadata()).thenReturn(metadata);
}
Expand Down
Loading

0 comments on commit 8225a82

Please sign in to comment.