From 4e66825a569bb7d8e4aca6855121d129acee0005 Mon Sep 17 00:00:00 2001 From: Tim Fox Date: Thu, 31 Oct 2019 17:59:07 -0700 Subject: [PATCH] feat: Implement schemaProvider for UDTFs (#3690) This commit allows schemaProvider to be used when specifying return values for UDTFs in the same way as UDFs. --- .../ksql/function/FunctionInvoker.java | 8 +- .../confluent/ksql/function/KsqlFunction.java | 151 +--------------- .../ksql/function/KsqlFunctionBase.java | 163 ++++++++++++++++++ .../ksql/function/KsqlTableFunction.java | 33 +++- .../ksql/function/TableFunctionFactory.java | 1 - .../ksql/function/BaseTableFunction.java | 77 --------- .../ksql/function/FunctionLoaderUtils.java | 25 ++- .../io/confluent/ksql/function/UdfLoader.java | 17 +- .../confluent/ksql/function/UdtfLoader.java | 50 ++++-- .../ksql/function/UserFunctionLoader.java | 6 +- .../ksql/function/udtf/array/Explode.java | 38 ++-- .../ksql/planner/plan/FlatMapNode.java | 15 +- .../InternalFunctionRegistryTest.java | 37 +--- .../ksql/function/UdfLoaderTest.java | 15 +- .../ksql/function/UdtfLoaderTest.java | 124 +++++++++++-- .../confluent/ksql/function/udf/TestUdtf.java | 11 +- .../execution/util/ExpressionTypeManager.java | 6 +- .../query-validation-tests/explode.json | 10 +- .../table-functions.json | 11 +- 19 files changed, 431 insertions(+), 367 deletions(-) rename {ksql-engine => ksql-common}/src/main/java/io/confluent/ksql/function/FunctionInvoker.java (77%) create mode 100644 ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunctionBase.java delete mode 100644 ksql-engine/src/main/java/io/confluent/ksql/function/BaseTableFunction.java diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/FunctionInvoker.java b/ksql-common/src/main/java/io/confluent/ksql/function/FunctionInvoker.java similarity index 77% rename from ksql-engine/src/main/java/io/confluent/ksql/function/FunctionInvoker.java rename to ksql-common/src/main/java/io/confluent/ksql/function/FunctionInvoker.java index be150066e0d7..17274070e7b3 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/FunctionInvoker.java +++ b/ksql-common/src/main/java/io/confluent/ksql/function/FunctionInvoker.java @@ -21,12 +21,12 @@ public interface FunctionInvoker { /** - * Call onto an UDF instance with the expected args. This is providing - * a wrapper such that we can invoke all UDFs in a generic way. + * Call the UDF/UDTF instance with the expected args. This is providing + * a wrapper such that we can invoke all UDFs/UDTFs in a generic way. * - * @param udf the udf that is being called + * @param udf the UDF/UDTF that is being called * @param udfArgs any arguments that need to be passed to the udf - * @return the result of evaluating the udf + * @return the result of evaluating the UDF/UDTF */ Object eval(Object udf, Object... udfArgs); } diff --git a/ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunction.java b/ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunction.java index f984e9f00c98..42a4ec21b632 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunction.java +++ b/ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunction.java @@ -15,40 +15,23 @@ package io.confluent.ksql.function; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Iterables; import io.confluent.ksql.function.udf.Kudf; import io.confluent.ksql.name.FunctionName; -import io.confluent.ksql.schema.ksql.FormatOptions; -import io.confluent.ksql.schema.ksql.SchemaConverters; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; -import io.confluent.ksql.util.SchemaUtil; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.function.Function; -import java.util.stream.Collectors; import javax.annotation.concurrent.Immutable; import org.apache.kafka.connect.data.Schema; -import org.apache.kafka.connect.data.Schema.Type; -import org.apache.kafka.connect.data.SchemaBuilder; @Immutable -public final class KsqlFunction implements FunctionSignature { +public final class KsqlFunction extends KsqlFunctionBase { static final String INTERNAL_PATH = "internal"; - private final Function,Schema> returnSchemaProvider; - private final Schema javaReturnType; - private final List parameters; - private final FunctionName functionName; private final Class kudfClass; private final Function udfFactory; - private final String description; - private final String pathLoadedFrom; - private final boolean isVariadic; private KsqlFunction( final Function,Schema> returnSchemaProvider, @@ -60,31 +43,11 @@ private KsqlFunction( final String description, final String pathLoadedFrom, final boolean isVariadic) { - - this.returnSchemaProvider = Objects.requireNonNull(returnSchemaProvider, "schemaProvider"); - this.javaReturnType = Objects.requireNonNull(javaReturnType, "javaReturnType"); - this.parameters = ImmutableList.copyOf(Objects.requireNonNull(arguments, "arguments")); - this.functionName = Objects.requireNonNull(functionName, "functionName"); + super(returnSchemaProvider, javaReturnType, arguments, functionName, description, + pathLoadedFrom, isVariadic + ); this.kudfClass = Objects.requireNonNull(kudfClass, "kudfClass"); this.udfFactory = Objects.requireNonNull(udfFactory, "udfFactory"); - this.description = Objects.requireNonNull(description, "description"); - this.pathLoadedFrom = Objects.requireNonNull(pathLoadedFrom, "pathLoadedFrom"); - this.isVariadic = isVariadic; - - - if (arguments.stream().anyMatch(Objects::isNull)) { - throw new IllegalArgumentException("KSQL Function can't have null argument types"); - } - if (isVariadic) { - if (arguments.isEmpty()) { - throw new IllegalArgumentException( - "KSQL variadic functions must have at least one parameter"); - } - if (!Iterables.getLast(arguments).type().equals(Type.ARRAY)) { - throw new IllegalArgumentException( - "KSQL variadic functions must have ARRAY type as their last parameter"); - } - } } /** @@ -111,6 +74,10 @@ public static KsqlFunction createLegacyBuiltIn( INTERNAL_PATH, false); } + public Class getKudfClass() { + return kudfClass; + } + /** * Create udf. * @@ -139,112 +106,10 @@ static KsqlFunction create( isVariadic); } - public Schema getReturnType(final List arguments) { - - final Schema returnType = returnSchemaProvider.apply(arguments); - - if (returnType == null) { - throw new KsqlException(String.format("Return type of UDF %s cannot be null.", functionName)); - } - - if (!returnType.isOptional()) { - throw new IllegalArgumentException("KSQL only supports optional field types"); - } - - if (!GenericsUtil.hasGenerics(returnType)) { - checkMatchingReturnTypes(returnType, javaReturnType); - return returnType; - } - - final Map genericMapping = new HashMap<>(); - for (int i = 0; i < Math.min(parameters.size(), arguments.size()); i++) { - final Schema schema = parameters.get(i); - - // we resolve any variadic as if it were an array so that the type - // structure matches the input type - final Schema instance = isVariadic && i == parameters.size() - 1 - ? SchemaBuilder.array(arguments.get(i)).build() - : arguments.get(i); - - genericMapping.putAll(GenericsUtil.resolveGenerics(schema, instance)); - } - - final Schema genericSchema = GenericsUtil.applyResolved(returnType, genericMapping); - final Schema genericJavaSchema = GenericsUtil.applyResolved(javaReturnType, genericMapping); - checkMatchingReturnTypes(genericSchema, genericJavaSchema); - - return genericSchema; - } - - private void checkMatchingReturnTypes(final Schema s1, final Schema s2) { - if (!SchemaUtil.areCompatible(s1, s2)) { - throw new KsqlException(String.format("Return type %s of UDF %s does not match the declared " - + "return type %s.", - SchemaConverters.connectToSqlConverter().toSqlType( - s1).toString(), - functionName.toString(FormatOptions.noEscape()), - SchemaConverters.connectToSqlConverter().toSqlType( - s2).toString())); - } - } - - public List getArguments() { - return parameters; - } - - public FunctionName getFunctionName() { - return functionName; - } - - public String getDescription() { - return description; - } - - public Class getKudfClass() { - return kudfClass; - } - - public String getPathLoadedFrom() { - return pathLoadedFrom; - } - - public boolean isVariadic() { - return isVariadic; - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - final KsqlFunction that = (KsqlFunction) o; - return Objects.equals(javaReturnType, that.javaReturnType) - && Objects.equals(parameters, that.parameters) - && Objects.equals(functionName, that.functionName) - && Objects.equals(kudfClass, that.kudfClass) - && Objects.equals(pathLoadedFrom, that.pathLoadedFrom) - && (isVariadic == that.isVariadic); - } - - @Override - public int hashCode() { - return Objects.hash( - returnSchemaProvider, parameters, functionName, kudfClass, pathLoadedFrom, isVariadic); - } - @Override public String toString() { return "KsqlFunction{" - + "returnType=" + javaReturnType - + ", arguments=" + parameters.stream().map(Schema::type).collect(Collectors.toList()) - + ", functionName='" + functionName + '\'' + ", kudfClass=" + kudfClass - + ", description='" + description + "'" - + ", pathLoadedFrom='" + pathLoadedFrom + "'" - + ", isVariadic=" + isVariadic + '}'; } diff --git a/ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunctionBase.java b/ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunctionBase.java new file mode 100644 index 000000000000..842f29d60c9b --- /dev/null +++ b/ksql-common/src/main/java/io/confluent/ksql/function/KsqlFunctionBase.java @@ -0,0 +1,163 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.function; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import io.confluent.ksql.name.FunctionName; +import io.confluent.ksql.schema.ksql.FormatOptions; +import io.confluent.ksql.schema.ksql.SchemaConverters; +import io.confluent.ksql.util.KsqlException; +import io.confluent.ksql.util.SchemaUtil; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; +import javax.annotation.concurrent.Immutable; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Schema.Type; +import org.apache.kafka.connect.data.SchemaBuilder; + +@Immutable +public class KsqlFunctionBase implements FunctionSignature { + + private final Function, Schema> returnSchemaProvider; + private final Schema javaReturnType; + private final List parameters; + private final FunctionName functionName; + private final String description; + private final String pathLoadedFrom; + private final boolean isVariadic; + + KsqlFunctionBase( + final Function, Schema> returnSchemaProvider, + final Schema javaReturnType, + final List arguments, + final FunctionName functionName, + final String description, + final String pathLoadedFrom, + final boolean isVariadic + ) { + + this.returnSchemaProvider = Objects.requireNonNull(returnSchemaProvider, "schemaProvider"); + this.javaReturnType = Objects.requireNonNull(javaReturnType, "javaReturnType"); + this.parameters = ImmutableList.copyOf(Objects.requireNonNull(arguments, "arguments")); + this.functionName = Objects.requireNonNull(functionName, "functionName"); + this.description = Objects.requireNonNull(description, "description"); + this.pathLoadedFrom = Objects.requireNonNull(pathLoadedFrom, "pathLoadedFrom"); + this.isVariadic = isVariadic; + + if (arguments.stream().anyMatch(Objects::isNull)) { + throw new IllegalArgumentException("KSQL Function can't have null argument types"); + } + if (isVariadic) { + if (arguments.isEmpty()) { + throw new IllegalArgumentException( + "KSQL variadic functions must have at least one parameter"); + } + if (!Iterables.getLast(arguments).type().equals(Type.ARRAY)) { + throw new IllegalArgumentException( + "KSQL variadic functions must have ARRAY type as their last parameter"); + } + } + } + + public Schema getReturnType(final List arguments) { + + final Schema returnType = returnSchemaProvider.apply(arguments); + + if (returnType == null) { + throw new KsqlException(String.format("Return type of UDF %s cannot be null.", functionName)); + } + + if (!returnType.isOptional()) { + throw new IllegalArgumentException("KSQL only supports optional field types"); + } + + if (!GenericsUtil.hasGenerics(returnType)) { + checkMatchingReturnTypes(returnType, javaReturnType); + return returnType; + } + + final Map genericMapping = new HashMap<>(); + for (int i = 0; i < Math.min(parameters.size(), arguments.size()); i++) { + final Schema schema = parameters.get(i); + + // we resolve any variadic as if it were an array so that the type + // structure matches the input type + final Schema instance = isVariadic && i == parameters.size() - 1 + ? SchemaBuilder.array(arguments.get(i)).build() + : arguments.get(i); + + genericMapping.putAll(GenericsUtil.resolveGenerics(schema, instance)); + } + + final Schema genericSchema = GenericsUtil.applyResolved(returnType, genericMapping); + final Schema genericJavaSchema = GenericsUtil.applyResolved(javaReturnType, genericMapping); + checkMatchingReturnTypes(genericSchema, genericJavaSchema); + + return genericSchema; + } + + private void checkMatchingReturnTypes(final Schema s1, final Schema s2) { + if (!SchemaUtil.areCompatible(s1, s2)) { + throw new KsqlException(String.format( + "Return type %s of UDF %s does not match the declared " + + "return type %s.", + SchemaConverters.connectToSqlConverter().toSqlType( + s1).toString(), + functionName.toString(FormatOptions.noEscape()), + SchemaConverters.connectToSqlConverter().toSqlType( + s2).toString() + )); + } + } + + public List getArguments() { + return parameters; + } + + public FunctionName getFunctionName() { + return functionName; + } + + public String getDescription() { + return description; + } + + public String getPathLoadedFrom() { + return pathLoadedFrom; + } + + public boolean isVariadic() { + return isVariadic; + } + + @Override + public String toString() { + return "KsqlFunction{" + + "returnType=" + javaReturnType + + ", arguments=" + parameters.stream().map(Schema::type).collect(Collectors.toList()) + + ", functionName='" + functionName + '\'' + + ", description='" + description + "'" + + ", pathLoadedFrom='" + pathLoadedFrom + "'" + + ", isVariadic=" + isVariadic + + '}'; + } + +} \ No newline at end of file diff --git a/ksql-common/src/main/java/io/confluent/ksql/function/KsqlTableFunction.java b/ksql-common/src/main/java/io/confluent/ksql/function/KsqlTableFunction.java index 5eb7d73b67ff..143c7dae4494 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/function/KsqlTableFunction.java +++ b/ksql-common/src/main/java/io/confluent/ksql/function/KsqlTableFunction.java @@ -15,21 +15,40 @@ package io.confluent.ksql.function; -import io.confluent.ksql.schema.ksql.types.SqlType; +import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.name.FunctionName; import java.util.List; +import java.util.Objects; +import java.util.function.Function; import org.apache.kafka.connect.data.Schema; /** * A wrapper around the actual table function which provides methods to get return type and * description, and allows the function to be invoked. */ -public interface KsqlTableFunction extends FunctionSignature { +@Immutable +public class KsqlTableFunction extends KsqlFunctionBase { - Schema getReturnType(); + private final FunctionInvoker invoker; + private final Object udtf; - SqlType returnType(); + public KsqlTableFunction( + final Function, Schema> returnSchemaProvider, + final FunctionName functionName, + final Schema outputType, + final List arguments, + final String description, + final FunctionInvoker functionInvoker, + final Object udtf + ) { + super(returnSchemaProvider, outputType, arguments, functionName, description, + "", false + ); + this.invoker = Objects.requireNonNull(functionInvoker, "functionInvoker"); + this.udtf = Objects.requireNonNull(udtf, "udtf"); + } - List apply(Object... args); - - String getDescription(); + public List apply(final Object... args) { + return (List) invoker.eval(udtf, args); + } } diff --git a/ksql-common/src/main/java/io/confluent/ksql/function/TableFunctionFactory.java b/ksql-common/src/main/java/io/confluent/ksql/function/TableFunctionFactory.java index 38816f4db387..aa74390724d9 100644 --- a/ksql-common/src/main/java/io/confluent/ksql/function/TableFunctionFactory.java +++ b/ksql-common/src/main/java/io/confluent/ksql/function/TableFunctionFactory.java @@ -20,7 +20,6 @@ import java.util.Objects; import org.apache.kafka.connect.data.Schema; - public abstract class TableFunctionFactory { private final UdfMetadata metadata; diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/BaseTableFunction.java b/ksql-engine/src/main/java/io/confluent/ksql/function/BaseTableFunction.java deleted file mode 100644 index a81833daecc5..000000000000 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/BaseTableFunction.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2018 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.function; - -import com.google.errorprone.annotations.Immutable; -import io.confluent.ksql.name.FunctionName; -import io.confluent.ksql.schema.ksql.SchemaConverters; -import io.confluent.ksql.schema.ksql.SchemaConverters.ConnectToSqlTypeConverter; -import io.confluent.ksql.schema.ksql.types.SqlType; -import java.util.List; -import java.util.Objects; -import org.apache.kafka.connect.data.Schema; - -/** - * Abstract base class for table functions - */ -@Immutable -public abstract class BaseTableFunction implements KsqlTableFunction { - - private static final ConnectToSqlTypeConverter CONNECT_TO_SQL_CONVERTER - = SchemaConverters.connectToSqlConverter(); - - private final Schema outputSchema; - private final SqlType outputType; - private final List arguments; - - protected final FunctionName functionName; - private final String description; - - public BaseTableFunction( - final FunctionName functionName, - final Schema outputType, - final List arguments, - final String description - ) { - this.outputSchema = Objects.requireNonNull(outputType, "outputType"); - this.outputType = CONNECT_TO_SQL_CONVERTER.toSqlType(outputType); - this.arguments = Objects.requireNonNull(arguments, "arguments"); - this.functionName = Objects.requireNonNull(functionName, "functionName"); - this.description = Objects.requireNonNull(description, "description"); - } - - public FunctionName getFunctionName() { - return functionName; - } - - public Schema getReturnType() { - return outputSchema; - } - - @Override - public SqlType returnType() { - return outputType; - } - - public List getArguments() { - return arguments; - } - - @Override - public String getDescription() { - return description; - } -} diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java b/ksql-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java index b0519beeba18..5de811da3a1f 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java @@ -17,8 +17,6 @@ import com.google.common.annotations.VisibleForTesting; import io.confluent.ksql.execution.function.UdfUtil; -import io.confluent.ksql.function.udf.Udf; -import io.confluent.ksql.function.udf.UdfDescription; import io.confluent.ksql.function.udf.UdfParameter; import io.confluent.ksql.function.udf.UdfSchemaProvider; import io.confluent.ksql.schema.ksql.SchemaConverters; @@ -178,16 +176,15 @@ static Schema getReturnType( static Function, Schema> handleUdfReturnSchema( final Class theClass, final Schema javaReturnSchema, - final Udf udfAnnotation, - final UdfDescription descAnnotation + final String schemaProviderFunctionName, + final String functionName ) { - final String schemaProviderName = udfAnnotation.schemaProvider(); - - if (!schemaProviderName.equals("")) { - return handleUdfSchemaProviderAnnotation(schemaProviderName, theClass, descAnnotation); + if (!schemaProviderFunctionName.equals("")) { + return handleUdfSchemaProviderAnnotation( + schemaProviderFunctionName, theClass, functionName); } else if (DecimalUtil.isDecimal(javaReturnSchema)) { throw new KsqlException(String.format("Cannot load UDF %s. BigDecimal return type " - + "is not supported without a schema provider method.", descAnnotation.name())); + + "is not supported without a schema provider method.", functionName)); } return ignored -> javaReturnSchema; @@ -196,19 +193,19 @@ static Function, Schema> handleUdfReturnSchema( private static Function, Schema> handleUdfSchemaProviderAnnotation( final String schemaProviderName, final Class theClass, - final UdfDescription annotation + final String functionName ) { // throws exception if cannot find method final Method m = findSchemaProvider(theClass, schemaProviderName); final Object instance = FunctionLoaderUtils - .instantiateFunctionInstance(theClass, annotation.name()); + .instantiateFunctionInstance(theClass, functionName); return parameterSchemas -> { final List parameterTypes = parameterSchemas.stream() .map(p -> SchemaConverters.connectToSqlConverter().toSqlType(p)) .collect(Collectors.toList()); return SchemaConverters.sqlToConnectConverter().toConnectSchema(invokeSchemaProviderMethod( - instance, m, parameterTypes, annotation)); + instance, m, parameterTypes, functionName)); }; } @@ -236,7 +233,7 @@ private static SqlType invokeSchemaProviderMethod( final Object instance, final Method m, final List args, - final UdfDescription annotation + final String functionName ) { try { return (SqlType) m.invoke(instance, args); @@ -244,7 +241,7 @@ private static SqlType invokeSchemaProviderMethod( | InvocationTargetException e) { throw new KsqlException(String.format("Cannot invoke the schema provider " + "method %s for UDF %s. ", - m.getName(), annotation.name() + m.getName(), functionName ), e); } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java b/ksql-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java index c82bd38e09d4..f1c52136a0fe 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java @@ -46,17 +46,20 @@ public class UdfLoader { private final Optional metrics; private final SqlTypeParser typeParser; private final ClassLoader parentClassLoader; + private final boolean throwExceptionOnLoadFailure; UdfLoader( final MutableFunctionRegistry functionRegistry, final Optional metrics, final SqlTypeParser typeParser, - final ClassLoader parentClassLoader + final ClassLoader parentClassLoader, + final boolean throwExceptionOnLoadFailure ) { this.functionRegistry = functionRegistry; this.metrics = metrics; this.typeParser = typeParser; this.parentClassLoader = parentClassLoader; + this.throwExceptionOnLoadFailure = throwExceptionOnLoadFailure; } // Does not handle customer udfs, i.e the loader is the ParentClassLoader and path is internal @@ -65,14 +68,13 @@ public class UdfLoader { void loadUdfFromClass(final Class... udfClasses) { for (final Class theClass : udfClasses) { loadUdfFromClass( - theClass, KsqlFunction.INTERNAL_PATH, theClass.getClassLoader()); + theClass, KsqlFunction.INTERNAL_PATH); } } void loadUdfFromClass( final Class theClass, - final String path, - final ClassLoader loader + final String path ) { final UdfDescription udfDescriptionAnnotation = theClass.getAnnotation(UdfDescription.class); if (udfDescriptionAnnotation == null) { @@ -108,8 +110,7 @@ void loadUdfFromClass( sensorName, udfClass )); } catch (final KsqlException e) { - if (parentClassLoader == loader) { - // This only seems to be done for tests, dubious code + if (throwExceptionOnLoadFailure) { throw e; } else { LOGGER.warn( @@ -154,8 +155,8 @@ private KsqlFunction createFunction( FunctionLoaderUtils.handleUdfReturnSchema( theClass, javaReturnSchema, - udfAnnotation, - udfDescriptionAnnotation + udfAnnotation.schemaProvider(), + udfDescriptionAnnotation.name() ), javaReturnSchema, parameters, diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/UdtfLoader.java b/ksql-engine/src/main/java/io/confluent/ksql/function/UdtfLoader.java index 86a558a25e39..ff0cf48cbfe6 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/UdtfLoader.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/UdtfLoader.java @@ -27,6 +27,7 @@ import java.util.Arrays; import java.util.List; import java.util.Optional; +import java.util.function.Function; import org.apache.kafka.common.metrics.Metrics; import org.apache.kafka.connect.data.Schema; import org.slf4j.Logger; @@ -42,15 +43,18 @@ class UdtfLoader { private final MutableFunctionRegistry functionRegistry; private final Optional metrics; private final SqlTypeParser typeParser; + private final boolean throwExceptionOnLoadFailure; UdtfLoader( final MutableFunctionRegistry functionRegistry, final Optional metrics, - final SqlTypeParser typeParser + final SqlTypeParser typeParser, + final boolean throwExceptionOnLoadFailure ) { this.functionRegistry = functionRegistry; this.metrics = metrics; this.typeParser = typeParser; + this.throwExceptionOnLoadFailure = throwExceptionOnLoadFailure; } void loadUdtfFromClass( @@ -91,7 +95,8 @@ void loadUdtfFromClass( final Type ret = method.getGenericReturnType(); if (!(ret instanceof ParameterizedType)) { throw new KsqlException(String - .format("UDTF functions must return a generic List. Class %s Method %s", + .format( + "UDTF functions must return a parameterized List. Class %s Method %s", theClass.getName(), method.getName() )); } @@ -103,15 +108,20 @@ void loadUdtfFromClass( return Optional .of(createTableFunction(method, FunctionName.of(functionName), returnType, parameters, - udtfDescriptionAnnotation.description() + udtfDescriptionAnnotation.description(), + annotation )); } catch (final KsqlException e) { - LOGGER.warn( - "Failed to add UDF to the MetaStore. name={} method={}", - udtfDescriptionAnnotation.name(), - method, - e - ); + if (throwExceptionOnLoadFailure) { + throw e; + } else { + LOGGER.warn( + "Failed to add UDTF to the MetaStore. name={} method={}", + udtfDescriptionAnnotation.name(), + method, + e + ); + } } return Optional.empty(); }) @@ -127,18 +137,22 @@ private KsqlTableFunction createTableFunction( final FunctionName functionName, final Schema outputType, final List arguments, - final String description + final String description, + final Udtf udtfAnnotation ) { final FunctionInvoker invoker = FunctionLoaderUtils.createFunctionInvoker(method); final Object instance = FunctionLoaderUtils .instantiateFunctionInstance(method.getDeclaringClass(), description); - @SuppressWarnings("unchecked") final KsqlTableFunction tableFunction = new BaseTableFunction( - functionName, outputType, arguments, description) { - @Override - public List apply(final Object... args) { - return (List) invoker.eval(instance, args); - } - }; - return tableFunction; + final Function, Schema> schemaProviderFunction = FunctionLoaderUtils + .handleUdfReturnSchema( + method.getDeclaringClass(), + outputType, + udtfAnnotation.schemaProvider(), + functionName.name() + ); + return new KsqlTableFunction( + schemaProviderFunction, + functionName, outputType, arguments, description, invoker, instance + ); } } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/UserFunctionLoader.java b/ksql-engine/src/main/java/io/confluent/ksql/function/UserFunctionLoader.java index 757b57c73992..acf82b0382b7 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/UserFunctionLoader.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/UserFunctionLoader.java @@ -72,9 +72,9 @@ public UserFunctionLoader( Objects.requireNonNull(metrics, "metrics can't be null"); this.loadCustomerUdfs = loadCustomerUdfs; final SqlTypeParser typeParser = SqlTypeParser.create(TypeRegistry.EMPTY); - this.udfLoader = new UdfLoader(functionRegistry, metrics, typeParser, parentClassLoader); + this.udfLoader = new UdfLoader(functionRegistry, metrics, typeParser, parentClassLoader, false); this.udafLoader = new UdafLoader(functionRegistry, metrics, typeParser); - this.udtfLoader = new UdtfLoader(functionRegistry, metrics, typeParser); + this.udtfLoader = new UdtfLoader(functionRegistry, metrics, typeParser, false); } public void load() { @@ -122,7 +122,7 @@ private void loadFunctions(final ClassLoader loader, final Optional path) }) .matchClassesWithAnnotation( UdfDescription.class, - theClass -> udfLoader.loadUdfFromClass(theClass, pathLoadedFrom, loader) + theClass -> udfLoader.loadUdfFromClass(theClass, pathLoadedFrom) ) .matchClassesWithAnnotation( UdafDescription.class, diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/udtf/array/Explode.java b/ksql-engine/src/main/java/io/confluent/ksql/function/udtf/array/Explode.java index ea8738286118..61cdc8ea738c 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/udtf/array/Explode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/udtf/array/Explode.java @@ -15,9 +15,14 @@ package io.confluent.ksql.function.udtf.array; +import io.confluent.ksql.function.udf.UdfSchemaProvider; import io.confluent.ksql.function.udtf.Udtf; import io.confluent.ksql.function.udtf.UdtfDescription; +import io.confluent.ksql.schema.ksql.types.SqlArray; +import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.util.KsqlConstants; +import io.confluent.ksql.util.KsqlException; +import java.math.BigDecimal; import java.util.Collections; import java.util.List; @@ -31,32 +36,21 @@ public class Explode { @Udtf - public List explodeLong(final List input) { - return explode(input); - } - - @Udtf - public List explodeInt(final List input) { - return explode(input); - } - - @Udtf - public List explodeDouble(final List input) { - return explode(input); - } - - @Udtf - public List explodeBoolean(final List input) { - return explode(input); + public List explode(final List list) { + return list == null ? Collections.emptyList() : list; } - @Udtf - public List explodeString(final List input) { + @Udtf(schemaProvider = "provideSchema") + public List explodeBigDecimal(final List input) { return explode(input); } - private List explode(final List list) { - return list == null ? Collections.emptyList() : list; + @UdfSchemaProvider + public SqlType provideSchema(final List params) { + final SqlType argType = params.get(0); + if (!(argType instanceof SqlArray)) { + throw new KsqlException("explode should be provided with an ARRAY"); + } + return ((SqlArray) argType).getItemType(); } - } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/FlatMapNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/FlatMapNode.java index 3a29811dbe68..6774b5f886db 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/FlatMapNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/FlatMapNode.java @@ -25,17 +25,14 @@ import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor; -import io.confluent.ksql.execution.function.UdtfUtil; import io.confluent.ksql.execution.plan.SelectExpression; +import io.confluent.ksql.execution.util.ExpressionTypeManager; import io.confluent.ksql.function.FunctionRegistry; -import io.confluent.ksql.function.KsqlTableFunction; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.LogicalSchema; -import io.confluent.ksql.schema.ksql.SchemaConverters; -import io.confluent.ksql.schema.ksql.SchemaConverters.ConnectToSqlTypeConverter; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.services.KafkaTopicClient; import io.confluent.ksql.structured.SchemaKStream; @@ -128,16 +125,14 @@ private LogicalSchema buildLogicalSchema(final LogicalSchema inputSchema) { schemaBuilder.valueColumn(col); } - final ConnectToSqlTypeConverter converter = SchemaConverters.connectToSqlConverter(); + final ExpressionTypeManager expressionTypeManager = new ExpressionTypeManager( + inputSchema, functionRegistry); // And add new columns representing the exploded values at the end for (int i = 0; i < analysis.getTableFunctions().size(); i++) { - final KsqlTableFunction tableFunction = - UdtfUtil.resolveTableFunction(functionRegistry, - analysis.getTableFunctions().get(i), inputSchema - ); + final FunctionCall functionCall = analysis.getTableFunctions().get(i); final ColumnName colName = ColumnName.synthesisedSchemaColumn(i); - final SqlType fieldType = converter.toSqlType(tableFunction.getReturnType()); + final SqlType fieldType = expressionTypeManager.getExpressionSqlType(functionCall); schemaBuilder.valueColumn(colName, fieldType); } diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/InternalFunctionRegistryTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/InternalFunctionRegistryTest.java index ba8ef1b13c31..196a6a6e0e67 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/InternalFunctionRegistryTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/InternalFunctionRegistryTest.java @@ -90,6 +90,9 @@ public Object evaluate(final Object... args) { @Mock private AggregateFunctionFactory udafFactory; + @Mock + private KsqlTableFunction tableFunction; + @Before public void setUp() { when(udfFactory.getName()).thenReturn(UDF_NAME); @@ -432,42 +435,12 @@ public List> supportedArgs() { }; } - private static TableFunctionFactory createTableFunctionFactory() { + private TableFunctionFactory createTableFunctionFactory() { return new TableFunctionFactory(new UdfMetadata("my_tablefunction", "", "", "", "", false)) { @Override public KsqlTableFunction createTableFunction(List argTypeList) { - return new KsqlTableFunction() { - @Override - public Schema getReturnType() { - return null; - } - - @Override - public SqlType returnType() { - return null; - } - - @Override - public List apply(Object... args) { - return null; - } - - @Override - public String getDescription() { - return null; - } - - @Override - public FunctionName getFunctionName() { - return null; - } - - @Override - public List getArguments() { - return null; - } - }; + return tableFunction; } @Override diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java index 6d40bd6b2ff9..a388613d0b9d 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java @@ -280,7 +280,8 @@ public void shouldThrowOnMissingAnnotation() throws ClassNotFoundException { functionRegistry, Optional.empty(), SqlTypeParser.create(TypeRegistry.EMPTY), - udfClassLoader + udfClassLoader, + true ); // Expect: @@ -306,7 +307,8 @@ public void shouldThrowOnMissingSchemaProvider() throws ClassNotFoundException { functionRegistry, Optional.empty(), SqlTypeParser.create(TypeRegistry.EMPTY), - udfClassLoader + udfClassLoader, + true ); // Expect: @@ -333,7 +335,8 @@ public void shouldThrowOnReturnDecimalWithoutSchemaProvider() throws ClassNotFou functionRegistry, Optional.empty(), SqlTypeParser.create(TypeRegistry.EMPTY), - udfClassLoader + udfClassLoader, + true ); // Expect: @@ -464,7 +467,8 @@ public void shouldNotLoadInternalUdfs() { functionRegistry, Optional.empty(), SqlTypeParser.create(TypeRegistry.EMPTY), - PARENT_CLASS_LOADER + PARENT_CLASS_LOADER, + true ); udfLoader.loadUdfFromClass(UdfLoaderTest.SomeFunctionUdf.class); @@ -484,7 +488,8 @@ public void shouldLoadSomeFunction() { functionRegistry, Optional.empty(), SqlTypeParser.create(TypeRegistry.EMPTY), - PARENT_CLASS_LOADER + PARENT_CLASS_LOADER, + true ); final List args = ImmutableList.of( Schema.STRING_SCHEMA, diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/UdtfLoaderTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/UdtfLoaderTest.java index 297c923b3508..3306558c39ec 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/UdtfLoaderTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/UdtfLoaderTest.java @@ -16,16 +16,28 @@ package io.confluent.ksql.function; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; import com.google.common.collect.ImmutableList; +import io.confluent.ksql.function.udtf.Udtf; +import io.confluent.ksql.function.udtf.UdtfDescription; +import io.confluent.ksql.metastore.TypeRegistry; +import io.confluent.ksql.schema.ksql.SqlTypeParser; import io.confluent.ksql.util.DecimalUtil; +import io.confluent.ksql.util.KsqlException; import java.io.File; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.SchemaBuilder; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; public class UdtfLoaderTest { @@ -33,6 +45,10 @@ public class UdtfLoaderTest { private static final FunctionRegistry FUNC_REG = initializeFunctionRegistry(); + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Test public void shouldLoadSimpleParams() { @@ -51,7 +67,7 @@ public void shouldLoadSimpleParams() { final KsqlTableFunction function = FUNC_REG.getTableFunction("test_udtf", args); // Then: - assertThat(function.getReturnType(), equalTo(Schema.OPTIONAL_STRING_SCHEMA)); + assertThat(function.getReturnType(args), equalTo(Schema.OPTIONAL_STRING_SCHEMA)); } @Test @@ -72,7 +88,7 @@ public void shouldLoadParameterizedListParams() { final KsqlTableFunction function = FUNC_REG.getTableFunction("test_udtf", args); // Then: - assertThat(function.getReturnType(), equalTo(Schema.OPTIONAL_STRING_SCHEMA)); + assertThat(function.getReturnType(args), equalTo(Schema.OPTIONAL_STRING_SCHEMA)); } @Test @@ -96,7 +112,7 @@ public void shouldLoadParameterizedMapParams() { final KsqlTableFunction function = FUNC_REG.getTableFunction("test_udtf", args); // Then: - assertThat(function.getReturnType(), equalTo(Schema.OPTIONAL_STRING_SCHEMA)); + assertThat(function.getReturnType(args), equalTo(Schema.OPTIONAL_STRING_SCHEMA)); } @Test @@ -109,7 +125,7 @@ public void shouldLoadListIntReturn() { final KsqlTableFunction function = FUNC_REG.getTableFunction("test_udtf", args); // Then: - assertThat(function.getReturnType(), equalTo(Schema.OPTIONAL_INT32_SCHEMA)); + assertThat(function.getReturnType(args), equalTo(Schema.OPTIONAL_INT32_SCHEMA)); } @Test @@ -122,7 +138,7 @@ public void shouldLoadListLongReturn() { final KsqlTableFunction function = FUNC_REG.getTableFunction("test_udtf", args); // Then: - assertThat(function.getReturnType(), equalTo(Schema.OPTIONAL_INT64_SCHEMA)); + assertThat(function.getReturnType(args), equalTo(Schema.OPTIONAL_INT64_SCHEMA)); } @Test @@ -135,7 +151,7 @@ public void shouldLoadListDoubleReturn() { final KsqlTableFunction function = FUNC_REG.getTableFunction("test_udtf", args); // Then: - assertThat(function.getReturnType(), equalTo(Schema.OPTIONAL_FLOAT64_SCHEMA)); + assertThat(function.getReturnType(args), equalTo(Schema.OPTIONAL_FLOAT64_SCHEMA)); } @Test @@ -148,7 +164,7 @@ public void shouldLoadListBooleanReturn() { final KsqlTableFunction function = FUNC_REG.getTableFunction("test_udtf", args); // Then: - assertThat(function.getReturnType(), equalTo(Schema.OPTIONAL_BOOLEAN_SCHEMA)); + assertThat(function.getReturnType(args), equalTo(Schema.OPTIONAL_BOOLEAN_SCHEMA)); } @Test @@ -161,11 +177,11 @@ public void shouldLoadListStringReturn() { final KsqlTableFunction function = FUNC_REG.getTableFunction("test_udtf", args); // Then: - assertThat(function.getReturnType(), equalTo(Schema.OPTIONAL_STRING_SCHEMA)); + assertThat(function.getReturnType(args), equalTo(Schema.OPTIONAL_STRING_SCHEMA)); } @Test - public void shouldLoadListBigDecimalReturnWithSchemaAnnotation() { + public void shouldLoadListBigDecimalReturnWithSchemaProvider() { // Given: final List args = ImmutableList.of(DECIMAL_SCHEMA); @@ -174,7 +190,7 @@ public void shouldLoadListBigDecimalReturnWithSchemaAnnotation() { final KsqlTableFunction function = FUNC_REG.getTableFunction("test_udtf", args); // Then: - assertThat(function.getReturnType(), equalTo(DecimalUtil.builder(10, 10).build())); + assertThat(function.getReturnType(args), equalTo(DecimalUtil.builder(30, 10).build())); } @Test @@ -187,7 +203,7 @@ public void shouldLoadListStructReturnWithSchemaAnnotation() { final KsqlTableFunction function = FUNC_REG.getTableFunction("test_udtf", args); // Then: - assertThat(function.getReturnType(), equalTo(STRUCT_SCHEMA)); + assertThat(function.getReturnType(args), equalTo(STRUCT_SCHEMA)); } @Test @@ -200,7 +216,91 @@ public void shouldLoadVarArgsMethod() { final KsqlTableFunction function = FUNC_REG.getTableFunction("test_udtf", args); // Then: - assertThat(function.getReturnType(), equalTo(STRUCT_SCHEMA)); + assertThat(function.getReturnType(args), equalTo(STRUCT_SCHEMA)); + } + + @Test + public void shouldNotLoadUdtfWithWrongReturnValue() { + // Given: + final MutableFunctionRegistry functionRegistry = new InternalFunctionRegistry(); + final SqlTypeParser typeParser = SqlTypeParser.create(TypeRegistry.EMPTY); + final UdtfLoader udtfLoader = new UdtfLoader( + functionRegistry, Optional.empty(), typeParser, true + ); + + // Expect: + expectedException.expect(KsqlException.class); + expectedException + .expectMessage( + is("UDTF functions must return a List. Class io.confluent.ksql.function.UdtfLoaderTest$UdtfBadReturnValue Method badReturn")); + + // When: + udtfLoader.loadUdtfFromClass(UdtfBadReturnValue.class, KsqlFunction.INTERNAL_PATH); + } + + @Test + public void shouldNotLoadUdtfWithRawListReturn() { + // Given: + final MutableFunctionRegistry functionRegistry = new InternalFunctionRegistry(); + final SqlTypeParser typeParser = SqlTypeParser.create(TypeRegistry.EMPTY); + final UdtfLoader udtfLoader = new UdtfLoader( + functionRegistry, Optional.empty(), typeParser, true + ); + + // Expect: + expectedException.expect(KsqlException.class); + expectedException + .expectMessage( + is("UDTF functions must return a parameterized List. Class io.confluent.ksql.function.UdtfLoaderTest$RawListReturn Method badReturn")); + + // When: + udtfLoader.loadUdtfFromClass(RawListReturn.class, KsqlFunction.INTERNAL_PATH); + } + + @Test + public void shouldNotLoadUdtfWithBigDecimalReturnAndNoSchemaProvider() { + // Given: + final MutableFunctionRegistry functionRegistry = new InternalFunctionRegistry(); + final SqlTypeParser typeParser = SqlTypeParser.create(TypeRegistry.EMPTY); + final UdtfLoader udtfLoader = new UdtfLoader( + functionRegistry, Optional.empty(), typeParser, true + ); + + // Expect: + expectedException.expect(KsqlException.class); + expectedException + .expectMessage( + is("Cannot load UDF bigDecimalNoSchemaProvider. BigDecimal return type is not supported without a schema provider method.")); + + // When: + udtfLoader.loadUdtfFromClass(BigDecimalNoSchemaProvider.class, KsqlFunction.INTERNAL_PATH); + } + + @UdtfDescription(name = "badReturnUdtf", description = "whatever") + static class UdtfBadReturnValue { + + @Udtf + public Map badReturn(int foo) { + return new HashMap<>(); + } + } + + @UdtfDescription(name = "rawListReturn", description = "whatever") + static class RawListReturn { + + @Udtf + public List badReturn(int foo) { + return new ArrayList(); + } + } + + @UdtfDescription(name = "bigDecimalNoSchemaProvider", description = "whatever") + static class BigDecimalNoSchemaProvider { + + @Udtf + public List badReturn(int foo) { + return ImmutableList.of(new BigDecimal("123")); + } } private static final Schema STRUCT_SCHEMA = diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/udf/TestUdtf.java b/ksql-engine/src/test/java/io/confluent/ksql/function/udf/TestUdtf.java index 69a474573591..203a3d693ce7 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/udf/TestUdtf.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/udf/TestUdtf.java @@ -18,6 +18,8 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.function.udtf.Udtf; import io.confluent.ksql.function.udtf.UdtfDescription; +import io.confluent.ksql.schema.ksql.types.SqlDecimal; +import io.confluent.ksql.schema.ksql.types.SqlType; import java.math.BigDecimal; import java.util.List; import java.util.Map; @@ -92,8 +94,8 @@ public List listStringReturn(String s) { return ImmutableList.of(s); } - @Udtf(schema = "DECIMAL(10, 10)") - public List listBigDecimalReturn(BigDecimal bd) { + @Udtf(schemaProvider = "provideSchema") + public List listBigDecimalReturnWithSchemaProvider(BigDecimal bd) { return ImmutableList.of(bd); } @@ -102,4 +104,9 @@ public List listStructReturn(@UdfParameter(schema = "STRUCT") return ImmutableList.of(struct); } + @UdfSchemaProvider + public SqlType provideSchema(final List params) { + return SqlDecimal.of(30, 10); + } + } diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index 4ba5155eb580..6d4fdcdce94f 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -394,15 +394,15 @@ public Void visitFunctionCall( } if (functionRegistry.isTableFunction(node.getName().name())) { - final List schema = node.getArguments().isEmpty() + final List argumentTypes = node.getArguments().isEmpty() ? ImmutableList.of(FunctionRegistry.DEFAULT_FUNCTION_ARG_SCHEMA) : node.getArguments().stream().map(ExpressionTypeManager.this::getExpressionSchema) .collect(Collectors.toList()); final KsqlTableFunction tableFunction = functionRegistry - .getTableFunction(node.getName().name(), schema); + .getTableFunction(node.getName().name(), argumentTypes); - expressionTypeContext.setSchema(tableFunction.getReturnType()); + expressionTypeContext.setSchema(tableFunction.getReturnType(argumentTypes)); return null; } diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/explode.json b/ksql-functional-tests/src/test/resources/query-validation-tests/explode.json index 3ebf4522753d..e3d06921acb4 100644 --- a/ksql-functional-tests/src/test/resources/query-validation-tests/explode.json +++ b/ksql-functional-tests/src/test/resources/query-validation-tests/explode.json @@ -67,17 +67,17 @@ { "name": "explode different types", "statements": [ - "CREATE STREAM TEST (F0 ARRAY, F1 ARRAY, F2 ARRAY, F3 ARRAY, F4 ARRAY) WITH (kafka_topic='test_topic', value_format='JSON');", - "CREATE STREAM OUTPUT AS SELECT EXPLODE(F0), EXPLODE(F1), EXPLODE(F2), EXPLODE(F3), EXPLODE(F4) FROM TEST;" + "CREATE STREAM TEST (F0 ARRAY, F1 ARRAY, F2 ARRAY, F3 ARRAY, F4 ARRAY, F5 ARRAY) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT EXPLODE(F0), EXPLODE(F1), EXPLODE(F2), EXPLODE(F3), EXPLODE(F4), EXPLODE(F5) FROM TEST;" ], "inputs": [ { - "topic": "test_topic", "key": 0, "value": {"F0": [1, 2], "F1": [2, 3], "F2": [3.1, 4.1], "F3": [true, false], "F4": ["foo", "bar"]} + "topic": "test_topic", "key": 0, "value": {"F0": [1, 2], "F1": [2, 3], "F2": [3.1, 4.1], "F3": [true, false], "F4": ["foo", "bar"], "F5": [123.456, 456.123]} } ], "outputs": [ - {"topic": "OUTPUT", "key": "0", "value": {"KSQL_COL_0": 1, "KSQL_COL_1": 2, "KSQL_COL_2": 3.1, "KSQL_COL_3": true, "KSQL_COL_4": "foo"}}, - {"topic": "OUTPUT", "key": "0", "value": {"KSQL_COL_0": 2, "KSQL_COL_1": 3, "KSQL_COL_2": 4.1, "KSQL_COL_3": false, "KSQL_COL_4": "bar"}} + {"topic": "OUTPUT", "key": "0", "value": {"KSQL_COL_0": 1, "KSQL_COL_1": 2, "KSQL_COL_2": 3.1, "KSQL_COL_3": true, "KSQL_COL_4": "foo", "KSQL_COL_5": 123.456}}, + {"topic": "OUTPUT", "key": "0", "value": {"KSQL_COL_0": 2, "KSQL_COL_1": 3, "KSQL_COL_2": 4.1, "KSQL_COL_3": false, "KSQL_COL_4": "bar", "KSQL_COL_5": 456.123}} ] } ] diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/table-functions.json b/ksql-functional-tests/src/test/resources/query-validation-tests/table-functions.json index afe58730e07e..77fbd3f45766 100644 --- a/ksql-functional-tests/src/test/resources/query-validation-tests/table-functions.json +++ b/ksql-functional-tests/src/test/resources/query-validation-tests/table-functions.json @@ -207,7 +207,16 @@ ], "outputs": [ {"topic": "OUTPUT", "key": "0", "value": {"KSQL_COL_0": 1, "KSQL_COL_1": 2, "KSQL_COL_2": 3.1, "KSQL_COL_3": true, "KSQL_COL_4": "foo", "KSQL_COL_5": 123.456, "KSQL_COL_6": {"A": "bar"}}} - ] + ], + "post": { + "sources": [ + { + "name": "OUTPUT", + "type": "stream", + "valueSchema": "STRUCT>" + } + ] + } } ] } \ No newline at end of file