diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java index 799f29f248e7..35f23ca6647b 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java @@ -183,23 +183,33 @@ public Void visitLikePredicate(final LikePredicate node, final TypeContext conte @Override public Void visitFunctionCall(final FunctionCall node, final TypeContext context) { - final List argumentTypes = new ArrayList<>(); final FunctionName functionName = node.getName(); + + final TypeContext contextCopy = context.getCopy(); + final List argumentTypes = new ArrayList<>(); + final List typeContextsForChildren = new ArrayList<>(); final boolean hasLambda = node.hasLambdaFunctionCallArguments(); for (final Expression argExpr : node.getArguments()) { - final TypeContext childContext = context.getCopy(); + final TypeContext childContext; + if (argExpr instanceof LambdaFunctionCall) { + childContext = contextCopy.getCopy(); + } else { + childContext = context.getCopy(); + } + + typeContextsForChildren.add(childContext); final SqlType resolvedArgType = - expressionTypeManager.getExpressionSqlType(argExpr, childContext); + expressionTypeManager.getExpressionSqlType(argExpr, childContext.getCopy()); if (argExpr instanceof LambdaFunctionCall) { argumentTypes.add( SqlArgument.of( - SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType()))); + SqlLambda.of(contextCopy.getLambdaInputTypes(), resolvedArgType))); } else { argumentTypes.add(SqlArgument.of(resolvedArgType)); // for lambdas - we save the type information to resolve the lambda generics if (hasLambda) { - context.visitType(resolvedArgType); + contextCopy.visitType(resolvedArgType); } } } @@ -211,8 +221,9 @@ public Void visitFunctionCall(final FunctionCall node, final TypeContext context function.newInstance(ksqlConfig) ); - for (final Expression argExpr : node.getArguments()) { - process(argExpr, context.getCopy()); + final List arguments = node.getArguments(); + for (int i = 0; i < arguments.size(); i++) { + process(arguments.get(i), typeContextsForChildren.get(i)); } return null; } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index 501b16817ecc..36466f0336e1 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -452,21 +452,32 @@ public Pair visitFunctionCall( final String instanceName = funNameToCodeName.apply(functionName); final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName()); + + final TypeContext contextCopy = context.getCopy(); final List argumentSchemas = new ArrayList<>(); + final List typeContextsForChildren = new ArrayList<>(); final boolean hasLambda = node.hasLambdaFunctionCallArguments(); + for (final Expression argExpr : node.getArguments()) { - final TypeContext childContext = context.getCopy(); + final TypeContext childContext; + if (argExpr instanceof LambdaFunctionCall) { + childContext = contextCopy.getCopy(); + } else { + childContext = context.getCopy(); + } + + typeContextsForChildren.add(childContext.getCopy()); final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, childContext); if (argExpr instanceof LambdaFunctionCall) { argumentSchemas.add( SqlArgument.of( - SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType()))); + SqlLambda.of(contextCopy.getLambdaInputTypes(), resolvedArgType))); } else { argumentSchemas.add(SqlArgument.of(resolvedArgType)); // for lambdas - we save the type information to resolve the lambda generics if (hasLambda) { - context.visitType(resolvedArgType); + contextCopy.visitType(resolvedArgType); } } } @@ -494,7 +505,7 @@ public Pair visitFunctionCall( } joiner.add( - process(convertArgument(arg, sqlType, paramType), context.getCopy()) + process(convertArgument(arg, sqlType, paramType),typeContextsForChildren.get(i)) .getLeft()); } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index 688a150838f6..f1364fda4b86 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -474,22 +474,30 @@ public Void visitFunctionCall( final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName()); final List argTypes = new ArrayList<>(); + final TypeContext contextCopy = expressionTypeContext.getCopy(); final boolean hasLambda = node.hasLambdaFunctionCallArguments(); for (final Expression expression : node.getArguments()) { - final TypeContext childContext = expressionTypeContext.getCopy(); + final TypeContext childContext; + if (expression instanceof LambdaFunctionCall) { + childContext = contextCopy.getCopy(); + } else { + childContext = expressionTypeContext.getCopy(); + } process(expression, childContext); final SqlType resolvedArgType = childContext.getSqlType(); + if (expression instanceof LambdaFunctionCall) { argTypes.add( SqlArgument.of( - SqlLambda.of(expressionTypeContext.getLambdaInputTypes(), - childContext.getSqlType()))); + SqlLambda.of( + contextCopy.getLambdaInputTypes(), + resolvedArgType))); } else { argTypes.add(SqlArgument.of(resolvedArgType)); // for lambdas - we save the type information to resolve the lambda generics if (hasLambda) { - expressionTypeContext.visitType(resolvedArgType); + contextCopy.visitType(resolvedArgType); } } } diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java index 74ddbe887910..c1c128a656d3 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java @@ -30,7 +30,6 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.Mockito.mock; @@ -65,7 +64,6 @@ import io.confluent.ksql.execution.expression.tree.StringLiteral; import io.confluent.ksql.execution.expression.tree.SubscriptExpression; import io.confluent.ksql.execution.expression.tree.TimeLiteral; -import io.confluent.ksql.execution.expression.tree.TimestampLiteral; import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.WhenClause; import io.confluent.ksql.function.FunctionRegistry; @@ -74,18 +72,16 @@ import io.confluent.ksql.function.types.ArrayType; import io.confluent.ksql.function.types.GenericType; import io.confluent.ksql.function.types.LambdaType; -import io.confluent.ksql.function.types.MapType; -import io.confluent.ksql.function.types.ParamType; import io.confluent.ksql.function.types.ParamTypes; import io.confluent.ksql.function.udf.UdfMetadata; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.schema.Operator; import io.confluent.ksql.schema.ksql.types.SqlPrimitiveType; +import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.util.KsqlConfig; import java.math.BigDecimal; -import java.sql.Timestamp; import java.util.Collections; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; @@ -243,10 +239,10 @@ public void shouldPostfixFunctionInstancesWithUniqueId() { final KsqlScalarFunction ssFunction = mock(KsqlScalarFunction.class); final UdfFactory catFactory = mock(UdfFactory.class); final KsqlScalarFunction catFunction = mock(KsqlScalarFunction.class); - givenUdf("SUBSTRING", ssFactory, ssFunction); + givenUdf("SUBSTRING", ssFactory, ssFunction, SqlTypes.STRING); when(ssFunction.parameters()) .thenReturn(ImmutableList.of(ParamTypes.STRING, ParamTypes.INTEGER, ParamTypes.INTEGER)); - givenUdf("CONCAT", catFactory, catFunction); + givenUdf("CONCAT", catFactory, catFunction, SqlTypes.STRING); when(catFunction.parameters()) .thenReturn(ImmutableList.of(ParamTypes.STRING, ParamTypes.STRING)); final FunctionName ssName = FunctionName.of("SUBSTRING"); @@ -284,7 +280,7 @@ public void shouldImplicitlyCastFunctionCallParameters() { // Given: final UdfFactory udfFactory = mock(UdfFactory.class); final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); - givenUdf("FOO", udfFactory, udf); + givenUdf("FOO", udfFactory, udf, SqlTypes.STRING); when(udf.parameters()).thenReturn(ImmutableList.of(ParamTypes.DOUBLE, ParamTypes.LONG)); // When: @@ -312,7 +308,7 @@ public void shouldImplicitlyCastFunctionCallParametersVariadic() { // Given: final UdfFactory udfFactory = mock(UdfFactory.class); final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); - givenUdf("FOO", udfFactory, udf); + givenUdf("FOO", udfFactory, udf, SqlTypes.STRING); when(udf.parameters()).thenReturn(ImmutableList.of(ParamTypes.DOUBLE, ArrayType.of(ParamTypes.LONG))); when(udf.isVariadic()).thenReturn(true); @@ -344,7 +340,7 @@ public void shouldHandleFunctionCallsWithGenerics() { // Given: final UdfFactory udfFactory = mock(UdfFactory.class); final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); - givenUdf("FOO", udfFactory, udf); + givenUdf("FOO", udfFactory, udf, SqlTypes.STRING); when(udf.parameters()).thenReturn(ImmutableList.of(GenericType.of("T"), GenericType.of("T"))); // When: @@ -893,12 +889,12 @@ public void shouldGenerateCorrectCodeForInPredicate() { } @Test - public void shouldGenerateCorrectCodeForTransformLambdaExpression() { + public void shouldGenerateCorrectCodeForLambdaExpression() { // Given: final UdfFactory udfFactory = mock(UdfFactory.class); final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); - givenUdf("ABS", udfFactory, udf); - givenUdf("TRANSFORM", udfFactory, udf); + givenUdf("ABS", udfFactory, udf, SqlTypes.STRING); + givenUdf("TRANSFORM", udfFactory, udf, SqlTypes.STRING); when(udf.parameters()). thenReturn(ImmutableList.of( ArrayType.of(ParamTypes.DOUBLE), @@ -925,11 +921,11 @@ javaExpression, equalTo( } @Test - public void shouldGenerateCorrectCodeForReduceLambdaExpression() { + public void shouldGenerateCorrectCodeForLambdaExpressionWithTwoArguments() { // Given: final UdfFactory udfFactory = mock(UdfFactory.class); final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); - givenUdf("REDUCE", udfFactory, udf); + givenUdf("REDUCE", udfFactory, udf, SqlTypes.STRING); when(udf.parameters()). thenReturn(ImmutableList.of( ArrayType.of(ParamTypes.DOUBLE), @@ -967,6 +963,155 @@ javaExpression, equalTo( " }\n" + "}))")); } + + @Test + public void shouldGenerateCorrectCodeForFunctionWithMultipleLambdas() { + // Given: + final UdfFactory udfFactory = mock(UdfFactory.class); + final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); + givenUdf("function", udfFactory, udf, SqlTypes.STRING); + when(udf.parameters()). + thenReturn(ImmutableList.of( + ArrayType.of(ParamTypes.DOUBLE), + ParamTypes.STRING, + LambdaType.of( + ImmutableList.of(ParamTypes.DOUBLE, ParamTypes.STRING), + ParamTypes.DOUBLE), + LambdaType.of( + ImmutableList.of(ParamTypes.DOUBLE, ParamTypes.STRING), + ParamTypes.STRING) + )); + + final Expression expression = new FunctionCall ( + FunctionName.of("function"), + ImmutableList.of( + ARRAYCOL, + COL1, + new LambdaFunctionCall( + ImmutableList.of("X", "S"), + new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaVariable("X"), + new LambdaVariable("X")) + ), + new LambdaFunctionCall( + ImmutableList.of("X", "S"), + new SearchedCaseExpression( + ImmutableList.of( + new WhenClause( + new ComparisonExpression( + ComparisonExpression.Type.LESS_THAN, new LambdaVariable("X"), new IntegerLiteral(10)), + new StringLiteral("test") + ), + new WhenClause( + new ComparisonExpression( + ComparisonExpression.Type.LESS_THAN, new LambdaVariable("X"), new IntegerLiteral(100)), + new StringLiteral("test2") + ) + ), + Optional.of(new LambdaVariable("S")) + ) + ))); + + // When: + final String javaExpression = sqlToJavaVisitor.process(expression); + + // Then + assertThat( + javaExpression, equalTo("((String) function_0.evaluate(COL4, COL1, new BiFunction() {\n" + + " @Override\n" + + " public Object apply(Object arg1, Object arg2) {\n" + + " final Double X = (Double) arg1;\n" + + " final String S = (String) arg2;\n" + + " return (X + X);\n" + + " }\n" + + "}, new BiFunction() {\n" + + " @Override\n" + + " public Object apply(Object arg1, Object arg2) {\n" + + " final Double X = (Double) arg1;\n" + + " final String S = (String) arg2;\n" + + " return ((java.lang.String)SearchedCaseFunction.searchedCaseFunction(ImmutableList.copyOf(Arrays.asList( SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(X)) == null || ((Object)(10)) == null) ? false : (X < 10)); }}, new Supplier() { @Override public java.lang.String get() { return \"test\"; }}), SearchedCaseFunction.whenClause( new Supplier() { @Override public Boolean get() { return ((((Object)(X)) == null || ((Object)(100)) == null) ? false : (X < 100)); }}, new Supplier() { @Override public java.lang.String get() { return \"test2\"; }}))), new Supplier() { @Override public java.lang.String get() { return S; }}));\n" + + " }\n" + + "}))")); + } + + @Test + public void shouldGenerateCorrectCodeForNestedLambdas() { + // Given: + final UdfFactory udfFactory = mock(UdfFactory.class); + final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); + givenUdf("nested", udfFactory, udf, SqlTypes.DOUBLE); + when(udf.parameters()). + thenReturn(ImmutableList.of( + ArrayType.of(ParamTypes.DOUBLE), + ParamTypes.DOUBLE, + LambdaType.of( + ImmutableList.of(ParamTypes.DOUBLE, ParamTypes.DOUBLE), + ParamTypes.DOUBLE)) + ); + + final Expression expression = new ArithmeticBinaryExpression( + Operator.ADD, + new FunctionCall( + FunctionName.of("nested"), + ImmutableList.of( + ARRAYCOL, + new IntegerLiteral(0), + new LambdaFunctionCall( + ImmutableList.of("A", "B"), + new ArithmeticBinaryExpression( + Operator.ADD, + new FunctionCall( + FunctionName.of("nested"), + ImmutableList.of( + ARRAYCOL, + new IntegerLiteral(0), + new LambdaFunctionCall( + ImmutableList.of("Q", "V"), + new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaVariable("Q"), + new LambdaVariable("V")) + ))), + new LambdaVariable("B")) + ))), + new IntegerLiteral(5) + ); + + // When: + final String javaExpression = sqlToJavaVisitor.process(expression); + + // Then + assertThat( + javaExpression, equalTo( + "(((Double) nested_0.evaluate(COL4, (Double)NullSafe.apply(0,new Function() {\n" + + " @Override\n" + + " public Object apply(Object arg1) {\n" + + " final Integer val = (Integer) arg1;\n" + + " return val.doubleValue();\n" + + " }\n" + + "}), new BiFunction() {\n" + + " @Override\n" + + " public Object apply(Object arg1, Object arg2) {\n" + + " final Double A = (Double) arg1;\n" + + " final Integer B = (Integer) arg2;\n" + + " return (((Double) nested_1.evaluate(COL4, (Double)NullSafe.apply(0,new Function() {\n" + + " @Override\n" + + " public Object apply(Object arg1) {\n" + + " final Integer val = (Integer) arg1;\n" + + " return val.doubleValue();\n" + + " }\n" + + "}), new BiFunction() {\n" + + " @Override\n" + + " public Object apply(Object arg1, Object arg2) {\n" + + " final Double Q = (Double) arg1;\n" + + " final Integer V = (Integer) arg2;\n" + + " return (Q + V);\n" + + " }\n" + + "})) + B);\n" + + " }\n" + + "})) + 5)")); + } @Test public void shouldThrowErrorOnEmptyLambdaInput() { @@ -1005,12 +1150,15 @@ public void shouldThrowOnTimeLiteral() { } private void givenUdf( - final String name, final UdfFactory factory, final KsqlScalarFunction function + final String name, + final UdfFactory factory, + final KsqlScalarFunction function, + final SqlType returnType ) { when(functionRegistry.isAggregate(FunctionName.of(name))).thenReturn(false); when(functionRegistry.getUdfFactory(FunctionName.of(name))).thenReturn(factory); when(factory.getFunction(anyList())).thenReturn(function); - when(function.getReturnType(anyList())).thenReturn(SqlTypes.STRING); + when(function.getReturnType(anyList())).thenReturn(returnType); final UdfMetadata metadata = mock(UdfMetadata.class); when(factory.getMetadata()).thenReturn(metadata); } diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java index c5c8ffed4ec6..61bc47f9c260 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java @@ -524,6 +524,78 @@ public void shouldFailToEvaluateLambdaWithMismatchedArgumentNumber() { "Was expecting 1 arguments but found 2, [X, Y]. Check your lambda statement.")); } + @Test + public void shouldHandleMultipleLambdasInSameFunctionCallWithDifferentVariableNames() { + // Given: + givenUdfWithNameAndReturnType("TRANSFORM", SqlTypes.INTEGER); + final Expression expression = new ArithmeticBinaryExpression( + Operator.ADD, + new FunctionCall( + FunctionName.of("TRANSFORM"), + ImmutableList.of( + ARRAYCOL, + new IntegerLiteral(0), + new LambdaFunctionCall( + ImmutableList.of("A", "B"), + new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaVariable("A"), + new LambdaVariable("B")) + ), + new LambdaFunctionCall( + ImmutableList.of("K", "V"), + new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaVariable("K"), + new LambdaVariable("V")) + ))), + new IntegerLiteral(5) + ); + + // When: + final SqlType result = expressionTypeManager.getExpressionSqlType(expression); + + assertThat(result, is(SqlTypes.INTEGER)); + } + + @Test + public void shouldHandleNestedLambdas() { + // Given: + givenUdfWithNameAndReturnType("TRANSFORM", SqlTypes.INTEGER); + final Expression expression = new ArithmeticBinaryExpression( + Operator.ADD, + new FunctionCall( + FunctionName.of("TRANSFORM"), + ImmutableList.of( + ARRAYCOL, + new IntegerLiteral(0), + new LambdaFunctionCall( + ImmutableList.of("A", "B"), + new ArithmeticBinaryExpression( + Operator.ADD, + new FunctionCall( + FunctionName.of("TRANSFORM"), + ImmutableList.of( + ARRAYCOL, + new IntegerLiteral(0), + new LambdaFunctionCall( + ImmutableList.of("Q", "V"), + new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaVariable("Q"), + new LambdaVariable("V")) + ))), + new LambdaVariable("B")) + ))), + new IntegerLiteral(5) + ); + + // When: + final SqlType result = expressionTypeManager.getExpressionSqlType(expression); + + assertThat(result, is(SqlTypes.INTEGER)); + } + @Test public void shouldHandleStructFieldDereference() { // Given: