Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use SqlArgument wrapper to look up functions #7011

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.confluent.ksql.function.types.ParamTypes;
import io.confluent.ksql.function.udf.UdfMetadata;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlDecimal;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.util.KsqlConstants;
Expand Down Expand Up @@ -59,7 +60,7 @@ public AggregateFunctionFactory(final UdfMetadata metadata) {
}

public abstract KsqlAggregateFunction<?, ?, ?> createAggregateFunction(
List<SqlType> argTypeList, AggregateFunctionInitArguments initArgs);
List<SqlArgument> argTypeList, AggregateFunctionInitArguments initArgs);

protected abstract List<List<ParamType>> supportedArgs();

Expand All @@ -77,6 +78,7 @@ public void eachFunction(final Consumer<KsqlAggregateFunction<?, ?, ?>> consumer
.map(args -> args
.stream()
.map(AggregateFunctionFactory::getSampleSqlType)
.map(arg -> SqlArgument.of(arg))
.collect(Collectors.toList()))
.forEach(args -> consumer.accept(createAggregateFunction(args, getDefaultArguments())));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.confluent.ksql.function;

import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.testing.EffectivelyImmutable;
Expand Down Expand Up @@ -114,7 +115,7 @@ public interface FunctionRegistry {
* @return the function instance.
* @throws KsqlException on unknown table function, or on unsupported {@code argumentType}.
*/
KsqlTableFunction getTableFunction(FunctionName functionName, List<SqlType> argumentTypes);
KsqlTableFunction getTableFunction(FunctionName functionName, List<SqlArgument> argumentTypes);

/**
* @return all UDF factories.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.types.StructType;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlMap;
import io.confluent.ksql.schema.ksql.types.SqlStruct.Builder;
Expand Down Expand Up @@ -145,7 +146,7 @@ public static SqlType applyResolved(
*/
public static Map<GenericType, SqlType> resolveGenerics(
final ParamType schema,
final SqlType instance
final SqlArgument instance
) {
final List<Entry<GenericType, SqlType>> genericMapping = new ArrayList<>();
final boolean success = resolveGenerics(genericMapping, schema, instance);
Expand All @@ -172,11 +173,15 @@ public static Map<GenericType, SqlType> resolveGenerics(
return ImmutableMap.copyOf(mapping);
}

// CHECKSTYLE_RULES.OFF: CyclomaticComplexity
private static boolean resolveGenerics(
final List<Entry<GenericType, SqlType>> mapping,
final ParamType schema,
final SqlType instance
final SqlArgument instance
) {
// CHECKSTYLE_RULES.ON: CyclomaticComplexity
final SqlType sqlType = instance.getSqlType();

if (!isGeneric(schema) && !matches(schema, instance)) {
// cannot identify from type mismatch
return false;
Expand All @@ -191,19 +196,20 @@ private static boolean resolveGenerics(
+ schema + " vs. " + instance);

if (isGeneric(schema)) {
mapping.add(new HashMap.SimpleEntry<>((GenericType) schema, instance));
mapping.add(new HashMap.SimpleEntry<>((GenericType) schema, sqlType));
}

if (schema instanceof ArrayType) {
final SqlArray sqlArray = (SqlArray) sqlType;
return resolveGenerics(
mapping, ((ArrayType) schema).element(), ((SqlArray) instance).getItemType());
mapping, ((ArrayType) schema).element(), SqlArgument.of(sqlArray.getItemType()));
}

if (schema instanceof MapType) {
final SqlMap sqlMap = (SqlMap) instance;
final SqlMap sqlMap = (SqlMap) sqlType;
final MapType mapType = (MapType) schema;
return resolveGenerics(mapping, mapType.key(), sqlMap.getKeyType())
&& resolveGenerics(mapping, mapType.value(), sqlMap.getValueType());
return resolveGenerics(mapping, mapType.key(), SqlArgument.of(sqlMap.getKeyType()))
&& resolveGenerics(mapping, mapType.value(), SqlArgument.of(sqlMap.getValueType()));
}

if (schema instanceof StructType) {
Expand All @@ -213,9 +219,9 @@ private static boolean resolveGenerics(
return true;
}

private static boolean matches(final ParamType schema, final SqlType instance) {
private static boolean matches(final ParamType schema, final SqlArgument instance) {
final ParamType instanceParamType = SchemaConverters
.sqlToFunctionConverter().toFunctionType(instance);
.sqlToFunctionConverter().toFunctionType(instance.getSqlType());
return schema.getClass() == instanceParamType.getClass();
}

Expand All @@ -224,7 +230,7 @@ private static boolean matches(final ParamType schema, final SqlType instance) {
* @param instance a schema without generics
* @return whether {@code instance} conforms to the structure of {@code schema}
*/
public static boolean instanceOf(final ParamType schema, final SqlType instance) {
public static boolean instanceOf(final ParamType schema, final SqlArgument instance) {
final List<Entry<GenericType, SqlType>> mappings = new ArrayList<>();

if (!resolveGenerics(mappings, schema, instance)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.testing.EffectivelyImmutable;
import java.util.List;
Expand Down Expand Up @@ -75,7 +76,7 @@ public class KsqlFunction implements FunctionSignature {
}
}

public SqlType getReturnType(final List<SqlType> arguments) {
public SqlType getReturnType(final List<SqlArgument> arguments) {
return returnSchemaProvider.resolve(parameters(), arguments);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
package io.confluent.ksql.function;

import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlType;
import java.util.List;

@FunctionalInterface
public interface SchemaProvider {

SqlType resolve(List<ParamType> parameters, List<SqlType> arguments);
SqlType resolve(List<ParamType> parameters, List<SqlArgument> arguments);

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.udf.UdfMetadata;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.SqlArgument;

import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
Expand Down Expand Up @@ -46,7 +47,7 @@ public synchronized void eachFunction(final Consumer<KsqlTableFunction> consumer
udtfIndex.values().forEach(consumer);
}

public synchronized KsqlTableFunction createTableFunction(final List<SqlType> argTypes) {
public synchronized KsqlTableFunction createTableFunction(final List<SqlArgument> argTypes) {
return udtfIndex.getFunction(argTypes);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import io.confluent.ksql.function.udf.Kudf;
import io.confluent.ksql.function.udf.UdfMetadata;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.util.KsqlException;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -79,7 +79,7 @@ public String toString() {
+ '}';
}

public synchronized KsqlScalarFunction getFunction(final List<SqlType> argTypes) {
public synchronized KsqlScalarFunction getFunction(final List<SqlArgument> argTypes) {
return udfIndex.getFunction(argTypes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.confluent.ksql.function.types.GenericType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.types.ParamTypes;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.utils.FormatOptions;
import io.confluent.ksql.util.KsqlException;
Expand Down Expand Up @@ -142,7 +143,7 @@ void addFunction(final T function) {
curr.update(function, order);
}

T getFunction(final List<SqlType> arguments) {
T getFunction(final List<SqlArgument> arguments) {
final List<Node> candidates = new ArrayList<>();

// first try to get the candidates without any implicit casting
Expand All @@ -169,7 +170,7 @@ T getFunction(final List<SqlType> arguments) {
}

private void getCandidates(
final List<SqlType> arguments,
final List<SqlArgument> arguments,
final int argIndex,
final Node current,
final List<Node> candidates,
Expand All @@ -183,7 +184,7 @@ private void getCandidates(
return;
}

final SqlType arg = arguments.get(argIndex);
final SqlArgument arg = arguments.get(argIndex);
for (final Entry<Parameter, Node> candidate : current.children.entrySet()) {
final Map<GenericType, SqlType> reservedCopy = new HashMap<>(reservedGenerics);
if (candidate.getKey().accepts(arg, reservedCopy, allowCasts)) {
Expand All @@ -193,11 +194,11 @@ private void getCandidates(
}
}

private KsqlException createNoMatchingFunctionException(final List<SqlType> paramTypes) {
private KsqlException createNoMatchingFunctionException(final List<SqlArgument> paramTypes) {
LOG.debug("Current UdfIndex:\n{}", describe());

final String requiredTypes = paramTypes.stream()
.map(type -> type == null ? "null" : type.toString(FormatOptions.noEscape()))
.map(type -> type == null ? "null" : type.getSqlType().toString(FormatOptions.noEscape()))
.collect(Collectors.joining(", ", "(", ")"));

final String acceptedTypes = allFunctions.values().stream()
Expand Down Expand Up @@ -348,9 +349,9 @@ public int hashCode() {
* this parameter
*/
// CHECKSTYLE_RULES.OFF: BooleanExpressionComplexity
boolean accepts(final SqlType argument, final Map<GenericType, SqlType> reservedGenerics,
boolean accepts(final SqlArgument argument, final Map<GenericType, SqlType> reservedGenerics,
final boolean allowCasts) {
if (argument == null) {
if (argument == null || argument.getSqlType() == null) {
return true;
}

Expand All @@ -364,7 +365,7 @@ boolean accepts(final SqlType argument, final Map<GenericType, SqlType> reserved

private static boolean reserveGenerics(
final ParamType schema,
final SqlType argument,
final SqlArgument argument,
final Map<GenericType, SqlType> reservedGenerics
) {
if (!GenericsUtil.instanceOf(schema, argument)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static io.confluent.ksql.schema.ksql.SchemaConverters.functionToSqlBaseConverter;

import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlBaseType;
import io.confluent.ksql.schema.ksql.types.SqlMap;
Expand All @@ -40,33 +41,34 @@ private ParamTypes() {
public static final TimestampType TIMESTAMP = TimestampType.INSTANCE;

public static boolean areCompatible(final SqlType actual, final ParamType declared) {
return areCompatible(actual, declared, false);
return areCompatible(SqlArgument.of(actual), declared, false);
}

public static boolean areCompatible(
final SqlType actual,
final SqlArgument argument,
final ParamType declared,
final boolean allowCast
) {
if (actual.baseType() == SqlBaseType.ARRAY && declared instanceof ArrayType) {
final SqlType argumentSqlType = argument.getSqlType();
if (argumentSqlType.baseType() == SqlBaseType.ARRAY && declared instanceof ArrayType) {
return areCompatible(
((SqlArray) actual).getItemType(),
SqlArgument.of(((SqlArray) argumentSqlType).getItemType()),
((ArrayType) declared).element(),
allowCast);
}

if (actual.baseType() == SqlBaseType.MAP && declared instanceof MapType) {
final SqlMap sqlType = (SqlMap) actual;
if (argumentSqlType.baseType() == SqlBaseType.MAP && declared instanceof MapType) {
final SqlMap sqlType = (SqlMap) argumentSqlType;
final MapType mapType = (MapType) declared;
return areCompatible(sqlType.getKeyType(), mapType.key(), allowCast)
&& areCompatible(sqlType.getValueType(), mapType.value(), allowCast);
return areCompatible(SqlArgument.of(sqlType.getKeyType()), mapType.key(), allowCast)
&& areCompatible(SqlArgument.of(sqlType.getValueType()), mapType.value(), allowCast);
}

if (actual.baseType() == SqlBaseType.STRUCT && declared instanceof StructType) {
return isStructCompatible(actual, declared);
if (argumentSqlType.baseType() == SqlBaseType.STRUCT && declared instanceof StructType) {
return isStructCompatible(argumentSqlType, declared);
}

return isPrimitiveMatch(actual, declared, allowCast);
return isPrimitiveMatch(argumentSqlType, declared, allowCast);
}

private static boolean isStructCompatible(final SqlType actual, final ParamType declared) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.confluent.ksql.function.udf.Kudf;
import io.confluent.ksql.function.udf.UdfMetadata;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.ksql.SqlArgument;
import org.apache.kafka.common.KafkaException;
import org.junit.Test;

Expand All @@ -42,7 +43,7 @@ public void shouldThrowIfNoVariantFoundThatAcceptsSuppliedParamTypes() {
// When:
final Exception e = assertThrows(
KafkaException.class,
() -> factory.getFunction(of(STRING, BIGINT))
() -> factory.getFunction(of(SqlArgument.of(STRING), SqlArgument.of(BIGINT)))
);

// Then:
Expand Down
Loading