Skip to content

Commit

Permalink
type checking without sql type
Browse files Browse the repository at this point in the history
feat: add functions as UDF input types

fix immutable test

proof of concept lambda

follow up

extend lambda functionality

proof of concept lambda

feat: add lambda syntax to grammar

change g4

proof of concept lambda

initial type checking

bug fixes, clean up

clean-up and adding lambda sql type stuff

adding lambdas to sql types

Type cleanup

adding contexts

cleanup

updating qtt

extending support for maps

Updating for multiple input types

Steven's ast fix + functional tests

Initial experiment for new sql type

getting return types working

cleanup
  • Loading branch information
lct45 committed Feb 8, 2021
1 parent 5eaa3be commit 165896b
Show file tree
Hide file tree
Showing 47 changed files with 3,130 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import com.google.common.collect.Sets;
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.GenericType;
import io.confluent.ksql.function.types.LambdaType;
import io.confluent.ksql.function.types.MapType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.types.StructType;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlLambda;
import io.confluent.ksql.schema.ksql.types.SqlMap;
import io.confluent.ksql.schema.ksql.types.SqlStruct.Builder;
import io.confluent.ksql.schema.ksql.types.SqlType;
Expand All @@ -34,6 +36,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
Expand Down Expand Up @@ -76,6 +79,14 @@ public static Set<ParamType> constituentGenerics(final ParamType type) {
.collect(Collectors.toSet());
} else if (type instanceof GenericType) {
return ImmutableSet.of(type);
} else if (type instanceof LambdaType) {
final Set<ParamType> inputSet = new HashSet<>();
for (final ParamType paramType: ((LambdaType) type).inputTypes()) {
inputSet.addAll(constituentGenerics(paramType));
}
return Sets.union(
inputSet,
constituentGenerics(((LambdaType) type).returnType()));
} else {
return ImmutableSet.of();
}
Expand Down Expand Up @@ -172,11 +183,58 @@ public static Map<GenericType, SqlType> resolveGenerics(
return ImmutableMap.copyOf(mapping);
}

public static Map<GenericType, SqlType> resolveLambdaGenerics(
final ParamType schema,
final SqlLambda sqlLambda
) {
final List<Entry<GenericType, SqlType>> genericMapping = new ArrayList<>();
boolean success;

final LambdaType lambdaType = (LambdaType) schema;
boolean resolvedInputs = true;
if (sqlLambda.getInputType().size() != lambdaType.inputTypes().size()) {
throw new KsqlException(
"Number of lambda arguments don't match between schema and sql type");
}

int i = 0;
for (final ParamType paramType : lambdaType.inputTypes()) {
resolvedInputs =
resolvedInputs && resolveGenerics(
genericMapping, paramType, sqlLambda.getInputType().get(i));
i++;
}
success = resolvedInputs && resolveGenerics(genericMapping, lambdaType.returnType(), sqlLambda.getReturnType());
if (!success) {
throw new KsqlException(
String.format("Cannot infer generics for %s from %s because "
+ "they do not have the same schema structure.",
schema,
sqlLambda));
}

final Map<GenericType, SqlType> mapping = new HashMap<>();
for (final Entry<GenericType, SqlType> entry : genericMapping) {
final SqlType old = mapping.putIfAbsent(entry.getKey(), entry.getValue());
if (old != null && !old.equals(entry.getValue())) {
throw new KsqlException(String.format(
"Found invalid instance of generic schema. Cannot map %s to both %s and %s",
schema,
old,
sqlLambda));
}
}

return ImmutableMap.copyOf(mapping);
}

// CHECKSTYLE_RULES.OFF: CyclomaticComplexity
private static boolean resolveGenerics(
final List<Entry<GenericType, SqlType>> mapping,
final ParamType schema,
final SqlType instance
) {
// CHECKSTYLE_RULES.ON: CyclomaticComplexity
if (!isGeneric(schema) && !matches(schema, instance)) {
// cannot identify from type mismatch
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import io.confluent.ksql.function.udf.Kudf;
import io.confluent.ksql.function.udf.UdfMetadata;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.util.KsqlException;
import java.util.List;
Expand Down Expand Up @@ -79,6 +80,10 @@ public String toString() {
+ '}';
}

public synchronized KsqlScalarFunction getUdfFunction(final List<SqlArgument> argTypes) {
return udfIndex.getUdfFunction(argTypes);
}

public synchronized KsqlScalarFunction getFunction(final List<SqlType> argTypes) {
return udfIndex.getFunction(argTypes);
}
Expand Down
119 changes: 119 additions & 0 deletions ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
import com.google.common.collect.Iterables;
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.GenericType;
import io.confluent.ksql.function.types.LambdaType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.types.ParamTypes;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlLambda;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.utils.FormatOptions;
import io.confluent.ksql.util.KsqlException;
Expand Down Expand Up @@ -142,6 +145,32 @@ void addFunction(final T function) {
curr.update(function, order);
}

T getUdfFunction(final List<SqlArgument> arguments) {
final List<Node> candidates = new ArrayList<>();

// first try to get the candidates without any implicit casting
getUdfCandidates(arguments, 0, root, candidates, new HashMap<>(), false);
final Optional<T> fun = candidates
.stream()
.max(Node::compare)
.map(node -> node.value);

if (fun.isPresent()) {
return fun.get();
} else if (!supportsImplicitCasts) {
throw createNoMatchingFunctionExceptionSqlArgument(arguments);
}

// if none were found (candidates is empty) try again with
// implicit casting
getUdfCandidates(arguments, 0, root, candidates, new HashMap<>(), true);
return candidates
.stream()
.max(Node::compare)
.map(node -> node.value)
.orElseThrow(() -> createNoMatchingFunctionExceptionSqlArgument(arguments));
}

T getFunction(final List<SqlType> arguments) {
final List<Node> candidates = new ArrayList<>();

Expand All @@ -168,6 +197,37 @@ T getFunction(final List<SqlType> arguments) {
.orElseThrow(() -> createNoMatchingFunctionException(arguments));
}

private void getUdfCandidates(
final List<SqlArgument> arguments,
final int argIndex,
final Node current,
final List<Node> candidates,
final Map<GenericType, SqlType> reservedGenerics,
final boolean allowCasts
) {
if (argIndex == arguments.size()) {
if (current.value != null) {
candidates.add(current);
}
return;
}
final SqlType arg = arguments.get(argIndex).getSqlType();
for (final Entry<Parameter, Node> candidate : current.children.entrySet()) {
final Map<GenericType, SqlType> reservedCopy = new HashMap<>(reservedGenerics);
if (candidate.getKey().type instanceof LambdaType) {
if (candidate.getKey().acceptsLambda(arguments.get(argIndex).getSqlLambda(), reservedCopy, allowCasts)) {
final Node node = candidate.getValue();
getUdfCandidates(arguments, argIndex + 1, node, candidates, reservedCopy, allowCasts);
}
} else {
if (candidate.getKey().accepts(arg, reservedCopy, allowCasts)) {
final Node node = candidate.getValue();
getUdfCandidates(arguments, argIndex + 1, node, candidates, reservedCopy, allowCasts);
}
}
}
}

private void getCandidates(
final List<SqlType> arguments,
final int argIndex,
Expand Down Expand Up @@ -215,6 +275,28 @@ private KsqlException createNoMatchingFunctionException(final List<SqlType> para
);
}

private KsqlException createNoMatchingFunctionExceptionSqlArgument(final List<SqlArgument> paramTypes) {
LOG.debug("Current UdfIndex:\n{}", describe());

final String requiredTypes = paramTypes.stream()
.map(type -> type.getSqlType() == null ? "null" : type.getSqlType().toString(FormatOptions.noEscape()))
.collect(Collectors.joining(", ", "(", ")"));

final String acceptedTypes = allFunctions.values().stream()
.map(UdfIndex::formatAvailableSignatures)
.collect(Collectors.joining(System.lineSeparator()));

return new KsqlException("Function '" + udfName
+ "' does not accept parameters " + requiredTypes + "."
+ System.lineSeparator()
+ "Valid alternatives are:"
+ System.lineSeparator()
+ acceptedTypes
+ System.lineSeparator()
+ "For detailed information on a function run: DESCRIBE FUNCTION <Function-Name>;"
);
}

public Collection<T> values() {
return allFunctions.values();
}
Expand Down Expand Up @@ -362,6 +444,21 @@ boolean accepts(final SqlType argument, final Map<GenericType, SqlType> reserved
}
// CHECKSTYLE_RULES.ON: BooleanExpressionComplexity

// CHECKSTYLE_RULES.OFF: BooleanExpressionComplexity
boolean acceptsLambda(final SqlLambda lambda, final Map<GenericType, SqlType> reservedGenerics,
final boolean allowCasts) {
if (lambda == null) {
return true;
}

if (GenericsUtil.hasGenerics(type)) {
return reserveLambdaGenerics(type, lambda, reservedGenerics);
}

return ParamTypes.isLambdaCompatible(lambda, type);
}
// CHECKSTYLE_RULES.ON: BooleanExpressionComplexity

private static boolean reserveGenerics(
final ParamType schema,
final SqlType argument,
Expand All @@ -384,6 +481,28 @@ private static boolean reserveGenerics(
return true;
}

private static boolean reserveLambdaGenerics(
final ParamType schema,
final SqlLambda argument,
final Map<GenericType, SqlType> reservedGenerics
) {
/*if (!GenericsUtil.instanceOf(schema, argument)) {
return false;
}*/

final Map<GenericType, SqlType> genericMapping = GenericsUtil
.resolveLambdaGenerics(schema, argument);

for (final Entry<GenericType, SqlType> entry : genericMapping.entrySet()) {
final SqlType old = reservedGenerics.putIfAbsent(entry.getKey(), entry.getValue());
if (old != null && !old.equals(entry.getValue())) {
return false;
}
}

return true;
}

@Override
public String toString() {
return type + (isVararg ? "(VARARG)" : "");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright 2021 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"; you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.function.types;

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;

public final class LambdaType extends ObjectType {

private final ImmutableList<ParamType> inputTypes;
private final ParamType returnType;

public LambdaType(
final List<ParamType> inputTypes,
final ParamType returnType
) {
this.inputTypes = ImmutableList.copyOf(
Objects.requireNonNull(inputTypes, "inputTypes"));
this.returnType = Objects.requireNonNull(returnType, "returnType");
}

public static LambdaType of(
final List<ParamType> inputTypes,
final ParamType returnType
) {
return new LambdaType(inputTypes, returnType);
}

public List<ParamType> inputTypes() {
return inputTypes;
}

public ParamType returnType() {
return returnType;
}

@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final LambdaType lambdaType = (LambdaType) o;
return Objects.equals(inputTypes, lambdaType.inputTypes)
&& Objects.equals(returnType, lambdaType.returnType);
}

@Override
public int hashCode() {
return Objects.hash(inputTypes, returnType);
}

@Override
public String toString() {
return "LAMBDA<" + inputTypes + ", " + returnType + ">";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlBaseType;
import io.confluent.ksql.schema.ksql.types.SqlLambda;
import io.confluent.ksql.schema.ksql.types.SqlMap;
import io.confluent.ksql.schema.ksql.types.SqlStruct;
import io.confluent.ksql.schema.ksql.types.SqlStruct.Field;
Expand Down Expand Up @@ -88,6 +89,25 @@ private static boolean isStructCompatible(final SqlType actual, final ParamType
return actualStruct.fields().size() == ((StructType) declared).getSchema().size();
}

public static boolean isLambdaCompatible(final SqlLambda actual, final ParamType declared) {
final LambdaType declaredLambda = (LambdaType) declared;
if (actual.getInputType().size() != declaredLambda.inputTypes().size()) {
return false;
}
int i = 0;
for (final ParamType paramType: declaredLambda.inputTypes()) {
if (!areCompatible(actual.getInputType().get(i), paramType)) {
return false;
}
i++;
}

if (!areCompatible(actual.getReturnType(), declaredLambda.returnType())) {
return false;
}
return true;
}

// CHECKSTYLE_RULES.OFF: CyclomaticComplexity
private static boolean isPrimitiveMatch(
final SqlType actual,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ PreparedStatement<?> prepare(final ParsedStatement stmt, final Map<String, Strin
} catch (final KsqlStatementException e) {
throw e;
} catch (final Exception e) {
e.printStackTrace();
throw new KsqlStatementException(
"Exception while preparing statement: " + e.getMessage(), stmt.getStatementText(), e);
}
Expand Down
Loading

0 comments on commit 165896b

Please sign in to comment.