From 3ff3dcaca752f0e0d28c0854465d27034f5c6db8 Mon Sep 17 00:00:00 2001 From: Steven Zhang Date: Tue, 16 Feb 2021 15:13:26 -0800 Subject: [PATCH 1/6] feat: add lambda type checking without adding lambda sql type --- .../confluent/ksql/function/GenericsUtil.java | 44 +++++++++- .../io/confluent/ksql/function/UdfIndex.java | 11 +-- .../ksql/function/types/ParamTypes.java | 27 ++++++- .../confluent/ksql/function/UdfIndexTest.java | 4 +- .../ksql/function/FunctionLoaderUtils.java | 22 ++--- .../ksql/function/udf/map/ReduceMap.java | 2 +- .../ksql/execution/codegen/CodeGenRunner.java | 51 +++++++++--- .../execution/codegen/SqlToJavaVisitor.java | 69 +++++++++++++--- .../execution/codegen/helpers/LambdaUtil.java | 2 +- .../execution/util/ExpressionTypeManager.java | 38 +++++++-- .../streaming/StreamedQueryResource.java | 4 +- .../ksql/schema/ksql/SqlArgument.java | 28 +++++-- .../ksql/schema/ksql/types/SqlLambda.java | 81 +++++++++++++++++++ 13 files changed, 328 insertions(+), 55 deletions(-) create mode 100644 ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java index 16580663b96d..9b38579dbd5b 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java @@ -20,12 +20,14 @@ import com.google.common.collect.Sets; import io.confluent.ksql.function.types.ArrayType; import io.confluent.ksql.function.types.GenericType; +import io.confluent.ksql.function.types.LambdaType; import io.confluent.ksql.function.types.MapType; import io.confluent.ksql.function.types.ParamType; import io.confluent.ksql.function.types.StructType; import io.confluent.ksql.schema.ksql.SchemaConverters; import io.confluent.ksql.schema.ksql.SqlArgument; import io.confluent.ksql.schema.ksql.types.SqlArray; +import io.confluent.ksql.schema.ksql.types.SqlLambda; import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlStruct.Builder; import io.confluent.ksql.schema.ksql.types.SqlType; @@ -35,6 +37,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -77,6 +80,14 @@ public static Set constituentGenerics(final ParamType type) { .collect(Collectors.toSet()); } else if (type instanceof GenericType) { return ImmutableSet.of(type); + } else if (type instanceof LambdaType) { + final Set inputSet = new HashSet<>(); + for (final ParamType paramType: ((LambdaType) type).inputTypes()) { + inputSet.addAll(constituentGenerics(paramType)); + } + return Sets.union( + inputSet, + constituentGenerics(((LambdaType) type).returnType())); } else { return ImmutableSet.of(); } @@ -173,12 +184,14 @@ public static Map resolveGenerics( return ImmutableMap.copyOf(mapping); } + // CHECKSTYLE_RULES.OFF: NPathComplexity // CHECKSTYLE_RULES.OFF: CyclomaticComplexity private static boolean resolveGenerics( final List> mapping, final ParamType schema, final SqlArgument instance ) { + // CHECKSTYLE_RULES.ON: NPathComplexity // CHECKSTYLE_RULES.ON: CyclomaticComplexity final SqlType sqlType = instance.getSqlType(); @@ -202,7 +215,9 @@ private static boolean resolveGenerics( if (schema instanceof ArrayType) { final SqlArray sqlArray = (SqlArray) sqlType; return resolveGenerics( - mapping, ((ArrayType) schema).element(), SqlArgument.of(sqlArray.getItemType())); + mapping, + ((ArrayType) schema).element(), + SqlArgument.of(((SqlArray) sqlType).getItemType())); } if (schema instanceof MapType) { @@ -216,10 +231,35 @@ private static boolean resolveGenerics( throw new KsqlException("Generic STRUCT is not yet supported"); } + if (schema instanceof LambdaType) { + final LambdaType lambdaType = (LambdaType) schema; + final SqlLambda sqlLambda = instance.getSqlLambda(); + boolean resolvedInputs = true; + if (sqlLambda.getInputType().size() != lambdaType.inputTypes().size()) { + throw new KsqlException( + "Number of lambda arguments don't match between schema and sql type"); + } + + int i = 0; + for (final ParamType paramType : lambdaType.inputTypes()) { + resolvedInputs = + resolvedInputs && resolveGenerics( + mapping, paramType, SqlArgument.of(sqlLambda.getInputType().get(i)) + ); + i++; + } + return resolvedInputs && resolveGenerics( + mapping, lambdaType.returnType(), SqlArgument.of(sqlLambda.getReturnType()) + ); + } + return true; } private static boolean matches(final ParamType schema, final SqlArgument instance) { + if (schema instanceof LambdaType && instance.getSqlLambda() != null) { + return true; + } final ParamType instanceParamType = SchemaConverters .sqlToFunctionConverter().toFunctionType(instance.getSqlType()); return schema.getClass() == instanceParamType.getClass(); @@ -233,7 +273,7 @@ private static boolean matches(final ParamType schema, final SqlArgument instanc public static boolean instanceOf(final ParamType schema, final SqlArgument instance) { final List> mappings = new ArrayList<>(); - if (!resolveGenerics(mappings, schema, instance)) { + if (!resolveGenerics(mappings, schema, SqlArgument.of(instance.getSqlType()))) { return false; } diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java index 3e01e4ab0de3..c96fdbc4295f 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java @@ -19,6 +19,7 @@ import com.google.common.collect.Iterables; import io.confluent.ksql.function.types.ArrayType; import io.confluent.ksql.function.types.GenericType; +import io.confluent.ksql.function.types.LambdaType; import io.confluent.ksql.function.types.ParamType; import io.confluent.ksql.function.types.ParamTypes; import io.confluent.ksql.schema.ksql.SqlArgument; @@ -351,7 +352,7 @@ public int hashCode() { // CHECKSTYLE_RULES.OFF: BooleanExpressionComplexity boolean accepts(final SqlArgument argument, final Map reservedGenerics, final boolean allowCasts) { - if (argument == null || argument.getSqlType() == null) { + if (argument == null || (argument.getSqlLambda() == null && argument.getSqlType() == null)) { return true; } @@ -368,12 +369,12 @@ private static boolean reserveGenerics( final SqlArgument argument, final Map reservedGenerics ) { - if (!GenericsUtil.instanceOf(schema, argument)) { + if (!(schema instanceof LambdaType) + && !GenericsUtil.instanceOf(schema, argument)) { return false; } - - final Map genericMapping = GenericsUtil - .resolveGenerics(schema, argument); + final Map genericMapping = + GenericsUtil.resolveGenerics(schema, argument); for (final Entry entry : genericMapping.entrySet()) { final SqlType old = reservedGenerics.putIfAbsent(entry.getKey(), entry.getValue()); diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java index ec7c29d31a61..a5706635b076 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java @@ -20,6 +20,7 @@ import io.confluent.ksql.schema.ksql.SqlArgument; import io.confluent.ksql.schema.ksql.types.SqlArray; import io.confluent.ksql.schema.ksql.types.SqlBaseType; +import io.confluent.ksql.schema.ksql.types.SqlLambda; import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlStruct; import io.confluent.ksql.schema.ksql.types.SqlStruct.Field; @@ -44,12 +45,17 @@ public static boolean areCompatible(final SqlType actual, final ParamType declar return areCompatible(SqlArgument.of(actual), declared, false); } + // CHECKSTYLE_RULES.OFF: CyclomaticComplexity + // CHECKSTYLE_RULES.OFF: NPathComplexity public static boolean areCompatible( final SqlArgument argument, final ParamType declared, final boolean allowCast ) { + // CHECKSTYLE_RULES.ON: CyclomaticComplexity + // CHECKSTYLE_RULES.ON: NPathComplexity final SqlType argumentSqlType = argument.getSqlType(); + final SqlLambda sqlLambda = argument.getSqlLambda(); if (argumentSqlType.baseType() == SqlBaseType.ARRAY && declared instanceof ArrayType) { return areCompatible( SqlArgument.of(((SqlArray) argumentSqlType).getItemType()), @@ -61,13 +67,32 @@ public static boolean areCompatible( final SqlMap sqlType = (SqlMap) argumentSqlType; final MapType mapType = (MapType) declared; return areCompatible(SqlArgument.of(sqlType.getKeyType()), mapType.key(), allowCast) - && areCompatible(SqlArgument.of(sqlType.getValueType()), mapType.value(), allowCast); + && areCompatible( + SqlArgument.of(sqlType.getValueType()), + mapType.value(), + allowCast + ); } if (argumentSqlType.baseType() == SqlBaseType.STRUCT && declared instanceof StructType) { return isStructCompatible(argumentSqlType, declared); } + if (sqlLambda != null && declared instanceof LambdaType) { + final LambdaType declaredLambda = (LambdaType) declared; + if (sqlLambda.getInputType().size() != declaredLambda.inputTypes().size()) { + return false; + } + int i = 0; + for (final ParamType paramType: declaredLambda.inputTypes()) { + if (!areCompatible(sqlLambda.getInputType().get(i), paramType)) { + return false; + } + i++; + } + return areCompatible(sqlLambda.getReturnType(), declaredLambda.returnType()); + } + return isPrimitiveMatch(argumentSqlType, declared, allowCast); } diff --git a/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java b/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java index 39a9576c6936..cdcc012d6ae9 100644 --- a/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java +++ b/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java @@ -459,7 +459,7 @@ public void shouldChooseNonVarargWithNullValuesOfDifferingSchemas() { ); // When: - final KsqlScalarFunction fun = udfIndex.getFunction(Arrays.asList(SqlArgument.of(null), SqlArgument.of(null))); + final KsqlScalarFunction fun = udfIndex.getFunction(Arrays.asList(SqlArgument.of(null, null), SqlArgument.of(null, null))); // Then: assertThat(fun.name(), equalTo(EXPECTED)); @@ -474,7 +474,7 @@ public void shouldChooseNonVarargWithNullValuesOfSameSchemas() { ); // When: - final KsqlScalarFunction fun = udfIndex.getFunction(Arrays.asList(SqlArgument.of(null), SqlArgument.of(null))); + final KsqlScalarFunction fun = udfIndex.getFunction(Arrays.asList(SqlArgument.of(null, null), SqlArgument.of(null, null))); // Then: assertThat(fun.name(), equalTo(EXPECTED)); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java index fa55003aed9e..1582e55c9009 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java @@ -18,6 +18,7 @@ import com.google.common.annotations.VisibleForTesting; import io.confluent.ksql.execution.function.UdfUtil; import io.confluent.ksql.function.types.GenericType; +import io.confluent.ksql.function.types.LambdaType; import io.confluent.ksql.function.types.ParamType; import io.confluent.ksql.function.types.ParamTypes; import io.confluent.ksql.function.udf.Udf; @@ -178,15 +179,18 @@ static SchemaProvider handleUdfReturnSchema( final Map genericMapping = new HashMap<>(); for (int i = 0; i < Math.min(parameters.size(), arguments.size()); i++) { final ParamType schema = parameters.get(i); - - // we resolve any variadic as if it were an array so that the type - // structure matches the input type - final SqlType instance = isVariadic && i == parameters.size() - 1 - ? SqlTypes.array(arguments.get(i).getSqlType()) - : arguments.get(i).getSqlType(); - - genericMapping.putAll( - GenericsUtil.resolveGenerics(schema, SqlArgument.of(instance))); + if (schema instanceof LambdaType) { + genericMapping.putAll(GenericsUtil.resolveGenerics(schema, arguments.get(i))); + } else { + // we resolve any variadic as if it were an array so that the type + // structure matches the input type + final SqlType instance = isVariadic && i == parameters.size() - 1 + ? SqlTypes.array(arguments.get(i).getSqlType()) + : arguments.get(i).getSqlType(); + genericMapping.putAll( + GenericsUtil.resolveGenerics(schema, SqlArgument.of(instance)) + ); + } } return GenericsUtil.applyResolved(javaReturnSchema, genericMapping); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/map/ReduceMap.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/map/ReduceMap.java index 94f8b74b89a0..5bace9f9dbac 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/map/ReduceMap.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/map/ReduceMap.java @@ -28,7 +28,7 @@ * Reduce a map using an initial state and function */ @UdfDescription( - name = "map_reduce", + name = "reduce_map", category = FunctionCategory.MAP, description = "Reduce the input map down to a single value " + "using an initial state and a function. " diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java index 598d04ae49b4..75562c4e95a2 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java @@ -40,6 +40,9 @@ import io.confluent.ksql.schema.ksql.SchemaConverters; import io.confluent.ksql.schema.ksql.SchemaConverters.SqlToJavaTypeConverter; import io.confluent.ksql.schema.ksql.SqlArgument; +import io.confluent.ksql.schema.ksql.types.SqlArray; +import io.confluent.ksql.schema.ksql.types.SqlLambda; +import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; @@ -48,6 +51,7 @@ import java.util.Map.Entry; import java.util.stream.Collectors; import java.util.stream.Stream; + import org.apache.kafka.connect.data.Schema; import org.codehaus.commons.compiler.CompileException; import org.codehaus.commons.compiler.CompilerFactoryFactory; @@ -109,8 +113,8 @@ public CodeGenRunner( public CodeGenSpec getCodeGenSpec(final Expression expression) { final Visitor visitor = new Visitor(); - - visitor.process(expression, null); + final TypeContext context = new TypeContext(); + visitor.process(expression, context); return visitor.spec.build(); } @@ -140,9 +144,11 @@ public ExpressionMetadata buildCodeGenFromParseTree( return new ExpressionMetadata(ee, spec, returnType, expression); } catch (KsqlException | CompileException e) { + e.printStackTrace(); throw new KsqlException("Invalid " + type + ": " + e.getMessage() + ". expression:" + expression + ", schema:" + schema, e); } catch (final Exception e) { + e.printStackTrace(); throw new RuntimeException("Unexpected error generating code for " + type + ". expression:" + expression, e); } @@ -175,8 +181,8 @@ private Visitor() { @Override public Void visitLikePredicate(final LikePredicate node, final TypeContext context) { - process(node.getValue(), null); - process(node.getPattern(), null); + process(node.getValue(), context); + process(node.getPattern(), context); return null; } @@ -184,14 +190,36 @@ public Void visitLikePredicate(final LikePredicate node, final TypeContext conte public Void visitFunctionCall(final FunctionCall node, final TypeContext context) { final List argumentTypes = new ArrayList<>(); final FunctionName functionName = node.getName(); + final UdfFactory holder = functionRegistry.getUdfFactory(functionName); for (final Expression argExpr : node.getArguments()) { - process(argExpr, null); - argumentTypes.add(SqlArgument.of( - expressionTypeManager.getExpressionSqlType(argExpr) - )); + process(argExpr, context); + final SqlType newSqlType = expressionTypeManager.getExpressionSqlType(argExpr, context); + // for lambdas - if we see this it's the array/map being passed in we save the type + if (context.notAllInputsSeen()) { + if (newSqlType instanceof SqlArray) { + final SqlArray inputArray = (SqlArray) newSqlType; + context.addLambdaInputType(inputArray.getItemType()); + } else if (newSqlType instanceof SqlMap) { + final SqlMap inputMap = (SqlMap) newSqlType; + context.addLambdaInputType(inputMap.getKeyType()); + context.addLambdaInputType(inputMap.getValueType()); + } else { + context.addLambdaInputType(newSqlType); + } + } + + if (argExpr instanceof LambdaFunctionCall) { + argumentTypes.add( + SqlArgument.of( + SqlLambda.of(context.getLambdaInputTypes(), + newSqlType)) + ); + + } else { + argumentTypes.add(SqlArgument.of(newSqlType)); + } } - final UdfFactory holder = functionRegistry.getUdfFactory(functionName); final KsqlScalarFunction function = holder.getFunction(argumentTypes); spec.addFunction( function.name(), @@ -260,13 +288,14 @@ public Void visitUnqualifiedColumnReference( public Void visitDereferenceExpression( final DereferenceExpression node, final TypeContext context ) { - process(node.getBase(), null); + process(node.getBase(), context); return null; } @Override public Void visitLambdaExpression(final LambdaFunctionCall node, final TypeContext context) { - process(node.getBody(), null); + context.mapLambdaInputTypes(node.getArguments()); + process(node.getBody(), context); return null; } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index 0730d7f3a277..994b06ef643c 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -29,6 +29,7 @@ import io.confluent.ksql.execution.codegen.helpers.ArrayBuilder; import io.confluent.ksql.execution.codegen.helpers.CastEvaluator; import io.confluent.ksql.execution.codegen.helpers.InListEvaluator; +import io.confluent.ksql.execution.codegen.helpers.LambdaUtil; import io.confluent.ksql.execution.codegen.helpers.LikeEvaluator; import io.confluent.ksql.execution.codegen.helpers.MapBuilder; import io.confluent.ksql.execution.codegen.helpers.NullSafe; @@ -94,6 +95,7 @@ import io.confluent.ksql.schema.ksql.types.SqlArray; import io.confluent.ksql.schema.ksql.types.SqlBaseType; import io.confluent.ksql.schema.ksql.types.SqlDecimal; +import io.confluent.ksql.schema.ksql.types.SqlLambda; import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; @@ -104,6 +106,7 @@ import java.math.BigDecimal; import java.math.MathContext; import java.math.RoundingMode; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -226,8 +229,9 @@ public String process(final Expression expression) { } private String formatExpression(final Expression expression) { + final TypeContext context = new TypeContext(); final Pair expressionFormatterResult = - new Formatter(functionRegistry).process(expression, null); + new Formatter(functionRegistry).process(expression, context); return expressionFormatterResult.getLeft(); } @@ -360,15 +364,34 @@ public Pair visitNullLiteral( } @Override + // CHECKSTYLE_RULES.OFF: TodoComment public Pair visitLambdaExpression( final LambdaFunctionCall lambdaFunctionCall, final TypeContext context) { - return visitUnsupported(lambdaFunctionCall); + + context.mapLambdaInputTypes(lambdaFunctionCall.getArguments()); + + final Pair lambdaBody = process(lambdaFunctionCall.getBody(), context); + + final List>> argPairs = new ArrayList<>(); + + for (final String lambdaArg: lambdaFunctionCall.getArguments()) { + argPairs.add(new Pair<>( + lambdaArg, + SchemaConverters.sqlToJavaConverter().toJavaType(context.getLambdaType(lambdaArg)) + )); + } + return new Pair<>(LambdaUtil.function(argPairs, lambdaBody.getLeft()), + expressionTypeManager.getExpressionSqlType(lambdaFunctionCall, context)); } @Override public Pair visitLambdaVariable( - final LambdaVariable lambdaVariable, final TypeContext context) { - return visitUnsupported(lambdaVariable); + final LambdaVariable lambdaVariable, final TypeContext context + ) { + return new Pair<>( + lambdaVariable.getValue(), + context.getLambdaType(lambdaVariable.getValue()) + ); } @Override @@ -433,10 +456,30 @@ public Pair visitFunctionCall( final String instanceName = funNameToCodeName.apply(functionName); final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName()); - final List argumentSchemas = node.getArguments().stream() - .map(expressionTypeManager::getExpressionSqlType) - .map(SqlArgument::of) - .collect(Collectors.toList()); + final List argumentSchemas = new ArrayList<>(); + for (final Expression argExpr : node.getArguments()) { + final SqlType newSqlType = expressionTypeManager.getExpressionSqlType(argExpr, context); + // for lambdas: if it's the array/map being passed in we save the type for later + if (context.notAllInputsSeen()) { + if (newSqlType instanceof SqlArray) { + final SqlArray inputArray = (SqlArray) newSqlType; + context.addLambdaInputType(inputArray.getItemType()); + } else if (newSqlType instanceof SqlMap) { + final SqlMap inputMap = (SqlMap) newSqlType; + context.addLambdaInputType(inputMap.getKeyType()); + context.addLambdaInputType(inputMap.getValueType()); + } else { + context.addLambdaInputType(newSqlType); + } + } + if (argExpr instanceof LambdaFunctionCall) { + argumentSchemas.add( + SqlArgument.of(SqlLambda.of(context.getLambdaInputTypes(), + newSqlType))); + } else { + argumentSchemas.add(SqlArgument.of(newSqlType)); + } + } final KsqlFunction function = udfFactory.getFunction(argumentSchemas); @@ -458,7 +501,9 @@ public Pair visitFunctionCall( paramType = function.parameters().get(i); } - joiner.add(process(convertArgument(arg, sqlType, paramType), context).getLeft()); + final Pair pair = + process(convertArgument(arg, sqlType, paramType), context); + joiner.add(pair.getLeft()); } @@ -783,7 +828,8 @@ public Pair visitArithmeticBinary( final Pair left = process(node.getLeft(), context); final Pair right = process(node.getRight(), context); - final SqlType schema = expressionTypeManager.getExpressionSqlType(node); + final SqlType schema = + expressionTypeManager.getExpressionSqlType(node, context); if (schema.baseType() == SqlBaseType.DECIMAL) { final SqlDecimal decimal = (SqlDecimal) schema; @@ -837,7 +883,8 @@ public Pair visitSearchedCaseExpression( )) .collect(Collectors.toList()); - final SqlType resultSchema = expressionTypeManager.getExpressionSqlType(node); + final SqlType resultSchema = + expressionTypeManager.getExpressionSqlType(node, context); final String resultSchemaString = SchemaConverters.sqlToJavaConverter().toJavaType(resultSchema).getCanonicalName(); diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java index 9b6dd8f0cd37..72ddeb98fece 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java @@ -82,7 +82,7 @@ public static String function( } else { throw new KsqlException("Unsupported number of lambda arguments."); } - + final String function = "new " + functionType + " {\n" + " @Override\n" + functionApply diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index 222cfa110f87..48bbb16c94f7 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -66,6 +66,7 @@ import io.confluent.ksql.schema.ksql.SqlArgument; import io.confluent.ksql.schema.ksql.types.SqlArray; import io.confluent.ksql.schema.ksql.types.SqlBaseType; +import io.confluent.ksql.schema.ksql.types.SqlLambda; import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlStruct; import io.confluent.ksql.schema.ksql.types.SqlStruct.Builder; @@ -97,6 +98,11 @@ public ExpressionTypeManager( public SqlType getExpressionSqlType(final Expression expression) { final TypeContext expressionTypeContext = new TypeContext(); + return getExpressionSqlType(expression, expressionTypeContext); + } + + public SqlType getExpressionSqlType( + final Expression expression, final TypeContext expressionTypeContext) { new Visitor().process(expression, expressionTypeContext); return expressionTypeContext.getSqlType(); } @@ -132,9 +138,8 @@ public Void visitArithmeticUnary( public Void visitLambdaExpression( final LambdaFunctionCall node, final TypeContext context ) { + context.mapLambdaInputTypes(node.getArguments()); process(node.getBody(), context); - // TODO: add proper type inference - context.setSqlType(SqlTypes.INTEGER); return null; } @@ -143,8 +148,7 @@ public Void visitLambdaExpression( public Void visitLambdaVariable( final LambdaVariable node, final TypeContext expressionTypeContext ) { - // TODO: add proper type inference - expressionTypeContext.setSqlType(SqlTypes.INTEGER); + expressionTypeContext.setSqlType(expressionTypeContext.getLambdaType(node.getValue())); return null; } @@ -168,8 +172,10 @@ public Void visitComparisonExpression( ) { process(node.getLeft(), expressionTypeContext); final SqlType leftSchema = expressionTypeContext.getSqlType(); + process(node.getRight(), expressionTypeContext); final SqlType rightSchema = expressionTypeContext.getSqlType(); + if (!ComparisonUtil.isValidComparison(leftSchema, node.getType(), rightSchema)) { throw new KsqlException("Cannot compare " + node.getLeft().toString() + " (" + leftSchema.toString() + ") to " @@ -423,11 +429,13 @@ public Void visitStructExpression( return null; } + // CHECKSTYLE_RULES.OFF: CyclomaticComplexity @Override public Void visitFunctionCall( final FunctionCall node, final TypeContext expressionTypeContext ) { + // CHECKSTYLE_RULES.ON: CyclomaticComplexity if (functionRegistry.isAggregate(node.getName())) { final SqlType schema = node.getArguments().isEmpty() ? FunctionRegistry.DEFAULT_FUNCTION_ARG_SCHEMA @@ -467,7 +475,27 @@ public Void visitFunctionCall( final List argTypes = new ArrayList<>(); for (final Expression expression : node.getArguments()) { process(expression, expressionTypeContext); - argTypes.add(SqlArgument.of(expressionTypeContext.getSqlType())); + final SqlType newSqlType = expressionTypeContext.getSqlType(); + if (expression instanceof LambdaFunctionCall) { + argTypes.add( + SqlArgument.of(SqlLambda.of(expressionTypeContext.getLambdaInputTypes(), + expressionTypeContext.getSqlType())) + ); + } else { + argTypes.add(SqlArgument.of(newSqlType)); + } + if (expressionTypeContext.notAllInputsSeen()) { + if (newSqlType instanceof SqlArray) { + final SqlArray inputArray = (SqlArray) newSqlType; + expressionTypeContext.addLambdaInputType(inputArray.getItemType()); + } else if (newSqlType instanceof SqlMap) { + final SqlMap inputMap = (SqlMap) newSqlType; + expressionTypeContext.addLambdaInputType(inputMap.getKeyType()); + expressionTypeContext.addLambdaInputType(inputMap.getValueType()); + } else { + expressionTypeContext.addLambdaInputType(newSqlType); + } + } } final SqlType returnSchema = udfFactory.getFunction(argTypes).getReturnType(argTypes); diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java index 8b4588b5c669..4128c3446fa3 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java @@ -393,8 +393,8 @@ private EndpointResponse handlePrintTopic( final String reverseSuggestion = possibleAlternatives.isEmpty() ? "" : possibleAlternatives.stream() - .map(name -> "\tprint " + name + ";") - .collect(Collectors.joining( + .map(name -> "\tprint " + name + ";") + .collect(Collectors.joining( System.lineSeparator(), System.lineSeparator() + "Did you mean:" + System.lineSeparator(), "" diff --git a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java index 081e9339eaae..a912b7e9d831 100644 --- a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java +++ b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java @@ -15,31 +15,48 @@ package io.confluent.ksql.schema.ksql; +import io.confluent.ksql.schema.ksql.types.SqlLambda; import io.confluent.ksql.schema.ksql.types.SqlType; import java.util.Objects; /** - * A wrapper class to bundle SqlTypes for UDF functions. + * A wrapper class to bundle SqlTypes and SqlLambdas for UDF functions that contain + * lambdas as an argument. This class allows us to properly find the matching UDF and + * resolve the return type for lambda UDFs based on the given sqlLambda. */ public class SqlArgument { private final SqlType sqlType; + private final SqlLambda sqlLambda; - public SqlArgument(final SqlType type) { + public SqlArgument(final SqlType type, final SqlLambda lambda) { sqlType = type; + sqlLambda = lambda; } public static SqlArgument of(final SqlType type) { - return new SqlArgument(type); + return new SqlArgument(type, null); + } + + public static SqlArgument of(final SqlLambda type) { + return new SqlArgument(null, type); + } + + public static SqlArgument of(final SqlType sqlType, final SqlLambda lambdaType) { + return new SqlArgument(sqlType, lambdaType); } public SqlType getSqlType() { return sqlType; } + public SqlLambda getSqlLambda() { + return sqlLambda; + } + @Override public int hashCode() { - return Objects.hashCode(sqlType); + return Objects.hash(sqlType, sqlLambda); } @Override @@ -51,6 +68,7 @@ public boolean equals(final Object o) { return false; } final SqlArgument that = (SqlArgument) o; - return that.sqlType == this.sqlType; + return that.sqlType == this.sqlType && that.sqlLambda == this.sqlLambda; } + } diff --git a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java new file mode 100644 index 000000000000..719ac5890aa0 --- /dev/null +++ b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java @@ -0,0 +1,81 @@ +/* + * Copyright 2021 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.schema.ksql.types; + +import static java.util.Objects.requireNonNull; + +import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.schema.utils.FormatOptions; +import java.util.List; +import java.util.Objects; + +/** + * An internal object to track the input types and return types for a lambda + * argument that's seen as an argument inside of a UDF. + */ +@Immutable +public final class SqlLambda { + + private final List inputTypes; + private final SqlType returnType; + + public static SqlLambda of( + final List inputType, + final SqlType returnType + ) { + return new SqlLambda(inputType, returnType); + } + + public SqlLambda( + final List inputTypes, + final SqlType returnType + ) { + this.inputTypes = requireNonNull(inputTypes, "inputType"); + this.returnType = requireNonNull(returnType, "returnType"); + } + + public List getInputType() { + return inputTypes; + } + + public SqlType getReturnType() { + return returnType; + } + + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final SqlLambda lambda = (SqlLambda) o; + return Objects.equals(inputTypes, lambda.inputTypes) + && Objects.equals(returnType, lambda.returnType); + } + + public int hashCode() { + return Objects.hash(inputTypes, returnType); + } + + public String toString() { + return toString(FormatOptions.none()); + } + + public String toString(final FormatOptions formatOptions) { + return "Lambda<" + inputTypes + ", " + returnType + ">"; + } +} From 55e968de2e99bd32ce75f10334fa3f14ff740c32 Mon Sep 17 00:00:00 2001 From: Leah Thomas Date: Thu, 18 Feb 2021 10:17:32 -0600 Subject: [PATCH 2/6] Adding unit tests --- ksqldb-common/pom.xml | 2 +- .../confluent/ksql/function/GenericsUtil.java | 1 - .../ksql/function/types/ParamTypes.java | 31 +++---- .../ksql/function/types/ParamTypesTest.java | 16 ++++ .../ksql/function/UdfLoaderTest.java | 42 +++++++++ .../ksql/execution/codegen/CodeGenRunner.java | 8 +- .../execution/codegen/SqlToJavaVisitor.java | 2 - .../codegen/SqlToJavaVisitorTest.java | 86 +++++++++++++++++++ .../util/ExpressionTypeManagerTest.java | 34 ++++++++ .../ksql/schema/ksql/SqlArgument.java | 1 - 10 files changed, 196 insertions(+), 27 deletions(-) diff --git a/ksqldb-common/pom.xml b/ksqldb-common/pom.xml index c00330718888..e961ceb18c05 100644 --- a/ksqldb-common/pom.xml +++ b/ksqldb-common/pom.xml @@ -42,7 +42,7 @@ io.confluent kafka-connect-avro-converter - ${io.confluent.schema-registry.version} + 6.2.0-363 diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java index 9b38579dbd5b..c82c8794f3f1 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java @@ -213,7 +213,6 @@ private static boolean resolveGenerics( } if (schema instanceof ArrayType) { - final SqlArray sqlArray = (SqlArray) sqlType; return resolveGenerics( mapping, ((ArrayType) schema).element(), diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java index a5706635b076..e27ed886a235 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java @@ -56,6 +56,22 @@ public static boolean areCompatible( // CHECKSTYLE_RULES.ON: NPathComplexity final SqlType argumentSqlType = argument.getSqlType(); final SqlLambda sqlLambda = argument.getSqlLambda(); + + if (sqlLambda != null && declared instanceof LambdaType) { + final LambdaType declaredLambda = (LambdaType) declared; + if (sqlLambda.getInputType().size() != declaredLambda.inputTypes().size()) { + return false; + } + int i = 0; + for (final ParamType paramType: declaredLambda.inputTypes()) { + if (!areCompatible(sqlLambda.getInputType().get(i), paramType)) { + return false; + } + i++; + } + return areCompatible(sqlLambda.getReturnType(), declaredLambda.returnType()); + } + if (argumentSqlType.baseType() == SqlBaseType.ARRAY && declared instanceof ArrayType) { return areCompatible( SqlArgument.of(((SqlArray) argumentSqlType).getItemType()), @@ -78,21 +94,6 @@ && areCompatible( return isStructCompatible(argumentSqlType, declared); } - if (sqlLambda != null && declared instanceof LambdaType) { - final LambdaType declaredLambda = (LambdaType) declared; - if (sqlLambda.getInputType().size() != declaredLambda.inputTypes().size()) { - return false; - } - int i = 0; - for (final ParamType paramType: declaredLambda.inputTypes()) { - if (!areCompatible(sqlLambda.getInputType().get(i), paramType)) { - return false; - } - i++; - } - return areCompatible(sqlLambda.getReturnType(), declaredLambda.returnType()); - } - return isPrimitiveMatch(argumentSqlType, declared, allowCast); } diff --git a/ksqldb-common/src/test/java/io/confluent/ksql/function/types/ParamTypesTest.java b/ksqldb-common/src/test/java/io/confluent/ksql/function/types/ParamTypesTest.java index 12528fc58d40..9bf843d03694 100644 --- a/ksqldb-common/src/test/java/io/confluent/ksql/function/types/ParamTypesTest.java +++ b/ksqldb-common/src/test/java/io/confluent/ksql/function/types/ParamTypesTest.java @@ -18,7 +18,10 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import com.google.common.collect.ImmutableList; import io.confluent.ksql.schema.ksql.SqlArgument; +import io.confluent.ksql.schema.ksql.types.SqlLambda; +import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import org.junit.Test; @@ -48,6 +51,13 @@ public void shouldFailINonCompatibleSchemas() { SqlTypes.map(SqlTypes.decimal(1, 1), SqlTypes.INTEGER), MapType.of(ParamTypes.INTEGER, ParamTypes.INTEGER)), is(false)); + + + assertThat(ParamTypes.areCompatible( + SqlArgument.of(new SqlLambda(ImmutableList.of(SqlTypes.INTEGER), SqlTypes.INTEGER)), + LambdaType.of(ImmutableList.of(ParamTypes.STRING), ParamTypes.STRING), + false), + is(false)); } @Test @@ -68,6 +78,12 @@ public void shouldPassCompatibleSchemas() { SqlTypes.map(SqlTypes.INTEGER, SqlTypes.decimal(1, 1)), MapType.of(ParamTypes.INTEGER, ParamTypes.DECIMAL)), is(true)); + + assertThat(ParamTypes.areCompatible( + SqlArgument.of(new SqlLambda(ImmutableList.of(SqlTypes.STRING), SqlTypes.STRING)), + LambdaType.of(ImmutableList.of(ParamTypes.STRING), ParamTypes.STRING), + false), + is(true)); } @Test diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java index 213a8dce760d..a0d70c09c941 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java @@ -53,7 +53,10 @@ import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.schema.ksql.SqlArgument; import io.confluent.ksql.schema.ksql.SqlTypeParser; +import io.confluent.ksql.schema.ksql.types.SqlArray; import io.confluent.ksql.schema.ksql.types.SqlDecimal; +import io.confluent.ksql.schema.ksql.types.SqlLambda; +import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlStruct; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; @@ -187,6 +190,45 @@ public void shouldLoadDecimalUdfs() { assertThat(fun.name().text(), equalToIgnoringCase("floor")); } + @Test + public void shouldLoadLambdaReduceUdfs() { + // Given: + final SqlLambda schema = + SqlLambda.of( + ImmutableList.of(SqlTypes.INTEGER, SqlTypes.INTEGER, SqlTypes.INTEGER), + SqlTypes.INTEGER); + + // When: + final KsqlScalarFunction fun = FUNC_REG.getUdfFactory(FunctionName.of("reduce_map")) + .getFunction( + ImmutableList.of( + SqlArgument.of(SqlMap.of(SqlTypes.INTEGER, SqlTypes.INTEGER)), + SqlArgument.of(SqlTypes.INTEGER), + SqlArgument.of(schema))); + + // Then: + assertThat(fun.name().text(), equalToIgnoringCase("reduce_map")); + } + + @Test + public void shouldLoadLambdaTransformUdfs() { + // Given: + final SqlLambda schema = + SqlLambda.of( + ImmutableList.of(SqlTypes.INTEGER), + SqlTypes.INTEGER); + + // When: + final KsqlScalarFunction fun = FUNC_REG.getUdfFactory(FunctionName.of("array_transform")) + .getFunction( + ImmutableList.of( + SqlArgument.of(SqlArray.of(SqlTypes.INTEGER)), + SqlArgument.of(schema))); + + // Then: + assertThat(fun.name().text(), equalToIgnoringCase("array_transform")); + } + @Test public void shouldLoadFunctionsFromJarsInPluginDir() { final UdfFactory toString = FUNC_REG.getUdfFactory(FunctionName.of("tostring")); diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java index 75562c4e95a2..6def6e20dae9 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java @@ -51,7 +51,6 @@ import java.util.Map.Entry; import java.util.stream.Collectors; import java.util.stream.Stream; - import org.apache.kafka.connect.data.Schema; import org.codehaus.commons.compiler.CompileException; import org.codehaus.commons.compiler.CompilerFactoryFactory; @@ -144,11 +143,9 @@ public ExpressionMetadata buildCodeGenFromParseTree( return new ExpressionMetadata(ee, spec, returnType, expression); } catch (KsqlException | CompileException e) { - e.printStackTrace(); throw new KsqlException("Invalid " + type + ": " + e.getMessage() + ". expression:" + expression + ", schema:" + schema, e); } catch (final Exception e) { - e.printStackTrace(); throw new RuntimeException("Unexpected error generating code for " + type + ". expression:" + expression, e); } @@ -211,10 +208,7 @@ public Void visitFunctionCall(final FunctionCall node, final TypeContext context if (argExpr instanceof LambdaFunctionCall) { argumentTypes.add( SqlArgument.of( - SqlLambda.of(context.getLambdaInputTypes(), - newSqlType)) - ); - + SqlLambda.of(context.getLambdaInputTypes(), newSqlType))); } else { argumentTypes.add(SqlArgument.of(newSqlType)); } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index 994b06ef643c..3f50fb894ecb 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -369,11 +369,9 @@ public Pair visitLambdaExpression( final LambdaFunctionCall lambdaFunctionCall, final TypeContext context) { context.mapLambdaInputTypes(lambdaFunctionCall.getArguments()); - final Pair lambdaBody = process(lambdaFunctionCall.getBody(), context); final List>> argPairs = new ArrayList<>(); - for (final String lambdaArg: lambdaFunctionCall.getArguments()) { argPairs.add(new Pair<>( lambdaArg, diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java index 9f3da6a0228d..783e449f65f9 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java @@ -56,6 +56,8 @@ import io.confluent.ksql.execution.expression.tree.InListExpression; import io.confluent.ksql.execution.expression.tree.InPredicate; import io.confluent.ksql.execution.expression.tree.IntegerLiteral; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; +import io.confluent.ksql.execution.expression.tree.LambdaVariable; import io.confluent.ksql.execution.expression.tree.LikePredicate; import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.SearchedCaseExpression; @@ -71,6 +73,9 @@ import io.confluent.ksql.function.UdfFactory; import io.confluent.ksql.function.types.ArrayType; import io.confluent.ksql.function.types.GenericType; +import io.confluent.ksql.function.types.LambdaType; +import io.confluent.ksql.function.types.MapType; +import io.confluent.ksql.function.types.ParamType; import io.confluent.ksql.function.types.ParamTypes; import io.confluent.ksql.function.udf.UdfMetadata; import io.confluent.ksql.name.ColumnName; @@ -887,6 +892,87 @@ public void shouldGenerateCorrectCodeForInPredicate() { assertThat(java, is("InListEvaluator.matches(COL0,1L,2L)")); } + @Test + public void shouldGenerateCorrectCodeForTransformLambdaExpression() { + // Given: + final UdfFactory udfFactory = mock(UdfFactory.class); + final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); + givenUdf("ABS", udfFactory, udf); + givenUdf("TRANSFORM", udfFactory, udf); + when(udf.parameters()). + thenReturn(ImmutableList.of( + ArrayType.of(ParamTypes.DOUBLE), + LambdaType.of(ImmutableList.of( + ParamTypes.DOUBLE), + ParamTypes.DOUBLE)) + ); + + final Expression expression = new FunctionCall ( + FunctionName.of("TRANSFORM"), + ImmutableList.of( + ARRAYCOL, + new LambdaFunctionCall( + ImmutableList.of("x"), + (new FunctionCall(FunctionName.of("ABS"), ImmutableList.of(new LambdaVariable("X"))))))); + + // When: + final String javaExpression = sqlToJavaVisitor.process(expression); + + // Then + assertThat( + javaExpression, equalTo( + "((String) TRANSFORM_0.evaluate(COL4, new Function() {\n @Override\n public Object apply(Object arg1) {\n final Double x = (Double) arg1;\n return ((String) ABS_1.evaluate(X));\n }\n}))")); + } + + @Test + public void shouldGenerateCorrectCodeForReduceLambdaExpression() { + // Given: + final UdfFactory udfFactory = mock(UdfFactory.class); + final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); + givenUdf("REDUCE", udfFactory, udf); + when(udf.parameters()). + thenReturn(ImmutableList.of( + ArrayType.of(ParamTypes.DOUBLE), + ParamTypes.DOUBLE, + LambdaType.of( + ImmutableList.of(ParamTypes.DOUBLE, ParamTypes.DOUBLE), + ParamTypes.DOUBLE)) + ); + + final Expression expression = new FunctionCall ( + FunctionName.of("REDUCE"), + ImmutableList.of( + ARRAYCOL, + COL3, + new LambdaFunctionCall( + ImmutableList.of("X", "S"), + (new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaVariable("X"), + new LambdaVariable("S"))) + ))); + + // When: + final String javaExpression = sqlToJavaVisitor.process(expression); + + // Then + assertThat( + javaExpression, equalTo( + "((String) REDUCE_0.evaluate(COL4, COL3, new BiFunction() {\\n @Override\\n public Object apply(Object arg1, Object arg2) {\\n final Double X = (Double) arg1;\\n final Double S = (Double) arg2;\\n return (X + S);\\n }\\n}))")); + } + + @Test + public void shouldThrowErrorOnEmptyLambdaInput() { + // Given: + final Expression expression = new LambdaFunctionCall( + ImmutableList.of("x"), + (new FunctionCall(FunctionName.of("ABS"), ImmutableList.of(new LambdaVariable("X"))))); + + // When: + assertThrows(IllegalArgumentException.class, () -> sqlToJavaVisitor.process(expression)); + + } + @Test public void shouldThrowOnSimpleCase() { // Given: diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java index eeb580c5c2b5..a1ed17048d71 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java @@ -16,6 +16,7 @@ package io.confluent.ksql.execution.util; import static io.confluent.ksql.execution.testutil.TestExpressions.ADDRESS; +import static io.confluent.ksql.execution.testutil.TestExpressions.ARRAYCOL; import static io.confluent.ksql.execution.testutil.TestExpressions.COL1; import static io.confluent.ksql.execution.testutil.TestExpressions.COL2; import static io.confluent.ksql.execution.testutil.TestExpressions.COL3; @@ -50,6 +51,8 @@ import io.confluent.ksql.execution.expression.tree.InListExpression; import io.confluent.ksql.execution.expression.tree.InPredicate; import io.confluent.ksql.execution.expression.tree.IntegerLiteral; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; +import io.confluent.ksql.execution.expression.tree.LambdaVariable; import io.confluent.ksql.execution.expression.tree.LikePredicate; import io.confluent.ksql.execution.expression.tree.NotExpression; import io.confluent.ksql.execution.expression.tree.NullLiteral; @@ -73,6 +76,7 @@ import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.SqlArgument; import io.confluent.ksql.schema.ksql.SystemColumns; +import io.confluent.ksql.schema.ksql.types.SqlArray; import io.confluent.ksql.schema.ksql.types.SqlStruct; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; @@ -339,6 +343,36 @@ public void shouldHandleNestedUdfs() { assertThat(expressionTypeManager.getExpressionSqlType(expression), equalTo(SqlTypes.STRING)); } + @Test + public void shouldEvaluateTypeForLambdaUDF() { + // Given: + + givenUdfWithNameAndReturnType("transform_array", SqlArray.of(SqlTypes.STRING)); + final Expression expression = + new FunctionCall( + FunctionName.of("TRANSFORM_ARRAY"), + ImmutableList.of( + ARRAYCOL, + new LambdaFunctionCall( + ImmutableList.of("X"), + new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaVariable("X"), + literal(5) + ) + ) + ) + ); + + // When: + final SqlType exprType = expressionTypeManager.getExpressionSqlType(expression); + + // Then: + assertThat(exprType, is(SqlArray.of(SqlTypes.STRING))); + verify(udfFactory).getFunction(ImmutableList.of(SqlArgument.of(SqlArray.of(SqlTypes.STRING), null))); + verify(function).getReturnType(ImmutableList.of(SqlArgument.of(SqlArray.of(SqlTypes.STRING), null))); + } + @Test public void shouldHandleStructFieldDereference() { // Given: diff --git a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java index a912b7e9d831..096ab1de9f32 100644 --- a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java +++ b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java @@ -70,5 +70,4 @@ public boolean equals(final Object o) { final SqlArgument that = (SqlArgument) o; return that.sqlType == this.sqlType && that.sqlLambda == this.sqlLambda; } - } From c034b584b3a8d308d8623515b7321e28a9f883bb Mon Sep 17 00:00:00 2001 From: Steven Zhang Date: Wed, 17 Feb 2021 17:17:32 -0800 Subject: [PATCH 3/6] additional unit tests --- ksqldb-common/pom.xml | 2 +- .../confluent/ksql/function/GenericsUtil.java | 16 +- .../io/confluent/ksql/function/UdfIndex.java | 17 +- .../confluent/ksql/function/UdfIndexTest.java | 128 +++++++++++++ .../confluent/ksql/util/GenericsUtilTest.java | 113 +++++++++++ .../ksql/execution/codegen/CodeGenRunner.java | 2 +- .../execution/codegen/SqlToJavaVisitor.java | 4 +- .../execution/codegen/helpers/LambdaUtil.java | 2 +- .../execution/util/ExpressionTypeManager.java | 11 +- .../codegen/SqlToJavaVisitorTest.java | 10 +- .../util/ExpressionTypeManagerTest.java | 177 ++++++++++++++++-- .../ksql/schema/ksql/SqlArgument.java | 12 +- .../ksql/schema/ksql/types/SqlLambda.java | 5 +- 13 files changed, 463 insertions(+), 36 deletions(-) diff --git a/ksqldb-common/pom.xml b/ksqldb-common/pom.xml index e961ceb18c05..c00330718888 100644 --- a/ksqldb-common/pom.xml +++ b/ksqldb-common/pom.xml @@ -42,7 +42,7 @@ io.confluent kafka-connect-avro-converter - 6.2.0-363 + ${io.confluent.schema-registry.version} diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java index c82c8794f3f1..8b5b099cd799 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java @@ -174,10 +174,13 @@ public static Map resolveGenerics( final SqlType old = mapping.putIfAbsent(entry.getKey(), entry.getValue()); if (old != null && !old.equals(entry.getValue())) { throw new KsqlException(String.format( - "Found invalid instance of generic schema. Cannot map %s to both %s and %s", + "Found invalid instance of generic schema when mapping %s to %s. " + + "Cannot map %s to both %s and %s", schema, + instance, + entry.getKey(), old, - instance)); + entry.getValue())); } } @@ -213,10 +216,11 @@ private static boolean resolveGenerics( } if (schema instanceof ArrayType) { + final SqlArray sqlArray = (SqlArray) sqlType; return resolveGenerics( mapping, ((ArrayType) schema).element(), - SqlArgument.of(((SqlArray) sqlType).getItemType())); + SqlArgument.of(sqlArray.getItemType())); } if (schema instanceof MapType) { @@ -236,7 +240,7 @@ private static boolean resolveGenerics( boolean resolvedInputs = true; if (sqlLambda.getInputType().size() != lambdaType.inputTypes().size()) { throw new KsqlException( - "Number of lambda arguments don't match between schema and sql type"); + "Number of lambda arguments doesn't match between schema and sql type"); } int i = 0; @@ -258,6 +262,8 @@ resolvedInputs && resolveGenerics( private static boolean matches(final ParamType schema, final SqlArgument instance) { if (schema instanceof LambdaType && instance.getSqlLambda() != null) { return true; + } else if (schema instanceof LambdaType || instance.getSqlLambda() != null) { + return false; } final ParamType instanceParamType = SchemaConverters .sqlToFunctionConverter().toFunctionType(instance.getSqlType()); @@ -272,7 +278,7 @@ private static boolean matches(final ParamType schema, final SqlArgument instanc public static boolean instanceOf(final ParamType schema, final SqlArgument instance) { final List> mappings = new ArrayList<>(); - if (!resolveGenerics(mappings, schema, SqlArgument.of(instance.getSqlType()))) { + if (!resolveGenerics(mappings, schema, instance)) { return false; } diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java index c96fdbc4295f..0dec36b67895 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java @@ -19,7 +19,6 @@ import com.google.common.collect.Iterables; import io.confluent.ksql.function.types.ArrayType; import io.confluent.ksql.function.types.GenericType; -import io.confluent.ksql.function.types.LambdaType; import io.confluent.ksql.function.types.ParamType; import io.confluent.ksql.function.types.ParamTypes; import io.confluent.ksql.schema.ksql.SqlArgument; @@ -199,7 +198,18 @@ private KsqlException createNoMatchingFunctionException(final List LOG.debug("Current UdfIndex:\n{}", describe()); final String requiredTypes = paramTypes.stream() - .map(type -> type == null ? "null" : type.getSqlType().toString(FormatOptions.noEscape())) + .map(argument -> { + if (argument == null) { + return "null"; + } else { + final SqlType sqlType = argument.getSqlType(); + if (sqlType != null) { + return sqlType.toString(FormatOptions.noEscape()); + } else { + return argument.getSqlLambda().toString(); + } + } + }) .collect(Collectors.joining(", ", "(", ")")); final String acceptedTypes = allFunctions.values().stream() @@ -369,8 +379,7 @@ private static boolean reserveGenerics( final SqlArgument argument, final Map reservedGenerics ) { - if (!(schema instanceof LambdaType) - && !GenericsUtil.instanceOf(schema, argument)) { + if (!GenericsUtil.instanceOf(schema, argument)) { return false; } final Map genericMapping = diff --git a/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java b/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java index cdcc012d6ae9..ab533d976eb3 100644 --- a/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java +++ b/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java @@ -13,6 +13,7 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.function.types.ArrayType; import io.confluent.ksql.function.types.GenericType; +import io.confluent.ksql.function.types.LambdaType; import io.confluent.ksql.function.types.MapType; import io.confluent.ksql.function.types.ParamType; import io.confluent.ksql.function.types.ParamTypes; @@ -21,6 +22,7 @@ import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.schema.ksql.SqlArgument; import io.confluent.ksql.schema.ksql.types.SqlArray; +import io.confluent.ksql.schema.ksql.types.SqlLambda; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.util.KsqlConfig; @@ -45,9 +47,14 @@ public class UdfIndexTest { private static final ParamType STRUCT2 = StructType.builder().field("b", INT).build(); private static final ParamType MAP1 = MapType.of(STRING, STRING); private static final ParamType MAP2 = MapType.of(INT, INT); + private static final ParamType LAMBDA_FUNCTION = LambdaType.of(ImmutableList.of(GenericType.of("A")), GenericType.of("B")); + private static final ParamType LAMBDA_BI_FUNCTION = LambdaType.of(ImmutableList.of(GenericType.of("A"), GenericType.of("B")), GenericType.of("C")); + private static final ParamType LAMBDA_BI_FUNCTION_STRING = LambdaType.of(ImmutableList.of(STRING, STRING), GenericType.of("A")); private static final ParamType GENERIC_LIST = ArrayType.of(GenericType.of("T")); + private static final ParamType GENERIC_MAP = MapType.of(GenericType.of("A"), GenericType.of("B")); + private static final SqlType ARRAY_ARG = SqlTypes.array(INTEGER); private static final SqlType MAP1_ARG = SqlTypes.map(SqlTypes.STRING, SqlTypes.STRING); private static final SqlType DECIMAL1_ARG = SqlTypes.decimal(4, 2); @@ -231,6 +238,127 @@ public void shouldChooseCorrectMap() { assertThat(fun.name(), equalTo(EXPECTED)); } + @Test + public void shouldChooseCorrectLambdaFunction() { + // Given: + givenFunctions( + function(EXPECTED, false, GENERIC_LIST, LAMBDA_FUNCTION) + ); + + // When: + final KsqlScalarFunction fun = udfIndex.getFunction( + ImmutableList.of( + SqlArgument.of(ARRAY_ARG), + SqlArgument.of( + SqlLambda.of( + ImmutableList.of(SqlTypes.STRING), + INTEGER)))); + + // Then: + assertThat(fun.name(), equalTo(EXPECTED)); + } + + @Test + public void shouldChooseCorrectLambdaBiFunction() { + // Given: + givenFunctions( + function(EXPECTED, false, GENERIC_MAP, LAMBDA_BI_FUNCTION) + ); + + // When: + final KsqlScalarFunction fun = udfIndex.getFunction( + ImmutableList.of( + SqlArgument.of(MAP1_ARG), + SqlArgument.of( + SqlLambda.of( + ImmutableList.of(SqlTypes.STRING, SqlTypes.STRING), + INTEGER)))); + + // Then: + assertThat(fun.name(), equalTo(EXPECTED)); + } + + @Test + public void shouldChooseCorrectLambdaForTypeSpecificCollections() { + // Given: + givenFunctions( + function(EXPECTED, false, MAP1, LAMBDA_BI_FUNCTION_STRING) + ); + + // When: + final KsqlScalarFunction fun1 = udfIndex.getFunction( + ImmutableList.of( + SqlArgument.of(MAP1_ARG), + SqlArgument.of( + SqlLambda.of( + ImmutableList.of(SqlTypes.STRING, SqlTypes.STRING), + SqlTypes.BOOLEAN)))); + + final KsqlScalarFunction fun2 = udfIndex.getFunction( + ImmutableList.of( + SqlArgument.of(MAP1_ARG), + SqlArgument.of( + SqlLambda.of( + ImmutableList.of(SqlTypes.STRING, SqlTypes.STRING), + INTEGER)))); + + final Exception e = assertThrows( + Exception.class, + () -> udfIndex.getFunction( + ImmutableList.of( + SqlArgument.of(MAP1_ARG), + SqlArgument.of( + SqlLambda.of( + ImmutableList.of(SqlTypes.BOOLEAN, INTEGER), + INTEGER)))) + ); + + // Then: + assertThat(fun1.name(), equalTo(EXPECTED)); + assertThat(fun2.name(), equalTo(EXPECTED)); + assertThat(e.getMessage(), containsString("Valid alternatives are:" + + lineSeparator() + + "expected(MAP, LAMBDA<[VARCHAR, VARCHAR], A>)")); + } + + @Test + public void shouldThrowOnInvalidLambdaMapping() { + // Given: + givenFunctions( + function(OTHER, false, GENERIC_MAP, LAMBDA_BI_FUNCTION) + ); + + // When: + final Exception e1 = assertThrows( + Exception.class, + () -> udfIndex.getFunction( + ImmutableList.of( + SqlArgument.of(MAP1_ARG), + SqlArgument.of( + SqlLambda.of( + ImmutableList.of(SqlTypes.BOOLEAN, SqlTypes.STRING), + INTEGER)))) + ); + + final Exception e2 = assertThrows( + Exception.class, + () -> udfIndex.getFunction( + ImmutableList.of( + SqlArgument.of(MAP1_ARG), + SqlArgument.of( + SqlLambda.of( + ImmutableList.of(SqlTypes.STRING,SqlTypes.STRING, SqlTypes.STRING), + INTEGER) + ))) + ); + + // Then: + assertThat(e1.getMessage(), containsString("Valid alternatives are:" + + lineSeparator() + + "other(MAP, LAMBDA<[A, B], C>)")); + assertThat(e2.getMessage(), containsString("Number of lambda arguments doesn't match between schema and sql type")); + } + @Test public void shouldAllowAnyDecimal() { // Given: diff --git a/ksqldb-common/src/test/java/io/confluent/ksql/util/GenericsUtilTest.java b/ksqldb-common/src/test/java/io/confluent/ksql/util/GenericsUtilTest.java index 63b43d07ef87..5dca6b6ff37b 100644 --- a/ksqldb-common/src/test/java/io/confluent/ksql/util/GenericsUtilTest.java +++ b/ksqldb-common/src/test/java/io/confluent/ksql/util/GenericsUtilTest.java @@ -17,19 +17,24 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThrows; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.confluent.ksql.function.GenericsUtil; import io.confluent.ksql.function.types.ArrayType; import io.confluent.ksql.function.types.GenericType; +import io.confluent.ksql.function.types.LambdaType; import io.confluent.ksql.function.types.MapType; import io.confluent.ksql.function.types.ParamType; import io.confluent.ksql.function.types.ParamTypes; import io.confluent.ksql.function.types.StructType; import io.confluent.ksql.schema.ksql.SqlArgument; +import io.confluent.ksql.schema.ksql.types.SqlLambda; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import java.util.Map; @@ -78,6 +83,27 @@ public void shouldFindAllConstituentGenerics() { assertThat(generics, containsInAnyOrder(a, b, c, d)); } + @Test + public void shouldFindAllConstituentGenericsInLambdaType() { + // Given: + final GenericType a = GenericType.of("A"); + final GenericType b = GenericType.of("B"); + final GenericType c = GenericType.of("C"); + final GenericType d = GenericType.of("D"); + final ParamType lambda = LambdaType.of( + ImmutableList.of( + GenericType.of("C"), + GenericType.of("A"), + GenericType.of("B")), + GenericType.of("D")); + + // When: + final Set generics = GenericsUtil.constituentGenerics(lambda); + + // Then: + assertThat(generics, containsInAnyOrder(a, b, c, d)); + } + @Test public void shouldFindNoConstituentGenerics() { // Given: @@ -178,6 +204,93 @@ public void shouldIdentifyMapGeneric() { assertThat(mapping, hasEntry(a.value(), SqlTypes.BIGINT)); } + @Test + public void shouldIdentifyLambdaGenerics() { + // Given: + final GenericType typeA = GenericType.of("A"); + final GenericType typeB = GenericType.of("B"); + final LambdaType a = LambdaType.of(ImmutableList.of(typeA, typeB), typeB); + final SqlArgument instance = SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.DOUBLE, SqlTypes.BIGINT), SqlTypes.BIGINT)); + + // When: + final Map mapping = GenericsUtil.resolveGenerics(a, instance); + + // Then: + assertThat(mapping, hasEntry(typeA, SqlTypes.DOUBLE)); + assertThat(mapping, hasEntry(typeB, SqlTypes.BIGINT)); + } + + @Test + public void shouldFailToIdentifyLambdasWithDifferentSchema() { + // Given: + final GenericType typeA = GenericType.of("A"); + final GenericType typeB = GenericType.of("B"); + final GenericType typeC = GenericType.of("C"); + final LambdaType a = LambdaType.of(ImmutableList.of(typeA, typeC), typeB); + final SqlArgument instance = SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.DOUBLE), SqlTypes.BIGINT)); + + // When: + final Exception e = assertThrows( + KsqlException.class, + () -> GenericsUtil.resolveGenerics(a, instance) + ); + + // Then: + assertThat(e.getMessage(), containsString( + "Number of lambda arguments doesn't match between schema and sql type")); + } + + @Test + public void shouldFailToIdentifyMismatchedGenericsInLambda() { + // Given: + final GenericType typeA = GenericType.of("A"); + + final LambdaType a = LambdaType.of(ImmutableList.of(typeA), typeA); + final SqlArgument instance = SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.DOUBLE), SqlTypes.BOOLEAN)); + + // When: + final Exception e = assertThrows( + KsqlException.class, + () -> GenericsUtil.resolveGenerics(a, instance) + ); + + // Then: + assertThat(e.getMessage(), containsString( + "Found invalid instance of generic schema when mapping LAMBDA<[A], A> to Lambda<[DOUBLE], BOOLEAN>. " + + "Cannot map A to both DOUBLE and BOOLEAN")); + } + + @Test + public void shouldIdentifyInstanceOfLambda() { + // Given: + final LambdaType lambda = LambdaType.of(ImmutableList.of(GenericType.of("A")), GenericType.of("B")); + final SqlArgument instance = SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.INTEGER), SqlTypes.BIGINT)); + + // When: + final boolean isInstance = GenericsUtil.instanceOf(lambda, instance); + + // Then: + assertThat("expected instance of", isInstance); + } + + @Test + public void shouldNotIdentifyInstanceOfTypeMismatchLambda() { + // Given: + final MapType map = MapType.of(GenericType.of("A"), GenericType.of("B")); + final SqlArgument lambdaInstance = SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.INTEGER), SqlTypes.BIGINT)); + + final LambdaType lambda = LambdaType.of(ImmutableList.of(GenericType.of("A")), GenericType.of("B")); + final SqlArgument mapInstance = SqlArgument.of(SqlTypes.map(SqlTypes.STRING, SqlTypes.BOOLEAN)); + + // When: + final boolean isInstance1 = GenericsUtil.instanceOf(map, lambdaInstance); + final boolean isInstance2 = GenericsUtil.instanceOf(lambda, mapInstance); + + // Then: + assertThat("expected not instance of", !isInstance1); + assertThat("expected not instance of", !isInstance2); + } + @Test public void shouldNotIdentifyInstanceOfTypeMismatch() { // Given: diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java index 6def6e20dae9..65ebd07db946 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java @@ -187,7 +187,6 @@ public Void visitLikePredicate(final LikePredicate node, final TypeContext conte public Void visitFunctionCall(final FunctionCall node, final TypeContext context) { final List argumentTypes = new ArrayList<>(); final FunctionName functionName = node.getName(); - final UdfFactory holder = functionRegistry.getUdfFactory(functionName); for (final Expression argExpr : node.getArguments()) { process(argExpr, context); final SqlType newSqlType = expressionTypeManager.getExpressionSqlType(argExpr, context); @@ -214,6 +213,7 @@ public Void visitFunctionCall(final FunctionCall node, final TypeContext context } } + final UdfFactory holder = functionRegistry.getUdfFactory(functionName); final KsqlScalarFunction function = holder.getFunction(argumentTypes); spec.addFunction( function.name(), diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index 3f50fb894ecb..905117840a08 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -472,8 +472,8 @@ public Pair visitFunctionCall( } if (argExpr instanceof LambdaFunctionCall) { argumentSchemas.add( - SqlArgument.of(SqlLambda.of(context.getLambdaInputTypes(), - newSqlType))); + SqlArgument.of(SqlLambda.of(context.getLambdaInputTypes(), newSqlType)) + ); } else { argumentSchemas.add(SqlArgument.of(newSqlType)); } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java index 72ddeb98fece..9b6dd8f0cd37 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java @@ -82,7 +82,7 @@ public static String function( } else { throw new KsqlException("Unsupported number of lambda arguments."); } - + final String function = "new " + functionType + " {\n" + " @Override\n" + functionApply diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index 48bbb16c94f7..dcfc0c21a3d4 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -102,7 +102,8 @@ public SqlType getExpressionSqlType(final Expression expression) { } public SqlType getExpressionSqlType( - final Expression expression, final TypeContext expressionTypeContext) { + final Expression expression, final TypeContext expressionTypeContext + ) { new Visitor().process(expression, expressionTypeContext); return expressionTypeContext.getSqlType(); } @@ -477,10 +478,10 @@ public Void visitFunctionCall( process(expression, expressionTypeContext); final SqlType newSqlType = expressionTypeContext.getSqlType(); if (expression instanceof LambdaFunctionCall) { - argTypes.add( - SqlArgument.of(SqlLambda.of(expressionTypeContext.getLambdaInputTypes(), - expressionTypeContext.getSqlType())) - ); + argTypes.add(SqlArgument.of( + SqlLambda.of(expressionTypeContext.getLambdaInputTypes(), + expressionTypeContext.getSqlType()) + )); } else { argTypes.add(SqlArgument.of(newSqlType)); } diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java index 783e449f65f9..74ddbe887910 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitorTest.java @@ -958,7 +958,14 @@ public void shouldGenerateCorrectCodeForReduceLambdaExpression() { // Then assertThat( javaExpression, equalTo( - "((String) REDUCE_0.evaluate(COL4, COL3, new BiFunction() {\\n @Override\\n public Object apply(Object arg1, Object arg2) {\\n final Double X = (Double) arg1;\\n final Double S = (Double) arg2;\\n return (X + S);\\n }\\n}))")); + "((String) REDUCE_0.evaluate(COL4, COL3, new BiFunction() {\n" + + " @Override\n" + + " public Object apply(Object arg1, Object arg2) {\n" + + " final Double X = (Double) arg1;\n" + + " final Double S = (Double) arg2;\n" + + " return (X + S);\n" + + " }\n" + + "}))")); } @Test @@ -970,7 +977,6 @@ public void shouldThrowErrorOnEmptyLambdaInput() { // When: assertThrows(IllegalArgumentException.class, () -> sqlToJavaVisitor.process(expression)); - } @Test diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java index a1ed17048d71..c5c8ffed4ec6 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ExpressionTypeManagerTest.java @@ -76,12 +76,15 @@ import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.SqlArgument; import io.confluent.ksql.schema.ksql.SystemColumns; -import io.confluent.ksql.schema.ksql.types.SqlArray; +import io.confluent.ksql.schema.ksql.types.SqlLambda; import io.confluent.ksql.schema.ksql.types.SqlStruct; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.util.KsqlException; +import java.util.Arrays; +import java.util.Collections; import java.util.Optional; +import org.hamcrest.Matchers; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -344,13 +347,12 @@ public void shouldHandleNestedUdfs() { } @Test - public void shouldEvaluateTypeForLambdaUDF() { + public void shouldEvaluateLambdaInUDFWithArray() { // Given: - - givenUdfWithNameAndReturnType("transform_array", SqlArray.of(SqlTypes.STRING)); + givenUdfWithNameAndReturnType("TRANSFORM", SqlTypes.DOUBLE); final Expression expression = new FunctionCall( - FunctionName.of("TRANSFORM_ARRAY"), + FunctionName.of("TRANSFORM"), ImmutableList.of( ARRAYCOL, new LambdaFunctionCall( @@ -358,19 +360,168 @@ public void shouldEvaluateTypeForLambdaUDF() { new ArithmeticBinaryExpression( Operator.ADD, new LambdaVariable("X"), - literal(5) - ) - ) - ) - ); + new IntegerLiteral(5)) + ))); // When: final SqlType exprType = expressionTypeManager.getExpressionSqlType(expression); // Then: - assertThat(exprType, is(SqlArray.of(SqlTypes.STRING))); - verify(udfFactory).getFunction(ImmutableList.of(SqlArgument.of(SqlArray.of(SqlTypes.STRING), null))); - verify(function).getReturnType(ImmutableList.of(SqlArgument.of(SqlArray.of(SqlTypes.STRING), null))); + assertThat(exprType, is(SqlTypes.DOUBLE)); + verify(udfFactory).getFunction( + ImmutableList.of( + SqlArgument.of(SqlTypes.array(SqlTypes.DOUBLE)), + SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.DOUBLE), SqlTypes.DOUBLE)))); + verify(function).getReturnType( + ImmutableList.of( + SqlArgument.of(SqlTypes.array(SqlTypes.DOUBLE)), + SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.DOUBLE), SqlTypes.DOUBLE)))); + } + + @Test + public void shouldEvaluateLambdaInUDFWithMap() { + // Given: + givenUdfWithNameAndReturnType("TRANSFORM", SqlTypes.DOUBLE); + final Expression expression1 = + new FunctionCall( + FunctionName.of("TRANSFORM"), + ImmutableList.of( + MAPCOL, + new LambdaFunctionCall( + ImmutableList.of("X"), + new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaVariable("X"), + new IntegerLiteral(5)) + ))); + + final Expression expression2 = + new FunctionCall( + FunctionName.of("TRANSFORM"), + ImmutableList.of( + MAPCOL, + new LambdaFunctionCall( + ImmutableList.of("X", "Y"), + new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaVariable("X"), + new IntegerLiteral(5)) + ))); + + // When: + final Exception e = assertThrows( + Exception.class, + () -> expressionTypeManager.getExpressionSqlType(expression1) + ); + final SqlType exprType = expressionTypeManager.getExpressionSqlType(expression2); + + // Then: + assertThat(exprType, is(SqlTypes.DOUBLE)); + verify(udfFactory).getFunction( + ImmutableList.of( + SqlArgument.of(SqlTypes.map(SqlTypes.BIGINT, SqlTypes.DOUBLE)), + SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.BIGINT, SqlTypes.DOUBLE), SqlTypes.BIGINT)))); + verify(function).getReturnType( + ImmutableList.of( + SqlArgument.of(SqlTypes.map(SqlTypes.BIGINT, SqlTypes.DOUBLE)), + SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.BIGINT, SqlTypes.DOUBLE), SqlTypes.BIGINT)))); + assertThat(e.getMessage(), Matchers.containsString( + "Was expecting 2 arguments but found 1, [X]. Check your lambda statement.")); + } + + @Test + public void shouldEvaluateAnyNumberOfArgumentLambda() { + // Given: + givenUdfWithNameAndReturnType("TRANSFORM", SqlTypes.STRING); + final Expression expression = + new FunctionCall( + FunctionName.of("TRANSFORM"), + ImmutableList.of( + ARRAYCOL, + new StringLiteral("Q"), + MAPCOL, + new LambdaFunctionCall( + ImmutableList.of("A", "B", "C", "D"), + new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaVariable("C"), + new IntegerLiteral(5)) + ))); + + // When: + final SqlType exprType = expressionTypeManager.getExpressionSqlType(expression); + + // Then: + assertThat(exprType, is(SqlTypes.STRING)); + verify(udfFactory).getFunction( + ImmutableList.of( + SqlArgument.of(SqlTypes.array(SqlTypes.DOUBLE)), + SqlArgument.of(SqlTypes.STRING), + SqlArgument.of(SqlTypes.map(SqlTypes.BIGINT, SqlTypes.DOUBLE)), + SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.DOUBLE, SqlTypes.STRING, SqlTypes.BIGINT, SqlTypes.DOUBLE), SqlTypes.BIGINT)))); + verify(function).getReturnType( + ImmutableList.of( + SqlArgument.of(SqlTypes.array(SqlTypes.DOUBLE)), + SqlArgument.of(SqlTypes.STRING), + SqlArgument.of(SqlTypes.map(SqlTypes.BIGINT, SqlTypes.DOUBLE)), + SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.DOUBLE, SqlTypes.STRING, SqlTypes.BIGINT, SqlTypes.DOUBLE), SqlTypes.BIGINT)))); + } + + @Test + public void shouldEvaluateLambdaArgsToType() { + // Given: + givenUdfWithNameAndReturnType("TRANSFORM", SqlTypes.STRING); + final Expression expression = + new FunctionCall( + FunctionName.of("TRANSFORM"), + ImmutableList.of( + ARRAYCOL, + new StringLiteral("Q"), + new LambdaFunctionCall( + ImmutableList.of("A", "B"), + new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaVariable("A"), + new LambdaVariable("B")) + ))); + + // When: + final Exception e = assertThrows( + Exception.class, + () -> expressionTypeManager.getExpressionSqlType(expression) + ); + + // Then: + assertThat(e.getMessage(), Matchers.containsString( + "Unsupported arithmetic types. DOUBLE STRING")); + } + + @Test + public void shouldFailToEvaluateLambdaWithMismatchedArgumentNumber() { + // Given: + givenUdfWithNameAndReturnType("TRANSFORM", SqlTypes.DOUBLE); + final Expression expression = + new FunctionCall( + FunctionName.of("TRANSFORM"), + ImmutableList.of( + ARRAYCOL, + new LambdaFunctionCall( + ImmutableList.of("X", "Y"), + new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaVariable("X"), + new IntegerLiteral(5)) + ))); + + // When: + final Exception e = assertThrows( + Exception.class, + () -> expressionTypeManager.getExpressionSqlType(expression) + ); + + // Then: + assertThat(e.getMessage(), Matchers.containsString( + "Was expecting 1 arguments but found 2, [X, Y]. Check your lambda statement.")); } @Test diff --git a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java index 096ab1de9f32..01d0dd5129d7 100644 --- a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java +++ b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java @@ -68,6 +68,16 @@ public boolean equals(final Object o) { return false; } final SqlArgument that = (SqlArgument) o; - return that.sqlType == this.sqlType && that.sqlLambda == this.sqlLambda; + return Objects.equals(sqlType, that.sqlType) + && Objects.equals(sqlLambda, that.sqlLambda); + } + + @Override + public String toString() { + if (sqlType != null) { + return sqlType.toString(); + } else { + return sqlLambda.toString(); + } } } diff --git a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java index 719ac5890aa0..bde0dffe54f6 100644 --- a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java +++ b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java @@ -17,6 +17,7 @@ import static java.util.Objects.requireNonNull; +import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; import io.confluent.ksql.schema.utils.FormatOptions; import java.util.List; @@ -43,7 +44,7 @@ public SqlLambda( final List inputTypes, final SqlType returnType ) { - this.inputTypes = requireNonNull(inputTypes, "inputType"); + this.inputTypes = ImmutableList.copyOf(requireNonNull(inputTypes, "inputType")); this.returnType = requireNonNull(returnType, "returnType"); } @@ -55,6 +56,7 @@ public SqlType getReturnType() { return returnType; } + @Override public boolean equals(final Object o) { if (this == o) { return true; @@ -67,6 +69,7 @@ public boolean equals(final Object o) { && Objects.equals(returnType, lambda.returnType); } + @Override public int hashCode() { return Objects.hash(inputTypes, returnType); } From 15410339aacd29f9f15d2627e0b75229d7a1def7 Mon Sep 17 00:00:00 2001 From: Leah Thomas Date: Fri, 19 Feb 2021 10:26:20 -0600 Subject: [PATCH 4/6] review fixes --- .../confluent/ksql/function/GenericsUtil.java | 18 +++++----- .../ksql/function/types/LambdaType.java | 8 ++++- .../ksql/function/types/ParamTypes.java | 18 ++++++---- .../confluent/ksql/function/UdfIndexTest.java | 34 +++++++++++++++---- .../ksql/function/FunctionLoaderUtils.java | 3 ++ .../ksql/function/UdfLoaderTest.java | 4 +++ .../ksql/execution/codegen/CodeGenRunner.java | 23 ++++--------- .../execution/codegen/SqlToJavaVisitor.java | 31 +++++------------ .../ksql/execution/codegen/TypeContext.java | 21 ++++++++++++ .../codegen/helpers/CastEvaluator.java | 4 +-- .../execution/codegen/helpers/LambdaUtil.java | 6 ++-- .../execution/util/ExpressionTypeManager.java | 26 +++++--------- .../codegen/helpers/LambdaUtilTest.java | 8 ++--- .../codegen/helpers/NullSafeTest.java | 4 +-- .../streaming/StreamedQueryResource.java | 4 +-- .../ksql/schema/ksql/types/SqlLambda.java | 14 ++++++-- 16 files changed, 129 insertions(+), 97 deletions(-) diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java index 8b5b099cd799..fad20ab4424e 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java @@ -237,7 +237,6 @@ private static boolean resolveGenerics( if (schema instanceof LambdaType) { final LambdaType lambdaType = (LambdaType) schema; final SqlLambda sqlLambda = instance.getSqlLambda(); - boolean resolvedInputs = true; if (sqlLambda.getInputType().size() != lambdaType.inputTypes().size()) { throw new KsqlException( "Number of lambda arguments doesn't match between schema and sql type"); @@ -245,13 +244,14 @@ private static boolean resolveGenerics( int i = 0; for (final ParamType paramType : lambdaType.inputTypes()) { - resolvedInputs = - resolvedInputs && resolveGenerics( - mapping, paramType, SqlArgument.of(sqlLambda.getInputType().get(i)) - ); + if (!resolveGenerics( + mapping, paramType, SqlArgument.of(sqlLambda.getInputType().get(i)) + )) { + return false; + } i++; } - return resolvedInputs && resolveGenerics( + return resolveGenerics( mapping, lambdaType.returnType(), SqlArgument.of(sqlLambda.getReturnType()) ); } @@ -260,10 +260,8 @@ resolvedInputs && resolveGenerics( } private static boolean matches(final ParamType schema, final SqlArgument instance) { - if (schema instanceof LambdaType && instance.getSqlLambda() != null) { - return true; - } else if (schema instanceof LambdaType || instance.getSqlLambda() != null) { - return false; + if (instance.getSqlLambda() != null) { + return schema instanceof LambdaType; } final ParamType instanceParamType = SchemaConverters .sqlToFunctionConverter().toFunctionType(instance.getSqlType()); diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/LambdaType.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/LambdaType.java index 04ad120978b2..97b3023d1886 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/LambdaType.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/LambdaType.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; public final class LambdaType extends ObjectType { @@ -68,6 +69,11 @@ public int hashCode() { @Override public String toString() { - return "LAMBDA<" + inputTypes + ", " + returnType + ">"; + return "LAMBDA " + + inputTypes.stream() + .map(Object::toString) + .collect(Collectors.joining(", ", "(", ")")) + + " -> " + + returnType; } } diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java index e27ed886a235..3b25cb1f79ab 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java @@ -41,10 +41,6 @@ private ParamTypes() { public static final ParamType DECIMAL = DecimalType.INSTANCE; public static final TimestampType TIMESTAMP = TimestampType.INSTANCE; - public static boolean areCompatible(final SqlType actual, final ParamType declared) { - return areCompatible(SqlArgument.of(actual), declared, false); - } - // CHECKSTYLE_RULES.OFF: CyclomaticComplexity // CHECKSTYLE_RULES.OFF: NPathComplexity public static boolean areCompatible( @@ -64,12 +60,19 @@ public static boolean areCompatible( } int i = 0; for (final ParamType paramType: declaredLambda.inputTypes()) { - if (!areCompatible(sqlLambda.getInputType().get(i), paramType)) { + if (!areCompatible( + SqlArgument.of(sqlLambda.getInputType().get(i)), + paramType, + allowCast) + ) { return false; } i++; } - return areCompatible(sqlLambda.getReturnType(), declaredLambda.returnType()); + return areCompatible( + SqlArgument.of(sqlLambda.getReturnType()), + declaredLambda.returnType(), + allowCast); } if (argumentSqlType.baseType() == SqlBaseType.ARRAY && declared instanceof ArrayType) { @@ -109,7 +112,8 @@ private static boolean isStructCompatible(final SqlType actual, final ParamType final String k = entry.getKey(); final Optional field = actualStruct.field(k); // intentionally do not allow implicit casting within structs - if (!field.isPresent() || !areCompatible(field.get().type(), entry.getValue())) { + if (!field.isPresent() || + !areCompatible(SqlArgument.of(field.get().type()), entry.getValue(), false)) { return false; } } diff --git a/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java b/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java index ab533d976eb3..878e2b73f981 100644 --- a/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java +++ b/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java @@ -47,7 +47,8 @@ public class UdfIndexTest { private static final ParamType STRUCT2 = StructType.builder().field("b", INT).build(); private static final ParamType MAP1 = MapType.of(STRING, STRING); private static final ParamType MAP2 = MapType.of(INT, INT); - private static final ParamType LAMBDA_FUNCTION = LambdaType.of(ImmutableList.of(GenericType.of("A")), GenericType.of("B")); + private static final ParamType LAMBDA_KEY_FUNCTION = LambdaType.of(ImmutableList.of(GenericType.of("A")), GenericType.of("C")); + private static final ParamType LAMBDA_VALUE_FUNCTION = LambdaType.of(ImmutableList.of(GenericType.of("B")), GenericType.of("D")); private static final ParamType LAMBDA_BI_FUNCTION = LambdaType.of(ImmutableList.of(GenericType.of("A"), GenericType.of("B")), GenericType.of("C")); private static final ParamType LAMBDA_BI_FUNCTION_STRING = LambdaType.of(ImmutableList.of(STRING, STRING), GenericType.of("A")); @@ -56,6 +57,7 @@ public class UdfIndexTest { private static final SqlType ARRAY_ARG = SqlTypes.array(INTEGER); private static final SqlType MAP1_ARG = SqlTypes.map(SqlTypes.STRING, SqlTypes.STRING); + private static final SqlType MAP2_ARG = SqlTypes.map(SqlTypes.STRING, INTEGER); private static final SqlType DECIMAL1_ARG = SqlTypes.decimal(4, 2); private static final SqlType STRUCT1_ARG = SqlTypes.struct().field("a", SqlTypes.STRING).build(); @@ -63,6 +65,8 @@ public class UdfIndexTest { private static final FunctionName EXPECTED = FunctionName.of("expected"); private static final FunctionName OTHER = FunctionName.of("other"); + private static final FunctionName FIRST_FUNC = FunctionName.of("first_func"); + private static final FunctionName SECOND_FUNC = FunctionName.of("second_func"); private UdfIndex udfIndex; @@ -242,20 +246,32 @@ public void shouldChooseCorrectMap() { public void shouldChooseCorrectLambdaFunction() { // Given: givenFunctions( - function(EXPECTED, false, GENERIC_LIST, LAMBDA_FUNCTION) + function(FIRST_FUNC, false, GENERIC_MAP, LAMBDA_KEY_FUNCTION) + ); + givenFunctions( + function(SECOND_FUNC, false, GENERIC_MAP, LAMBDA_VALUE_FUNCTION) ); // When: - final KsqlScalarFunction fun = udfIndex.getFunction( + final KsqlScalarFunction first_fun = udfIndex.getFunction( ImmutableList.of( - SqlArgument.of(ARRAY_ARG), + SqlArgument.of(MAP2_ARG), SqlArgument.of( SqlLambda.of( ImmutableList.of(SqlTypes.STRING), + SqlTypes.STRING)))); + + final KsqlScalarFunction second_fun = udfIndex.getFunction( + ImmutableList.of( + SqlArgument.of(MAP2_ARG), + SqlArgument.of( + SqlLambda.of( + ImmutableList.of(INTEGER), INTEGER)))); // Then: - assertThat(fun.name(), equalTo(EXPECTED)); + assertThat(first_fun.name(), equalTo(FIRST_FUNC)); + assertThat(second_fun.name(), equalTo(SECOND_FUNC)); } @Test @@ -316,9 +332,11 @@ public void shouldChooseCorrectLambdaForTypeSpecificCollections() { // Then: assertThat(fun1.name(), equalTo(EXPECTED)); assertThat(fun2.name(), equalTo(EXPECTED)); + assertThat(e.getMessage(), containsString("does not accept parameters (" + + "MAP, LAMBDA (BOOLEAN, INTEGER) -> A).")); assertThat(e.getMessage(), containsString("Valid alternatives are:" + lineSeparator() - + "expected(MAP, LAMBDA<[VARCHAR, VARCHAR], A>)")); + + "expected(MAP, LAMBDA (VARCHAR, VARCHAR) -> A)")); } @Test @@ -353,9 +371,11 @@ public void shouldThrowOnInvalidLambdaMapping() { ); // Then: + assertThat(e1.getMessage(), containsString("does not accept parameters (" + + "MAP, LAMBDA (BOOLEAN, STRING) -> A).")); assertThat(e1.getMessage(), containsString("Valid alternatives are:" + lineSeparator() - + "other(MAP, LAMBDA<[A, B], C>)")); + + "other(MAP, LAMBDA (A, B) -> C)")); assertThat(e2.getMessage(), containsString("Number of lambda arguments doesn't match between schema and sql type")); } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java index 1582e55c9009..203e35c0ce04 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java @@ -180,6 +180,9 @@ static SchemaProvider handleUdfReturnSchema( for (int i = 0; i < Math.min(parameters.size(), arguments.size()); i++) { final ParamType schema = parameters.get(i); if (schema instanceof LambdaType) { + if (isVariadic) { + throw new KsqlException(String.format("Lambda function %s cannot be variadic.", arguments.get(i).toString())); + } genericMapping.putAll(GenericsUtil.resolveGenerics(schema, arguments.get(i))); } else { // we resolve any variadic as if it were an array so that the type diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java index a0d70c09c941..7847ee6b359d 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java @@ -41,6 +41,10 @@ import com.google.common.collect.ImmutableMap; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import io.confluent.ksql.execution.function.TableAggregationFunction; +import io.confluent.ksql.function.types.ArrayType; +import io.confluent.ksql.function.types.LambdaType; +import io.confluent.ksql.function.types.ParamType; +import io.confluent.ksql.function.types.ParamTypes; import io.confluent.ksql.function.udaf.TestUdaf; import io.confluent.ksql.function.udaf.Udaf; import io.confluent.ksql.function.udf.Kudf; diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java index 65ebd07db946..92331220878f 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java @@ -189,27 +189,16 @@ public Void visitFunctionCall(final FunctionCall node, final TypeContext context final FunctionName functionName = node.getName(); for (final Expression argExpr : node.getArguments()) { process(argExpr, context); - final SqlType newSqlType = expressionTypeManager.getExpressionSqlType(argExpr, context); - // for lambdas - if we see this it's the array/map being passed in we save the type - if (context.notAllInputsSeen()) { - if (newSqlType instanceof SqlArray) { - final SqlArray inputArray = (SqlArray) newSqlType; - context.addLambdaInputType(inputArray.getItemType()); - } else if (newSqlType instanceof SqlMap) { - final SqlMap inputMap = (SqlMap) newSqlType; - context.addLambdaInputType(inputMap.getKeyType()); - context.addLambdaInputType(inputMap.getValueType()); - } else { - context.addLambdaInputType(newSqlType); - } - } - + final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, context); + // for lambdas - if we find an array or map passed in before encountering a lambda function + // we save the type information to resolve the lambda generics + context.visitType(resolvedArgType); if (argExpr instanceof LambdaFunctionCall) { argumentTypes.add( SqlArgument.of( - SqlLambda.of(context.getLambdaInputTypes(), newSqlType))); + SqlLambda.of(context.getLambdaInputTypes(), resolvedArgType))); } else { - argumentTypes.add(SqlArgument.of(newSqlType)); + argumentTypes.add(SqlArgument.of(resolvedArgType)); } } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index 905117840a08..df5f397918c9 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -364,7 +364,6 @@ public Pair visitNullLiteral( } @Override - // CHECKSTYLE_RULES.OFF: TodoComment public Pair visitLambdaExpression( final LambdaFunctionCall lambdaFunctionCall, final TypeContext context) { @@ -378,7 +377,7 @@ public Pair visitLambdaExpression( SchemaConverters.sqlToJavaConverter().toJavaType(context.getLambdaType(lambdaArg)) )); } - return new Pair<>(LambdaUtil.function(argPairs, lambdaBody.getLeft()), + return new Pair<>(LambdaUtil.toJavaCode(argPairs, lambdaBody.getLeft()), expressionTypeManager.getExpressionSqlType(lambdaFunctionCall, context)); } @@ -456,26 +455,16 @@ public Pair visitFunctionCall( final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName()); final List argumentSchemas = new ArrayList<>(); for (final Expression argExpr : node.getArguments()) { - final SqlType newSqlType = expressionTypeManager.getExpressionSqlType(argExpr, context); - // for lambdas: if it's the array/map being passed in we save the type for later - if (context.notAllInputsSeen()) { - if (newSqlType instanceof SqlArray) { - final SqlArray inputArray = (SqlArray) newSqlType; - context.addLambdaInputType(inputArray.getItemType()); - } else if (newSqlType instanceof SqlMap) { - final SqlMap inputMap = (SqlMap) newSqlType; - context.addLambdaInputType(inputMap.getKeyType()); - context.addLambdaInputType(inputMap.getValueType()); - } else { - context.addLambdaInputType(newSqlType); - } - } + final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, context); + // for lambdas - if we find an array or map passed in before encountering a lambda function + // we save the type information to resolve the lambda generics + context.visitType(resolvedArgType); if (argExpr instanceof LambdaFunctionCall) { argumentSchemas.add( - SqlArgument.of(SqlLambda.of(context.getLambdaInputTypes(), newSqlType)) - ); + SqlArgument.of( + SqlLambda.of(context.getLambdaInputTypes(), resolvedArgType))); } else { - argumentSchemas.add(SqlArgument.of(newSqlType)); + argumentSchemas.add(SqlArgument.of(resolvedArgType)); } } @@ -499,9 +488,7 @@ public Pair visitFunctionCall( paramType = function.parameters().get(i); } - final Pair pair = - process(convertArgument(arg, sqlType, paramType), context); - joiner.add(pair.getLeft()); + joiner.add(process(convertArgument(arg, sqlType, paramType), context).getLeft()); } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/TypeContext.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/TypeContext.java index ef800bb36291..11fb919440fd 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/TypeContext.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/TypeContext.java @@ -15,6 +15,12 @@ package io.confluent.ksql.execution.codegen; +import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; +import io.confluent.ksql.schema.ksql.SqlArgument; +import io.confluent.ksql.schema.ksql.types.SqlArray; +import io.confluent.ksql.schema.ksql.types.SqlLambda; +import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlType; import java.util.ArrayList; @@ -64,4 +70,19 @@ public SqlType getLambdaType(final String name) { public boolean notAllInputsSeen() { return lambdaInputTypeMapping.size() != lambdaInputTypes.size() || lambdaInputTypes.size() == 0; } + + public void visitType(SqlType type) { + if (notAllInputsSeen()) { + if (type instanceof SqlArray) { + final SqlArray inputArray = (SqlArray) type; + addLambdaInputType(inputArray.getItemType()); + } else if (type instanceof SqlMap) { + final SqlMap inputMap = (SqlMap) type; + addLambdaInputType(inputMap.getKeyType()); + addLambdaInputType(inputMap.getValueType()); + } else { + addLambdaInputType(type); + } + } + } } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/CastEvaluator.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/CastEvaluator.java index 352ffacff213..1c3bf0c71e80 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/CastEvaluator.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/CastEvaluator.java @@ -192,7 +192,7 @@ private static CastFunction nonNullSafeCode(final String code) { final Class toJavaType = SchemaConverters.sqlToJavaConverter().toJavaType(to); final String lambdaBody = String.format(code, "val"); - final String function = LambdaUtil.function("val", fromJavaType, lambdaBody); + final String function = LambdaUtil.toJavaCode("val", fromJavaType, lambdaBody); return NullSafe.generateApply(innerCode, function, toJavaType); }; } @@ -322,7 +322,7 @@ private static String mapperFunction( final String lambdaBody = generateCode("val", fromItemType, toItemType, config); - return LambdaUtil.function("val", javaType, lambdaBody); + return LambdaUtil.toJavaCode("val", javaType, lambdaBody); } @FunctionalInterface diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java index 9b6dd8f0cd37..af9dc83c01dc 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java @@ -38,12 +38,12 @@ private LambdaUtil() { * @param lambdaBody the body of the lambda. * @return code to instantiate the function. */ - public static String function( + public static String toJavaCode( final String argName, final Class argType, final String lambdaBody ) { - return function(ImmutableList.of(new Pair<>(argName, argType)), lambdaBody); + return toJavaCode(ImmutableList.of(new Pair<>(argName, argType)), lambdaBody); } /** @@ -55,7 +55,7 @@ public static String function( * @return code to instantiate the function. */ // CHECKSTYLE_RULES.OFF: FinalLocalVariable - public static String function( + public static String toJavaCode( final List>> argList, final String lambdaBody ) { diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index dcfc0c21a3d4..0d4b6fdd8dd0 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -476,26 +476,16 @@ public Void visitFunctionCall( final List argTypes = new ArrayList<>(); for (final Expression expression : node.getArguments()) { process(expression, expressionTypeContext); - final SqlType newSqlType = expressionTypeContext.getSqlType(); + final SqlType resolvedArgType = expressionTypeContext.getSqlType(); + // for lambdas - if we find an array or map passed in before encountering a lambda function + // we save the type information to resolve the lambda generics + expressionTypeContext.visitType(resolvedArgType); if (expression instanceof LambdaFunctionCall) { - argTypes.add(SqlArgument.of( - SqlLambda.of(expressionTypeContext.getLambdaInputTypes(), - expressionTypeContext.getSqlType()) - )); + argTypes.add( + SqlArgument.of( + SqlLambda.of(expressionTypeContext.getLambdaInputTypes(), resolvedArgType))); } else { - argTypes.add(SqlArgument.of(newSqlType)); - } - if (expressionTypeContext.notAllInputsSeen()) { - if (newSqlType instanceof SqlArray) { - final SqlArray inputArray = (SqlArray) newSqlType; - expressionTypeContext.addLambdaInputType(inputArray.getItemType()); - } else if (newSqlType instanceof SqlMap) { - final SqlMap inputMap = (SqlMap) newSqlType; - expressionTypeContext.addLambdaInputType(inputMap.getKeyType()); - expressionTypeContext.addLambdaInputType(inputMap.getValueType()); - } else { - expressionTypeContext.addLambdaInputType(newSqlType); - } + argTypes.add(SqlArgument.of(resolvedArgType)); } } diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtilTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtilTest.java index 24b53e8d30bd..599f99090a03 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtilTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtilTest.java @@ -41,7 +41,7 @@ public void shouldGenerateFunctionCode() { // When: final String javaCode = LambdaUtil - .function(argName, argType, argName + " + 1"); + .toJavaCode(argName, argType, argName + " + 1"); // Then: final Object result = CodeGenTestUtil.cookAndEval(javaCode, Function.class); @@ -58,7 +58,7 @@ public void shouldGenerateBiFunction() { final List>> argList = ImmutableList.of(argName1, argName2); // When: - final String javaCode = LambdaUtil.function(argList, "fred + bob + 2"); + final String javaCode = LambdaUtil.toJavaCode(argList, "fred + bob + 2"); // Then: final Object result = CodeGenTestUtil.cookAndEval(javaCode, BiFunction.class); @@ -76,7 +76,7 @@ public void shouldGenerateTriFunction() { final List>> argList = ImmutableList.of(argName1, argName2, argName3); // When: - final String javaCode = LambdaUtil.function(argList, "fred + bob + tim + 1"); + final String javaCode = LambdaUtil.toJavaCode(argList, "fred + bob + tim + 1"); // Then: final Object result = CodeGenTestUtil.cookAndEval(javaCode, TriFunction.class); @@ -95,6 +95,6 @@ public void shouldThrowOnNonSupportedArguments() { final List>> argList = ImmutableList.of(argName1, argName2, argName3, argName4); // When: - LambdaUtil.function(argList, "fred + bob + tim + hello + 1"); + LambdaUtil.toJavaCode(argList, "fred + bob + tim + hello + 1"); } } \ No newline at end of file diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/NullSafeTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/NullSafeTest.java index 7dc8167545e7..0bff0f78c164 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/NullSafeTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/NullSafeTest.java @@ -14,7 +14,7 @@ public class NullSafeTest { public void shouldGenerateApply() { // Given: final String mapperCode = LambdaUtil - .function("val", Long.class, "val.longValue() + 1"); + .toJavaCode("val", Long.class, "val.longValue() + 1"); // When: final String javaCode = NullSafe @@ -30,7 +30,7 @@ public void shouldGenerateApply() { public void shouldGenerateApplyOrDefault() { // Given: final String mapperCode = LambdaUtil - .function("val", Long.class, "val.longValue() + 1"); + .toJavaCode("val", Long.class, "val.longValue() + 1"); // When: final String javaCode = NullSafe diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java index 4128c3446fa3..8b4588b5c669 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java @@ -393,8 +393,8 @@ private EndpointResponse handlePrintTopic( final String reverseSuggestion = possibleAlternatives.isEmpty() ? "" : possibleAlternatives.stream() - .map(name -> "\tprint " + name + ";") - .collect(Collectors.joining( + .map(name -> "\tprint " + name + ";") + .collect(Collectors.joining( System.lineSeparator(), System.lineSeparator() + "Did you mean:" + System.lineSeparator(), "" diff --git a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java index bde0dffe54f6..f5576729b9f1 100644 --- a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java +++ b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java @@ -22,6 +22,7 @@ import io.confluent.ksql.schema.utils.FormatOptions; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; /** * An internal object to track the input types and return types for a lambda @@ -75,10 +76,19 @@ public int hashCode() { } public String toString() { - return toString(FormatOptions.none()); + return "LAMBDA " + + inputTypes.stream() + .map(Object::toString) + .collect(Collectors.joining(", ", "(", ")")) + + " -> A"; } public String toString(final FormatOptions formatOptions) { - return "Lambda<" + inputTypes + ", " + returnType + ">"; + return "Lambda " + + inputTypes.stream() + .map(Object::toString) + .collect(Collectors.joining(", ", "(", ")")) + + " -> " + + returnType; } } From 6a2e65d4dac208382e110774cf80a7b02e386f5a Mon Sep 17 00:00:00 2001 From: Leah Thomas Date: Thu, 18 Feb 2021 14:29:57 -0600 Subject: [PATCH 5/6] Adding support for multiple variable names --- .../ksql/function/types/LambdaType.java | 2 +- .../ksql/function/types/ParamTypes.java | 4 +- .../confluent/ksql/function/UdfIndexTest.java | 8 +-- .../ksql/function/types/ParamTypesTest.java | 52 +++++++++++++------ .../confluent/ksql/util/GenericsUtilTest.java | 2 +- .../ksql/function/FunctionLoaderUtils.java | 2 +- .../ksql/execution/codegen/CodeGenRunner.java | 32 +++++++++--- .../execution/codegen/SqlToJavaVisitor.java | 22 +++++--- .../ksql/execution/codegen/TypeContext.java | 41 +++++++++------ .../execution/util/ExpressionTypeManager.java | 35 ++++++++++--- .../ksql/schema/ksql/types/SqlLambda.java | 7 +-- 11 files changed, 143 insertions(+), 64 deletions(-) diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/LambdaType.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/LambdaType.java index 97b3023d1886..2f2c473d4cdf 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/LambdaType.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/LambdaType.java @@ -73,7 +73,7 @@ public String toString() { + inputTypes.stream() .map(Object::toString) .collect(Collectors.joining(", ", "(", ")")) - + " -> " + + " => " + returnType; } } diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java index 3b25cb1f79ab..026b1bb0cb1d 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java @@ -112,8 +112,8 @@ private static boolean isStructCompatible(final SqlType actual, final ParamType final String k = entry.getKey(); final Optional field = actualStruct.field(k); // intentionally do not allow implicit casting within structs - if (!field.isPresent() || - !areCompatible(SqlArgument.of(field.get().type()), entry.getValue(), false)) { + if (!field.isPresent() + || !areCompatible(SqlArgument.of(field.get().type()), entry.getValue(), false)) { return false; } } diff --git a/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java b/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java index 878e2b73f981..2530ad9ae54e 100644 --- a/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java +++ b/ksqldb-common/src/test/java/io/confluent/ksql/function/UdfIndexTest.java @@ -333,10 +333,10 @@ public void shouldChooseCorrectLambdaForTypeSpecificCollections() { assertThat(fun1.name(), equalTo(EXPECTED)); assertThat(fun2.name(), equalTo(EXPECTED)); assertThat(e.getMessage(), containsString("does not accept parameters (" + - "MAP, LAMBDA (BOOLEAN, INTEGER) -> A).")); + "MAP, LAMBDA (BOOLEAN, INTEGER) => INTEGER).")); assertThat(e.getMessage(), containsString("Valid alternatives are:" + lineSeparator() - + "expected(MAP, LAMBDA (VARCHAR, VARCHAR) -> A)")); + + "expected(MAP, LAMBDA (VARCHAR, VARCHAR) => A)")); } @Test @@ -372,10 +372,10 @@ public void shouldThrowOnInvalidLambdaMapping() { // Then: assertThat(e1.getMessage(), containsString("does not accept parameters (" + - "MAP, LAMBDA (BOOLEAN, STRING) -> A).")); + "MAP, LAMBDA (BOOLEAN, STRING) => INTEGER).")); assertThat(e1.getMessage(), containsString("Valid alternatives are:" + lineSeparator() - + "other(MAP, LAMBDA (A, B) -> C)")); + + "other(MAP, LAMBDA (A, B) => C)")); assertThat(e2.getMessage(), containsString("Number of lambda arguments doesn't match between schema and sql type")); } diff --git a/ksqldb-common/src/test/java/io/confluent/ksql/function/types/ParamTypesTest.java b/ksqldb-common/src/test/java/io/confluent/ksql/function/types/ParamTypesTest.java index 9bf843d03694..4d6736cbe05b 100644 --- a/ksqldb-common/src/test/java/io/confluent/ksql/function/types/ParamTypesTest.java +++ b/ksqldb-common/src/test/java/io/confluent/ksql/function/types/ParamTypesTest.java @@ -29,27 +29,41 @@ public class ParamTypesTest { @Test public void shouldFailINonCompatibleSchemas() { - assertThat(ParamTypes.areCompatible(SqlTypes.STRING, ParamTypes.INTEGER), is(false)); + assertThat(ParamTypes.areCompatible( + SqlArgument.of(SqlTypes.STRING), + ParamTypes.INTEGER, + false), + is(false)); - assertThat(ParamTypes.areCompatible(SqlTypes.STRING, GenericType.of("T")), is(false)); + assertThat(ParamTypes.areCompatible( + SqlArgument.of(SqlTypes.STRING), + GenericType.of("T"), + false), + is(false)); assertThat( - ParamTypes.areCompatible(SqlTypes.array(SqlTypes.INTEGER), ArrayType.of(ParamTypes.STRING)), + ParamTypes.areCompatible( + SqlArgument.of(SqlTypes.array(SqlTypes.INTEGER)), + ArrayType.of(ParamTypes.STRING), + false), is(false)); assertThat(ParamTypes.areCompatible( - SqlTypes.struct().field("a", SqlTypes.decimal(1, 1)).build(), - StructType.builder().field("a", ParamTypes.DOUBLE).build()), + SqlArgument.of(SqlTypes.struct().field("a", SqlTypes.decimal(1, 1)).build()), + StructType.builder().field("a", ParamTypes.DOUBLE).build(), + false), is(false)); assertThat(ParamTypes.areCompatible( - SqlTypes.map(SqlTypes.STRING, SqlTypes.decimal(1, 1)), - MapType.of(ParamTypes.STRING, ParamTypes.INTEGER)), + SqlArgument.of(SqlTypes.map(SqlTypes.STRING, SqlTypes.decimal(1, 1))), + MapType.of(ParamTypes.STRING, ParamTypes.INTEGER), + false), is(false)); assertThat(ParamTypes.areCompatible( - SqlTypes.map(SqlTypes.decimal(1, 1), SqlTypes.INTEGER), - MapType.of(ParamTypes.INTEGER, ParamTypes.INTEGER)), + SqlArgument.of(SqlTypes.map(SqlTypes.decimal(1, 1), SqlTypes.INTEGER)), + MapType.of(ParamTypes.INTEGER, ParamTypes.INTEGER), + false), is(false)); @@ -62,21 +76,29 @@ public void shouldFailINonCompatibleSchemas() { @Test public void shouldPassCompatibleSchemas() { - assertThat(ParamTypes.areCompatible(SqlTypes.STRING, ParamTypes.STRING), + assertThat(ParamTypes.areCompatible( + SqlArgument.of(SqlTypes.STRING), + ParamTypes.STRING, + false), is(true)); assertThat( - ParamTypes.areCompatible(SqlTypes.array(SqlTypes.INTEGER), ArrayType.of(ParamTypes.INTEGER)), + ParamTypes.areCompatible( + SqlArgument.of(SqlTypes.array(SqlTypes.INTEGER)), + ArrayType.of(ParamTypes.INTEGER), + false), is(true)); assertThat(ParamTypes.areCompatible( - SqlTypes.struct().field("a", SqlTypes.decimal(1, 1)).build(), - StructType.builder().field("a", ParamTypes.DECIMAL).build()), + SqlArgument.of(SqlTypes.struct().field("a", SqlTypes.decimal(1, 1)).build()), + StructType.builder().field("a", ParamTypes.DECIMAL).build(), + false), is(true)); assertThat(ParamTypes.areCompatible( - SqlTypes.map(SqlTypes.INTEGER, SqlTypes.decimal(1, 1)), - MapType.of(ParamTypes.INTEGER, ParamTypes.DECIMAL)), + SqlArgument.of(SqlTypes.map(SqlTypes.INTEGER, SqlTypes.decimal(1, 1))), + MapType.of(ParamTypes.INTEGER, ParamTypes.DECIMAL), + false), is(true)); assertThat(ParamTypes.areCompatible( diff --git a/ksqldb-common/src/test/java/io/confluent/ksql/util/GenericsUtilTest.java b/ksqldb-common/src/test/java/io/confluent/ksql/util/GenericsUtilTest.java index 5dca6b6ff37b..4dcc69839ab5 100644 --- a/ksqldb-common/src/test/java/io/confluent/ksql/util/GenericsUtilTest.java +++ b/ksqldb-common/src/test/java/io/confluent/ksql/util/GenericsUtilTest.java @@ -256,7 +256,7 @@ public void shouldFailToIdentifyMismatchedGenericsInLambda() { // Then: assertThat(e.getMessage(), containsString( - "Found invalid instance of generic schema when mapping LAMBDA<[A], A> to Lambda<[DOUBLE], BOOLEAN>. " + "Found invalid instance of generic schema when mapping LAMBDA (A) => A to LAMBDA (DOUBLE) => BOOLEAN. " + "Cannot map A to both DOUBLE and BOOLEAN")); } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java index 203e35c0ce04..ad0b28c862c2 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java @@ -164,7 +164,7 @@ static SchemaProvider handleUdfReturnSchema( return (parameters, arguments) -> { if (schemaProvider != null) { final SqlType returnType = schemaProvider.apply(arguments); - if (!(ParamTypes.areCompatible(returnType, javaReturnSchema))) { + if (!(ParamTypes.areCompatible(SqlArgument.of(returnType), javaReturnSchema, false))) { throw new KsqlException(String.format( "Return type %s of UDF %s does not match the declared " + "return type %s.", diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java index 92331220878f..ac1e5dfbe1a7 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java @@ -187,18 +187,27 @@ public Void visitLikePredicate(final LikePredicate node, final TypeContext conte public Void visitFunctionCall(final FunctionCall node, final TypeContext context) { final List argumentTypes = new ArrayList<>(); final FunctionName functionName = node.getName(); + boolean hasLambda = false; + for (Expression e : node.getArguments()) { + if (e instanceof LambdaFunctionCall) { + hasLambda = true; + break; + } + } for (final Expression argExpr : node.getArguments()) { process(argExpr, context); - final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, context); - // for lambdas - if we find an array or map passed in before encountering a lambda function - // we save the type information to resolve the lambda generics - context.visitType(resolvedArgType); + final TypeContext childContext = context.getCopy(); + final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, childContext); + if (argExpr instanceof LambdaFunctionCall) { - argumentTypes.add( - SqlArgument.of( - SqlLambda.of(context.getLambdaInputTypes(), resolvedArgType))); + argumentTypes.add(SqlArgument.of(SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType()))); } else { argumentTypes.add(SqlArgument.of(resolvedArgType)); + // for lambdas - if we find an array or map passed in before encountering a lambda function + // we save the type information to resolve the lambda generics + if (hasLambda) { + context.visitType(resolvedArgType); + } } } @@ -295,5 +304,14 @@ private void addRequiredColumn(final ColumnName columnName) { column.index() ); } + + private boolean hasLambdaFunctionCall(FunctionCall node) { + for (Expression e : node.getArguments()) { + if (e instanceof LambdaFunctionCall) { + return true; + } + } + return false; + } } } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index df5f397918c9..62559feb826b 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -454,17 +454,25 @@ public Pair visitFunctionCall( final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName()); final List argumentSchemas = new ArrayList<>(); + boolean hasLambda = false; + for (Expression e : node.getArguments()) { + if (e instanceof LambdaFunctionCall) { + hasLambda = true; + break; + } + } for (final Expression argExpr : node.getArguments()) { - final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, context); - // for lambdas - if we find an array or map passed in before encountering a lambda function - // we save the type information to resolve the lambda generics - context.visitType(resolvedArgType); + final TypeContext childContext = context.getCopy(); + final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, childContext); if (argExpr instanceof LambdaFunctionCall) { - argumentSchemas.add( - SqlArgument.of( - SqlLambda.of(context.getLambdaInputTypes(), resolvedArgType))); + argumentSchemas.add(SqlArgument.of(SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType()))); } else { argumentSchemas.add(SqlArgument.of(resolvedArgType)); + // for lambdas - if we find an array or map passed in before encountering a lambda function + // we save the type information to resolve the lambda generics + if (hasLambda) { + context.visitType(resolvedArgType); + } } } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/TypeContext.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/TypeContext.java index 11fb919440fd..f851d8c9d6ac 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/TypeContext.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/TypeContext.java @@ -16,6 +16,7 @@ package io.confluent.ksql.execution.codegen; import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; import io.confluent.ksql.schema.ksql.SqlArgument; import io.confluent.ksql.schema.ksql.types.SqlArray; @@ -30,8 +31,18 @@ public class TypeContext { private SqlType sqlType; - private final List lambdaInputTypes = new ArrayList<>(); - private final Map lambdaInputTypeMapping = new HashMap<>(); + private final List lambdaInputTypes; + private final Map lambdaInputTypeMapping; + + public TypeContext() { + lambdaInputTypes = new ArrayList(); + lambdaInputTypeMapping = new HashMap<>(); + } + + TypeContext (final List lambdaInputTypes, final Map lambdaInputTypeMapping) { + this.lambdaInputTypes = lambdaInputTypes; + this.lambdaInputTypeMapping = lambdaInputTypeMapping; + } public SqlType getSqlType() { return sqlType; @@ -61,28 +72,28 @@ public void mapLambdaInputTypes(final List argumentList) { for (int i = 0; i < argumentList.size(); i++) { this.lambdaInputTypeMapping.putIfAbsent(argumentList.get(i), lambdaInputTypes.get(i)); } + lambdaInputTypes.clear(); } public SqlType getLambdaType(final String name) { return lambdaInputTypeMapping.get(name); } - public boolean notAllInputsSeen() { - return lambdaInputTypeMapping.size() != lambdaInputTypes.size() || lambdaInputTypes.size() == 0; + + public TypeContext getCopy() { + return new TypeContext(this.lambdaInputTypes, this.lambdaInputTypeMapping); } public void visitType(SqlType type) { - if (notAllInputsSeen()) { - if (type instanceof SqlArray) { - final SqlArray inputArray = (SqlArray) type; - addLambdaInputType(inputArray.getItemType()); - } else if (type instanceof SqlMap) { - final SqlMap inputMap = (SqlMap) type; - addLambdaInputType(inputMap.getKeyType()); - addLambdaInputType(inputMap.getValueType()); - } else { - addLambdaInputType(type); - } + if (type instanceof SqlArray) { + final SqlArray inputArray = (SqlArray) type; + addLambdaInputType(inputArray.getItemType()); + } else if (type instanceof SqlMap) { + final SqlMap inputMap = (SqlMap) type; + addLambdaInputType(inputMap.getKeyType()); + addLambdaInputType(inputMap.getValueType()); + } else { + addLambdaInputType(type); } } } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index 0d4b6fdd8dd0..be22575cc9bb 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -81,6 +81,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.Function; import java.util.stream.Collectors; public class ExpressionTypeManager { @@ -474,18 +475,27 @@ public Void visitFunctionCall( final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName()); final List argTypes = new ArrayList<>(); + + boolean hasLambda = false; + for (Expression e : node.getArguments()) { + if (e instanceof LambdaFunctionCall) { + hasLambda = true; + break; + } + } for (final Expression expression : node.getArguments()) { - process(expression, expressionTypeContext); - final SqlType resolvedArgType = expressionTypeContext.getSqlType(); - // for lambdas - if we find an array or map passed in before encountering a lambda function - // we save the type information to resolve the lambda generics - expressionTypeContext.visitType(resolvedArgType); + final TypeContext childContext = expressionTypeContext.getCopy(); + process(expression, childContext); + final SqlType resolvedArgType = childContext.getSqlType(); if (expression instanceof LambdaFunctionCall) { - argTypes.add( - SqlArgument.of( - SqlLambda.of(expressionTypeContext.getLambdaInputTypes(), resolvedArgType))); + argTypes.add(SqlArgument.of(SqlLambda.of(expressionTypeContext.getLambdaInputTypes(), childContext.getSqlType()))); } else { argTypes.add(SqlArgument.of(resolvedArgType)); + // for lambdas - if we find an array or map passed in before encountering a lambda function + // we save the type information to resolve the lambda generics + if (hasLambda) { + expressionTypeContext.visitType(resolvedArgType); + } } } @@ -603,4 +613,13 @@ private Optional validateWhenClauses( return previousResult; } } + + private boolean hasLambdaFunctionCall(FunctionCall node) { + for (Expression e : node.getArguments()) { + if (e instanceof LambdaFunctionCall) { + return true; + } + } + return false; + } } diff --git a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java index f5576729b9f1..178ae6bcbfee 100644 --- a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java +++ b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java @@ -80,15 +80,16 @@ public String toString() { + inputTypes.stream() .map(Object::toString) .collect(Collectors.joining(", ", "(", ")")) - + " -> A"; + + " => " + + returnType; } public String toString(final FormatOptions formatOptions) { - return "Lambda " + return "LAMBDA " + inputTypes.stream() .map(Object::toString) .collect(Collectors.joining(", ", "(", ")")) - + " -> " + + " => " + returnType; } } From fd7550a78b9bf7d1cec4ec62d388213ce855ccee Mon Sep 17 00:00:00 2001 From: Steven Zhang Date: Mon, 22 Feb 2021 00:30:28 -0800 Subject: [PATCH 6/6] sql argument optional --- .../confluent/ksql/function/GenericsUtil.java | 51 ++++----- .../io/confluent/ksql/function/UdfIndex.java | 13 +-- .../ksql/function/types/ParamTypes.java | 11 +- .../ksql/function/FunctionLoaderUtils.java | 13 ++- .../UdafAggregateFunctionFactory.java | 2 +- .../udaf/max/MaxAggFunctionFactory.java | 2 +- .../udaf/min/MinAggFunctionFactory.java | 2 +- .../udaf/sum/SumAggFunctionFactory.java | 2 +- .../topk/TopKAggregateFunctionFactory.java | 2 +- .../TopkDistinctAggFunctionFactory.java | 2 +- .../confluent/ksql/function/udf/math/Abs.java | 2 +- .../ksql/function/udf/math/Ceil.java | 2 +- .../ksql/function/udf/math/Floor.java | 2 +- .../ksql/function/udf/math/Round.java | 6 +- .../ksql/function/udtf/array/Explode.java | 2 +- .../ksql/function/UdfLoaderTest.java | 8 +- .../ksql/execution/codegen/CodeGenRunner.java | 31 ++---- .../execution/codegen/SqlToJavaVisitor.java | 29 +++-- .../ksql/execution/codegen/TypeContext.java | 19 ++-- .../expression/tree/FunctionCall.java | 5 + .../execution/util/ExpressionTypeManager.java | 26 +---- .../execution/codegen/TypeContextTest.java | 7 +- .../expression/tree/FunctionCallTest.java | 10 ++ .../ksql/schema/ksql/SqlArgument.java | 42 +++++-- .../ksql/schema/ksql/types/SqlLambda.java | 10 +- .../ksql/schema/ksql/SqlArgumentTest.java | 104 ++++++++++++++++++ 26 files changed, 261 insertions(+), 144 deletions(-) create mode 100644 ksqldb-udf/src/test/java/io/confluent/ksql/schema/ksql/SqlArgumentTest.java diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java index fad20ab4424e..76d2239db5a3 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/GenericsUtil.java @@ -196,7 +196,6 @@ private static boolean resolveGenerics( ) { // CHECKSTYLE_RULES.ON: NPathComplexity // CHECKSTYLE_RULES.ON: CyclomaticComplexity - final SqlType sqlType = instance.getSqlType(); if (!isGeneric(schema) && !matches(schema, instance)) { // cannot identify from type mismatch @@ -211,6 +210,30 @@ private static boolean resolveGenerics( "Cannot resolve generics if the schema and instance have differing types: " + schema + " vs. " + instance); + if (schema instanceof LambdaType) { + final LambdaType lambdaType = (LambdaType) schema; + final SqlLambda sqlLambda = instance.getSqlLambdaOrThrow(); + if (sqlLambda.getInputType().size() != lambdaType.inputTypes().size()) { + throw new KsqlException( + "Number of lambda arguments doesn't match between schema and sql type"); + } + + int i = 0; + for (final ParamType paramType : lambdaType.inputTypes()) { + if (!resolveGenerics( + mapping, paramType, SqlArgument.of(sqlLambda.getInputType().get(i)) + )) { + return false; + } + i++; + } + return resolveGenerics( + mapping, lambdaType.returnType(), SqlArgument.of(sqlLambda.getReturnType()) + ); + } + + final SqlType sqlType = instance.getSqlTypeOrThrow(); + if (isGeneric(schema)) { mapping.add(new HashMap.SimpleEntry<>((GenericType) schema, sqlType)); } @@ -234,37 +257,15 @@ private static boolean resolveGenerics( throw new KsqlException("Generic STRUCT is not yet supported"); } - if (schema instanceof LambdaType) { - final LambdaType lambdaType = (LambdaType) schema; - final SqlLambda sqlLambda = instance.getSqlLambda(); - if (sqlLambda.getInputType().size() != lambdaType.inputTypes().size()) { - throw new KsqlException( - "Number of lambda arguments doesn't match between schema and sql type"); - } - - int i = 0; - for (final ParamType paramType : lambdaType.inputTypes()) { - if (!resolveGenerics( - mapping, paramType, SqlArgument.of(sqlLambda.getInputType().get(i)) - )) { - return false; - } - i++; - } - return resolveGenerics( - mapping, lambdaType.returnType(), SqlArgument.of(sqlLambda.getReturnType()) - ); - } - return true; } private static boolean matches(final ParamType schema, final SqlArgument instance) { - if (instance.getSqlLambda() != null) { + if (instance.getSqlLambda().isPresent()) { return schema instanceof LambdaType; } final ParamType instanceParamType = SchemaConverters - .sqlToFunctionConverter().toFunctionType(instance.getSqlType()); + .sqlToFunctionConverter().toFunctionType(instance.getSqlTypeOrThrow()); return schema.getClass() == instanceParamType.getClass(); } diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java index 0dec36b67895..460b067e8f5c 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java @@ -22,6 +22,7 @@ import io.confluent.ksql.function.types.ParamType; import io.confluent.ksql.function.types.ParamTypes; import io.confluent.ksql.schema.ksql.SqlArgument; +import io.confluent.ksql.schema.ksql.types.SqlLambda; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.utils.FormatOptions; import io.confluent.ksql.util.KsqlException; @@ -202,12 +203,9 @@ private KsqlException createNoMatchingFunctionException(final List if (argument == null) { return "null"; } else { - final SqlType sqlType = argument.getSqlType(); - if (sqlType != null) { - return sqlType.toString(FormatOptions.noEscape()); - } else { - return argument.getSqlLambda().toString(); - } + final Optional sqlLambda = argument.getSqlLambda(); + return sqlLambda.map(lambda -> lambda.toString(FormatOptions.noEscape())) + .orElseGet(() -> argument.getSqlTypeOrThrow().toString(FormatOptions.noEscape())); } }) .collect(Collectors.joining(", ", "(", ")")); @@ -362,7 +360,8 @@ public int hashCode() { // CHECKSTYLE_RULES.OFF: BooleanExpressionComplexity boolean accepts(final SqlArgument argument, final Map reservedGenerics, final boolean allowCasts) { - if (argument == null || (argument.getSqlLambda() == null && argument.getSqlType() == null)) { + if (argument == null + || (!argument.getSqlLambda().isPresent() && !argument.getSqlType().isPresent())) { return true; } diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java index 026b1bb0cb1d..0ae2482295bc 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/function/types/ParamTypes.java @@ -41,6 +41,10 @@ private ParamTypes() { public static final ParamType DECIMAL = DecimalType.INSTANCE; public static final TimestampType TIMESTAMP = TimestampType.INSTANCE; + public static boolean areCompatible(final SqlArgument actual, final ParamType declared) { + return areCompatible(actual, declared, false); + } + // CHECKSTYLE_RULES.OFF: CyclomaticComplexity // CHECKSTYLE_RULES.OFF: NPathComplexity public static boolean areCompatible( @@ -50,10 +54,10 @@ public static boolean areCompatible( ) { // CHECKSTYLE_RULES.ON: CyclomaticComplexity // CHECKSTYLE_RULES.ON: NPathComplexity - final SqlType argumentSqlType = argument.getSqlType(); - final SqlLambda sqlLambda = argument.getSqlLambda(); + final Optional sqlLambdaOptional = argument.getSqlLambda(); - if (sqlLambda != null && declared instanceof LambdaType) { + if (sqlLambdaOptional.isPresent() && declared instanceof LambdaType) { + final SqlLambda sqlLambda = sqlLambdaOptional.get(); final LambdaType declaredLambda = (LambdaType) declared; if (sqlLambda.getInputType().size() != declaredLambda.inputTypes().size()) { return false; @@ -75,6 +79,7 @@ public static boolean areCompatible( allowCast); } + final SqlType argumentSqlType = argument.getSqlTypeOrThrow(); if (argumentSqlType.baseType() == SqlBaseType.ARRAY && declared instanceof ArrayType) { return areCompatible( SqlArgument.of(((SqlArray) argumentSqlType).getItemType()), diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java index ad0b28c862c2..de362fc44a66 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java @@ -135,6 +135,7 @@ static ParamType getReturnType( } } + // CHECKSTYLE_RULES.OFF: CyclomaticComplexity static SchemaProvider handleUdfReturnSchema( final Class theClass, final ParamType javaReturnSchema, @@ -144,6 +145,7 @@ static SchemaProvider handleUdfReturnSchema( final String functionName, final boolean isVariadic ) { + // CHECKSTYLE_RULES.ON: CyclomaticComplexity final Function, SqlType> schemaProvider; if (!Udf.NO_SCHEMA_PROVIDER.equals(schemaProviderFunctionName)) { schemaProvider = handleUdfSchemaProviderAnnotation( @@ -180,16 +182,19 @@ static SchemaProvider handleUdfReturnSchema( for (int i = 0; i < Math.min(parameters.size(), arguments.size()); i++) { final ParamType schema = parameters.get(i); if (schema instanceof LambdaType) { - if (isVariadic) { - throw new KsqlException(String.format("Lambda function %s cannot be variadic.", arguments.get(i).toString())); + if (isVariadic && i == parameters.size() - 1) { + throw new KsqlException( + String.format( + "Lambda function %s cannot be variadic.", + arguments.get(i).toString())); } genericMapping.putAll(GenericsUtil.resolveGenerics(schema, arguments.get(i))); } else { // we resolve any variadic as if it were an array so that the type // structure matches the input type final SqlType instance = isVariadic && i == parameters.size() - 1 - ? SqlTypes.array(arguments.get(i).getSqlType()) - : arguments.get(i).getSqlType(); + ? SqlTypes.array(arguments.get(i).getSqlTypeOrThrow()) + : arguments.get(i).getSqlTypeOrThrow(); genericMapping.putAll( GenericsUtil.resolveGenerics(schema, SqlArgument.of(instance)) ); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafAggregateFunctionFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafAggregateFunctionFactory.java index 7eb8ed6978ff..21fde33f0148 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafAggregateFunctionFactory.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafAggregateFunctionFactory.java @@ -61,7 +61,7 @@ public class UdafAggregateFunctionFactory extends AggregateFunctionFactory { throw new KsqlException("There is no aggregate function with name='" + getName() + "' that has arguments of type=" + allParams.stream() - .map(SqlArgument::getSqlType) + .map(SqlArgument::getSqlTypeOrThrow) .map(SqlType::baseType) .map(Objects::toString) .collect(Collectors.joining(","))); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/max/MaxAggFunctionFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/max/MaxAggFunctionFactory.java index 5135ae5d3cae..d6b41b3d1cd5 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/max/MaxAggFunctionFactory.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/max/MaxAggFunctionFactory.java @@ -42,7 +42,7 @@ public KsqlAggregateFunction createAggregateFunction( argTypeList.size() == 1, "expected exactly one argument to aggregate MAX function"); - final SqlType argSchema = argTypeList.get(0).getSqlType(); + final SqlType argSchema = argTypeList.get(0).getSqlTypeOrThrow(); switch (argSchema.baseType()) { case INTEGER: return new IntegerMaxKudaf(FUNCTION_NAME, initArgs.udafIndex()); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/min/MinAggFunctionFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/min/MinAggFunctionFactory.java index 5360c194af7c..7c37e40765fb 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/min/MinAggFunctionFactory.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/min/MinAggFunctionFactory.java @@ -41,7 +41,7 @@ public KsqlAggregateFunction createAggregateFunction( argTypeList.size() == 1, "expected exactly one argument to aggregate MAX function"); - final SqlType argSchema = argTypeList.get(0).getSqlType(); + final SqlType argSchema = argTypeList.get(0).getSqlTypeOrThrow(); switch (argSchema.baseType()) { case INTEGER: return new IntegerMinKudaf(FUNCTION_NAME, initArgs.udafIndex()); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/sum/SumAggFunctionFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/sum/SumAggFunctionFactory.java index 9724102cb2f4..1c68d202bc45 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/sum/SumAggFunctionFactory.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/sum/SumAggFunctionFactory.java @@ -43,7 +43,7 @@ public KsqlAggregateFunction createAggregateFunction( argTypeList.size() == 1, "expected exactly one argument to aggregate MAX function"); - final SqlType argSchema = argTypeList.get(0).getSqlType(); + final SqlType argSchema = argTypeList.get(0).getSqlTypeOrThrow(); switch (argSchema.baseType()) { case INTEGER: return new IntegerSumKudaf(FUNCTION_NAME, initArgs.udafIndex()); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/topk/TopKAggregateFunctionFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/topk/TopKAggregateFunctionFactory.java index 4ec2c0fd33a2..c7b94d548f10 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/topk/TopKAggregateFunctionFactory.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/topk/TopKAggregateFunctionFactory.java @@ -56,7 +56,7 @@ public KsqlAggregateFunction createAggregateFunction( throw new KsqlException("TOPK function should have two arguments."); } final int tkValFromArg = (Integer)(initArgs.arg(0)); - final SqlType argSchema = argumentType.get(0).getSqlType(); + final SqlType argSchema = argumentType.get(0).getSqlTypeOrThrow(); switch (argSchema.baseType()) { case INTEGER: return new TopkKudaf<>( diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/topkdistinct/TopkDistinctAggFunctionFactory.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/topkdistinct/TopkDistinctAggFunctionFactory.java index b49a2435e3e0..b915ed4c1d9b 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/topkdistinct/TopkDistinctAggFunctionFactory.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udaf/topkdistinct/TopkDistinctAggFunctionFactory.java @@ -56,7 +56,7 @@ public KsqlAggregateFunction createAggregateFunction( throw new KsqlException("TOPKDISTINCT function should have two arguments."); } final int tkValFromArg = (Integer)(initArgs.arg(0)); - final SqlType argSchema = argTypeList.get(0).getSqlType(); + final SqlType argSchema = argTypeList.get(0).getSqlTypeOrThrow(); switch (argSchema.baseType()) { case INTEGER: case BIGINT: diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Abs.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Abs.java index 54b0445db1a7..3e443e748272 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Abs.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Abs.java @@ -63,7 +63,7 @@ public BigDecimal abs(@UdfParameter final BigDecimal val) { @UdfSchemaProvider public SqlType absDecimalProvider(final List params) { - final SqlType s = params.get(0).getSqlType(); + final SqlType s = params.get(0).getSqlTypeOrThrow(); if (s.baseType() != SqlBaseType.DECIMAL) { throw new KsqlException("The schema provider method for Abs expects a BigDecimal parameter" + "type"); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Ceil.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Ceil.java index 2b32fb511963..c1805e72a0bd 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Ceil.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Ceil.java @@ -64,7 +64,7 @@ public BigDecimal ceil(@UdfParameter final BigDecimal val) { @UdfSchemaProvider public SqlType ceilDecimalProvider(final List params) { - final SqlType s = params.get(0).getSqlType(); + final SqlType s = params.get(0).getSqlTypeOrThrow(); if (s.baseType() != SqlBaseType.DECIMAL) { throw new KsqlException("The schema provider method for Ceil expects a BigDecimal parameter" + "type"); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Floor.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Floor.java index 188099e0ce8b..124ddc6caa4c 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Floor.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Floor.java @@ -65,7 +65,7 @@ public BigDecimal floor(@UdfParameter final BigDecimal val) { @UdfSchemaProvider public SqlType floorDecimalProvider(final List params) { - final SqlType s = params.get(0).getSqlType(); + final SqlType s = params.get(0).getSqlTypeOrThrow(); if (s.baseType() != SqlBaseType.DECIMAL) { throw new KsqlException("The schema provider method for Floor expects a BigDecimal parameter" + "type"); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Round.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Round.java index ce37a16e7823..0de9f4983c62 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Round.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udf/math/Round.java @@ -124,12 +124,12 @@ public BigDecimal round( @SuppressWarnings("unused") // Invoked via reflection @UdfSchemaProvider public static SqlType provideDecimalSchemaWithDecimalPlaces(final List params) { - final SqlType s0 = params.get(0).getSqlType(); + final SqlType s0 = params.get(0).getSqlTypeOrThrow(); if (s0.baseType() != SqlBaseType.DECIMAL) { throw new KsqlException("The schema provider method for round expects a BigDecimal parameter" + "type as first parameter."); } - final SqlType s1 = params.get(1).getSqlType(); + final SqlType s1 = params.get(1).getSqlTypeOrThrow(); if (s1.baseType() != SqlBaseType.INTEGER) { throw new KsqlException("The schema provider method for round expects an Integer parameter" + "type as second parameter."); @@ -143,7 +143,7 @@ public static SqlType provideDecimalSchemaWithDecimalPlaces(final List params) { - final SqlType s0 = params.get(0).getSqlType(); + final SqlType s0 = params.get(0).getSqlTypeOrThrow(); if (s0.baseType() != SqlBaseType.DECIMAL) { throw new KsqlException("The schema provider method for round expects a BigDecimal parameter" + "type as a parameter."); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udtf/array/Explode.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udtf/array/Explode.java index c0cba364c40c..9a219f55cc44 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/udtf/array/Explode.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/udtf/array/Explode.java @@ -53,7 +53,7 @@ public List explodeBigDecimal(final List input) { @UdfSchemaProvider public SqlType provideSchema(final List params) { - final SqlType argType = params.get(0).getSqlType(); + final SqlType argType = params.get(0).getSqlTypeOrThrow(); if (!(argType instanceof SqlArray)) { throw new KsqlException("explode should be provided with an ARRAY"); } diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java index 7847ee6b359d..105b005a6260 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java @@ -197,7 +197,7 @@ public void shouldLoadDecimalUdfs() { @Test public void shouldLoadLambdaReduceUdfs() { // Given: - final SqlLambda schema = + final SqlLambda lambda = SqlLambda.of( ImmutableList.of(SqlTypes.INTEGER, SqlTypes.INTEGER, SqlTypes.INTEGER), SqlTypes.INTEGER); @@ -208,7 +208,7 @@ public void shouldLoadLambdaReduceUdfs() { ImmutableList.of( SqlArgument.of(SqlMap.of(SqlTypes.INTEGER, SqlTypes.INTEGER)), SqlArgument.of(SqlTypes.INTEGER), - SqlArgument.of(schema))); + SqlArgument.of(lambda))); // Then: assertThat(fun.name().text(), equalToIgnoringCase("reduce_map")); @@ -217,7 +217,7 @@ public void shouldLoadLambdaReduceUdfs() { @Test public void shouldLoadLambdaTransformUdfs() { // Given: - final SqlLambda schema = + final SqlLambda lambda = SqlLambda.of( ImmutableList.of(SqlTypes.INTEGER), SqlTypes.INTEGER); @@ -227,7 +227,7 @@ public void shouldLoadLambdaTransformUdfs() { .getFunction( ImmutableList.of( SqlArgument.of(SqlArray.of(SqlTypes.INTEGER)), - SqlArgument.of(schema))); + SqlArgument.of(lambda))); // Then: assertThat(fun.name().text(), equalToIgnoringCase("array_transform")); diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java index ac1e5dfbe1a7..8beea3027756 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java @@ -40,9 +40,7 @@ import io.confluent.ksql.schema.ksql.SchemaConverters; import io.confluent.ksql.schema.ksql.SchemaConverters.SqlToJavaTypeConverter; import io.confluent.ksql.schema.ksql.SqlArgument; -import io.confluent.ksql.schema.ksql.types.SqlArray; import io.confluent.ksql.schema.ksql.types.SqlLambda; -import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; @@ -187,24 +185,20 @@ public Void visitLikePredicate(final LikePredicate node, final TypeContext conte public Void visitFunctionCall(final FunctionCall node, final TypeContext context) { final List argumentTypes = new ArrayList<>(); final FunctionName functionName = node.getName(); - boolean hasLambda = false; - for (Expression e : node.getArguments()) { - if (e instanceof LambdaFunctionCall) { - hasLambda = true; - break; - } - } + final boolean hasLambda = node.hasLambdaFunctionCallArguments(); for (final Expression argExpr : node.getArguments()) { - process(argExpr, context); final TypeContext childContext = context.getCopy(); - final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, childContext); + final SqlType resolvedArgType = + expressionTypeManager.getExpressionSqlType(argExpr, childContext); + process(argExpr, context.getCopy()); if (argExpr instanceof LambdaFunctionCall) { - argumentTypes.add(SqlArgument.of(SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType()))); + argumentTypes.add( + SqlArgument.of( + SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType()))); } else { argumentTypes.add(SqlArgument.of(resolvedArgType)); - // for lambdas - if we find an array or map passed in before encountering a lambda function - // we save the type information to resolve the lambda generics + // for lambdas - we save the type information to resolve the lambda generics if (hasLambda) { context.visitType(resolvedArgType); } @@ -304,14 +298,5 @@ private void addRequiredColumn(final ColumnName columnName) { column.index() ); } - - private boolean hasLambdaFunctionCall(FunctionCall node) { - for (Expression e : node.getArguments()) { - if (e instanceof LambdaFunctionCall) { - return true; - } - } - return false; - } } } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index 62559feb826b..501b16817ecc 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -377,8 +377,7 @@ public Pair visitLambdaExpression( SchemaConverters.sqlToJavaConverter().toJavaType(context.getLambdaType(lambdaArg)) )); } - return new Pair<>(LambdaUtil.toJavaCode(argPairs, lambdaBody.getLeft()), - expressionTypeManager.getExpressionSqlType(lambdaFunctionCall, context)); + return new Pair<>(LambdaUtil.toJavaCode(argPairs, lambdaBody.getLeft()), null); } @Override @@ -454,22 +453,18 @@ public Pair visitFunctionCall( final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName()); final List argumentSchemas = new ArrayList<>(); - boolean hasLambda = false; - for (Expression e : node.getArguments()) { - if (e instanceof LambdaFunctionCall) { - hasLambda = true; - break; - } - } + final boolean hasLambda = node.hasLambdaFunctionCallArguments(); for (final Expression argExpr : node.getArguments()) { final TypeContext childContext = context.getCopy(); - final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, childContext); + final SqlType resolvedArgType = + expressionTypeManager.getExpressionSqlType(argExpr, childContext); if (argExpr instanceof LambdaFunctionCall) { - argumentSchemas.add(SqlArgument.of(SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType()))); + argumentSchemas.add( + SqlArgument.of( + SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType()))); } else { argumentSchemas.add(SqlArgument.of(resolvedArgType)); - // for lambdas - if we find an array or map passed in before encountering a lambda function - // we save the type information to resolve the lambda generics + // for lambdas - we save the type information to resolve the lambda generics if (hasLambda) { context.visitType(resolvedArgType); } @@ -487,7 +482,9 @@ public Pair visitFunctionCall( final StringJoiner joiner = new StringJoiner(", "); for (int i = 0; i < arguments.size(); i++) { final Expression arg = arguments.get(i); - final SqlType sqlType = argumentSchemas.get(i).getSqlType(); + + // lambda arguments and null values are considered to have null type + final SqlType sqlType = argumentSchemas.get(i).getSqlType().orElse(null); final ParamType paramType; if (i >= function.parameters().size() - 1 && function.isVariadic()) { @@ -496,7 +493,9 @@ public Pair visitFunctionCall( paramType = function.parameters().get(i); } - joiner.add(process(convertArgument(arg, sqlType, paramType), context).getLeft()); + joiner.add( + process(convertArgument(arg, sqlType, paramType), context.getCopy()) + .getLeft()); } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/TypeContext.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/TypeContext.java index f851d8c9d6ac..6eeaf21cb578 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/TypeContext.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/TypeContext.java @@ -15,12 +15,7 @@ package io.confluent.ksql.execution.codegen; -import io.confluent.ksql.execution.expression.tree.Expression; -import io.confluent.ksql.execution.expression.tree.FunctionCall; -import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; -import io.confluent.ksql.schema.ksql.SqlArgument; import io.confluent.ksql.schema.ksql.types.SqlArray; -import io.confluent.ksql.schema.ksql.types.SqlLambda; import io.confluent.ksql.schema.ksql.types.SqlMap; import io.confluent.ksql.schema.ksql.types.SqlType; @@ -35,11 +30,12 @@ public class TypeContext { private final Map lambdaInputTypeMapping; public TypeContext() { - lambdaInputTypes = new ArrayList(); - lambdaInputTypeMapping = new HashMap<>(); + this(new ArrayList<>(), new HashMap<>()); } - TypeContext (final List lambdaInputTypes, final Map lambdaInputTypeMapping) { + public TypeContext( + final List lambdaInputTypes, + final Map lambdaInputTypeMapping) { this.lambdaInputTypes = lambdaInputTypes; this.lambdaInputTypeMapping = lambdaInputTypeMapping; } @@ -81,10 +77,13 @@ public SqlType getLambdaType(final String name) { public TypeContext getCopy() { - return new TypeContext(this.lambdaInputTypes, this.lambdaInputTypeMapping); + return new TypeContext( + new ArrayList<>(this.lambdaInputTypes), + new HashMap<>(this.lambdaInputTypeMapping) + ); } - public void visitType(SqlType type) { + public void visitType(final SqlType type) { if (type instanceof SqlArray) { final SqlArray inputArray = (SqlArray) type; addLambdaInputType(inputArray.getItemType()); diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/FunctionCall.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/FunctionCall.java index 474b0b4fc400..c79f881ae832 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/FunctionCall.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/FunctionCall.java @@ -53,6 +53,11 @@ public List getArguments() { return arguments; } + public boolean hasLambdaFunctionCallArguments() { + return arguments.stream().anyMatch( + argument -> argument instanceof LambdaFunctionCall); + } + @Override public R accept(final ExpressionVisitor visitor, final C context) { return visitor.visitFunctionCall(this, context); diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index be22575cc9bb..ae24c19c0140 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -81,7 +81,6 @@ import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.function.Function; import java.util.stream.Collectors; public class ExpressionTypeManager { @@ -476,23 +475,19 @@ public Void visitFunctionCall( final List argTypes = new ArrayList<>(); - boolean hasLambda = false; - for (Expression e : node.getArguments()) { - if (e instanceof LambdaFunctionCall) { - hasLambda = true; - break; - } - } + final boolean hasLambda = node.hasLambdaFunctionCallArguments(); for (final Expression expression : node.getArguments()) { final TypeContext childContext = expressionTypeContext.getCopy(); process(expression, childContext); final SqlType resolvedArgType = childContext.getSqlType(); if (expression instanceof LambdaFunctionCall) { - argTypes.add(SqlArgument.of(SqlLambda.of(expressionTypeContext.getLambdaInputTypes(), childContext.getSqlType()))); + argTypes.add( + SqlArgument.of( + SqlLambda.of(expressionTypeContext.getLambdaInputTypes(), + childContext.getSqlType()))); } else { argTypes.add(SqlArgument.of(resolvedArgType)); - // for lambdas - if we find an array or map passed in before encountering a lambda function - // we save the type information to resolve the lambda generics + // for lambdas - we save the type information to resolve the lambda generics if (hasLambda) { expressionTypeContext.visitType(resolvedArgType); } @@ -613,13 +608,4 @@ private Optional validateWhenClauses( return previousResult; } } - - private boolean hasLambdaFunctionCall(FunctionCall node) { - for (Expression e : node.getArguments()) { - if (e instanceof LambdaFunctionCall) { - return true; - } - } - return false; - } } diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/TypeContextTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/TypeContextTest.java index a63d7bfbbf29..c321e26ef970 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/TypeContextTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/TypeContextTest.java @@ -29,7 +29,7 @@ public class TypeContextTest { @Test public void shouldThrowOnLambdaMismatch() { // Given - TypeContext context = new TypeContext(); + final TypeContext context = new TypeContext(); context.addLambdaInputType(SqlTypes.STRING); context.addLambdaInputType(SqlTypes.STRING); @@ -43,9 +43,9 @@ public void shouldThrowOnLambdaMismatch() { } @Test - public void shouldMapLambdaTypes() { + public void shouldMapLambdaTypesAndClearInputList() { // Given - TypeContext context = new TypeContext(); + final TypeContext context = new TypeContext(); context.addLambdaInputType(SqlTypes.STRING); context.addLambdaInputType(SqlTypes.BIGINT); @@ -55,5 +55,6 @@ public void shouldMapLambdaTypes() { // Then assertThat(context.getLambdaType("x"), is(SqlTypes.STRING)); assertThat(context.getLambdaType("y"), is(SqlTypes.BIGINT)); + assertThat(context.getLambdaInputTypes().size(), is(0)); } } diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/tree/FunctionCallTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/tree/FunctionCallTest.java index 50464a6ac293..c94ef04c030c 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/tree/FunctionCallTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/tree/FunctionCallTest.java @@ -49,4 +49,14 @@ public void shouldImplementHashCodeAndEqualsProperty() { ) .testEquals(); } + + @Test + public void shouldReturnHasLambdaFunctionCall() { + final FunctionCall functionCall1 = new FunctionCall(SOME_NAME, SOME_ARGS); + final FunctionCall functionCall2 = new FunctionCall(SOME_NAME, ImmutableList.of( + new StringLiteral("jane"), + new LambdaFunctionCall(ImmutableList.of("x"), new StringLiteral("test")))); + assert !functionCall1.hasLambdaFunctionCallArguments(); + assert functionCall2.hasLambdaFunctionCallArguments(); + } } \ No newline at end of file diff --git a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java index 01d0dd5129d7..8bca96c8b731 100644 --- a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java +++ b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/SqlArgument.java @@ -18,6 +18,7 @@ import io.confluent.ksql.schema.ksql.types.SqlLambda; import io.confluent.ksql.schema.ksql.types.SqlType; import java.util.Objects; +import java.util.Optional; /** * A wrapper class to bundle SqlTypes and SqlLambdas for UDF functions that contain @@ -26,12 +27,16 @@ */ public class SqlArgument { - private final SqlType sqlType; - private final SqlLambda sqlLambda; + private final Optional sqlType; + private final Optional sqlLambda; public SqlArgument(final SqlType type, final SqlLambda lambda) { - sqlType = type; - sqlLambda = lambda; + if (type != null && lambda != null) { + throw new RuntimeException( + "A function argument was assigned to be both a type and a lambda"); + } + sqlType = Optional.ofNullable(type); + sqlLambda = Optional.ofNullable(lambda); } public static SqlArgument of(final SqlType type) { @@ -46,14 +51,32 @@ public static SqlArgument of(final SqlType sqlType, final SqlLambda lambdaType) return new SqlArgument(sqlType, lambdaType); } - public SqlType getSqlType() { + public Optional getSqlType() { return sqlType; } - public SqlLambda getSqlLambda() { + public SqlType getSqlTypeOrThrow() { + if (sqlLambda.isPresent()) { + throw new RuntimeException("Was expecting type as a function argument"); + } + // we represent the null type with a null SqlType + return sqlType.orElse(null); + } + + public Optional getSqlLambda() { return sqlLambda; } + public SqlLambda getSqlLambdaOrThrow() { + if (sqlType.isPresent()) { + throw new RuntimeException("Was expecting lambda as a function argument"); + } + if (sqlLambda.isPresent()) { + return sqlLambda.get(); + } + throw new RuntimeException("Was expecting lambda as a function argument"); + } + @Override public int hashCode() { return Objects.hash(sqlType, sqlLambda); @@ -74,10 +97,9 @@ public boolean equals(final Object o) { @Override public String toString() { - if (sqlType != null) { - return sqlType.toString(); - } else { - return sqlLambda.toString(); + if (sqlType.isPresent()) { + return sqlType.get().toString(); } + return sqlLambda.map(SqlLambda::toString).orElse("null"); } } diff --git a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java index 178ae6bcbfee..b1e37b043a39 100644 --- a/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java +++ b/ksqldb-udf/src/main/java/io/confluent/ksql/schema/ksql/types/SqlLambda.java @@ -75,13 +75,9 @@ public int hashCode() { return Objects.hash(inputTypes, returnType); } + @Override public String toString() { - return "LAMBDA " - + inputTypes.stream() - .map(Object::toString) - .collect(Collectors.joining(", ", "(", ")")) - + " => " - + returnType; + return toString(FormatOptions.none()); } public String toString(final FormatOptions formatOptions) { @@ -90,6 +86,6 @@ public String toString(final FormatOptions formatOptions) { .map(Object::toString) .collect(Collectors.joining(", ", "(", ")")) + " => " - + returnType; + + returnType.toString(formatOptions); } } diff --git a/ksqldb-udf/src/test/java/io/confluent/ksql/schema/ksql/SqlArgumentTest.java b/ksqldb-udf/src/test/java/io/confluent/ksql/schema/ksql/SqlArgumentTest.java new file mode 100644 index 000000000000..5f4ba472aa2b --- /dev/null +++ b/ksqldb-udf/src/test/java/io/confluent/ksql/schema/ksql/SqlArgumentTest.java @@ -0,0 +1,104 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.schema.ksql; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableList; +import com.google.common.testing.EqualsTester; +import io.confluent.ksql.schema.ksql.types.SqlArray; +import io.confluent.ksql.schema.ksql.types.SqlLambda; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import org.junit.Test; + +public class SqlArgumentTest { + + @SuppressWarnings("UnstableApiUsage") + @Test + public void shouldImplementHashCodeAndEqualsProperly() { + new EqualsTester() + .addEqualityGroup(SqlArgument.of(SqlArray.of(SqlTypes.STRING)), SqlArgument.of(SqlArray.of(SqlTypes.STRING))) + .addEqualityGroup( + SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.STRING), SqlTypes.INTEGER)), + SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.STRING), SqlTypes.INTEGER))) + .addEqualityGroup(SqlArgument.of(null, null), SqlArgument.of(null, null)) + .testEquals(); + } + + @Test + public void shouldReturnNullTypeIfBothLambdaAndTypeNotPresent() { + final SqlArgument argument = SqlArgument.of(null, null); + assertThat("null type", argument.getSqlTypeOrThrow() == null); + } + + @Test + public void shouldReturnTypeIfPresent() { + final SqlArgument argument = SqlArgument.of(SqlTypes.STRING); + assertThat("string type", argument.getSqlTypeOrThrow() == SqlTypes.STRING); + } + + @Test + public void shouldReturnLambdaIfPresent() { + final SqlArgument argument = SqlArgument.of(SqlLambda.of(ImmutableList.of(SqlTypes.STRING), SqlTypes.INTEGER)); + assertThat("lambda", argument.getSqlLambdaOrThrow() + .equals(SqlLambda.of(ImmutableList.of(SqlTypes.STRING), SqlTypes.INTEGER))); + } + + @Test + public void shouldThrowIfAssigningTypeAndLambdaToSqlArgument() { + final Exception e = assertThrows( + RuntimeException.class, + () -> SqlArgument.of(SqlTypes.STRING, (SqlLambda.of(ImmutableList.of(SqlTypes.STRING), SqlTypes.INTEGER))) + ); + assertThat(e.getMessage(), containsString( + "A function argument was assigned to be both a type and a lambda")); + } + + @Test + public void shouldThrowWhenLambdaPresentWhenGettingType() { + final SqlArgument argument = SqlArgument.of(null, (SqlLambda.of(ImmutableList.of(SqlTypes.STRING), SqlTypes.INTEGER))); + + final Exception e = assertThrows( + RuntimeException.class, + argument::getSqlTypeOrThrow + ); + assertThat(e.getMessage(), containsString("Was expecting type as a function argument")); + } + + @Test + public void shouldThrowWhenTypePresentWhenGettingLambda() { + final SqlArgument argument = SqlArgument.of(SqlTypes.STRING, null); + final Exception e = assertThrows( + RuntimeException.class, + argument::getSqlLambdaOrThrow + ); + + assertThat(e.getMessage(), containsString("Was expecting lambda as a function argument")); + } + + @Test + public void shouldThrowWhenLambdaNotPresentGettingLambda() { + final SqlArgument argument = SqlArgument.of(null, null); + final Exception e = assertThrows( + RuntimeException.class, + argument::getSqlLambdaOrThrow + ); + + assertThat(e.getMessage(), containsString("Was expecting lambda as a function argument")); + } +} \ No newline at end of file