Skip to content

Commit

Permalink
Adding support for multiple variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
lct45 committed Feb 19, 2021
1 parent 1541033 commit 5049d52
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ private static boolean isStructCompatible(final SqlType actual, final ParamType
final String k = entry.getKey();
final Optional<Field> field = actualStruct.field(k);
// intentionally do not allow implicit casting within structs
if (!field.isPresent() ||
!areCompatible(SqlArgument.of(field.get().type()), entry.getValue(), false)) {
if (!field.isPresent()
|| !areCompatible(SqlArgument.of(field.get().type()), entry.getValue(), false)) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,41 @@ public class ParamTypesTest {

@Test
public void shouldFailINonCompatibleSchemas() {
assertThat(ParamTypes.areCompatible(SqlTypes.STRING, ParamTypes.INTEGER), is(false));
assertThat(ParamTypes.areCompatible(
SqlArgument.of(SqlTypes.STRING),
ParamTypes.INTEGER,
false),
is(false));

assertThat(ParamTypes.areCompatible(SqlTypes.STRING, GenericType.of("T")), is(false));
assertThat(ParamTypes.areCompatible(
SqlArgument.of(SqlTypes.STRING),
GenericType.of("T"),
false),
is(false));

assertThat(
ParamTypes.areCompatible(SqlTypes.array(SqlTypes.INTEGER), ArrayType.of(ParamTypes.STRING)),
ParamTypes.areCompatible(
SqlArgument.of(SqlTypes.array(SqlTypes.INTEGER)),
ArrayType.of(ParamTypes.STRING),
false),
is(false));

assertThat(ParamTypes.areCompatible(
SqlTypes.struct().field("a", SqlTypes.decimal(1, 1)).build(),
StructType.builder().field("a", ParamTypes.DOUBLE).build()),
SqlArgument.of(SqlTypes.struct().field("a", SqlTypes.decimal(1, 1)).build()),
StructType.builder().field("a", ParamTypes.DOUBLE).build(),
false),
is(false));

assertThat(ParamTypes.areCompatible(
SqlTypes.map(SqlTypes.STRING, SqlTypes.decimal(1, 1)),
MapType.of(ParamTypes.STRING, ParamTypes.INTEGER)),
SqlArgument.of(SqlTypes.map(SqlTypes.STRING, SqlTypes.decimal(1, 1))),
MapType.of(ParamTypes.STRING, ParamTypes.INTEGER),
false),
is(false));

assertThat(ParamTypes.areCompatible(
SqlTypes.map(SqlTypes.decimal(1, 1), SqlTypes.INTEGER),
MapType.of(ParamTypes.INTEGER, ParamTypes.INTEGER)),
SqlArgument.of(SqlTypes.map(SqlTypes.decimal(1, 1), SqlTypes.INTEGER)),
MapType.of(ParamTypes.INTEGER, ParamTypes.INTEGER),
false),
is(false));


Expand All @@ -62,21 +76,29 @@ public void shouldFailINonCompatibleSchemas() {

@Test
public void shouldPassCompatibleSchemas() {
assertThat(ParamTypes.areCompatible(SqlTypes.STRING, ParamTypes.STRING),
assertThat(ParamTypes.areCompatible(
SqlArgument.of(SqlTypes.STRING),
ParamTypes.STRING,
false),
is(true));

assertThat(
ParamTypes.areCompatible(SqlTypes.array(SqlTypes.INTEGER), ArrayType.of(ParamTypes.INTEGER)),
ParamTypes.areCompatible(
SqlArgument.of(SqlTypes.array(SqlTypes.INTEGER)),
ArrayType.of(ParamTypes.INTEGER),
false),
is(true));

assertThat(ParamTypes.areCompatible(
SqlTypes.struct().field("a", SqlTypes.decimal(1, 1)).build(),
StructType.builder().field("a", ParamTypes.DECIMAL).build()),
SqlArgument.of(SqlTypes.struct().field("a", SqlTypes.decimal(1, 1)).build()),
StructType.builder().field("a", ParamTypes.DECIMAL).build(),
false),
is(true));

assertThat(ParamTypes.areCompatible(
SqlTypes.map(SqlTypes.INTEGER, SqlTypes.decimal(1, 1)),
MapType.of(ParamTypes.INTEGER, ParamTypes.DECIMAL)),
SqlArgument.of(SqlTypes.map(SqlTypes.INTEGER, SqlTypes.decimal(1, 1))),
MapType.of(ParamTypes.INTEGER, ParamTypes.DECIMAL),
false),
is(true));

assertThat(ParamTypes.areCompatible(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ public void shouldFailToIdentifyMismatchedGenericsInLambda() {

// Then:
assertThat(e.getMessage(), containsString(
"Found invalid instance of generic schema when mapping LAMBDA<[A], A> to Lambda<[DOUBLE], BOOLEAN>. "
"Found invalid instance of generic schema when mapping LAMBDA (A) -> A to Lambda<[DOUBLE], BOOLEAN>. "
+ "Cannot map A to both DOUBLE and BOOLEAN"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ static SchemaProvider handleUdfReturnSchema(
return (parameters, arguments) -> {
if (schemaProvider != null) {
final SqlType returnType = schemaProvider.apply(arguments);
if (!(ParamTypes.areCompatible(returnType, javaReturnSchema))) {
if (!(ParamTypes.areCompatible(SqlArgument.of(returnType), javaReturnSchema, false))) {
throw new KsqlException(String.format(
"Return type %s of UDF %s does not match the declared "
+ "return type %s.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,16 @@ public Void visitFunctionCall(final FunctionCall node, final TypeContext context
final FunctionName functionName = node.getName();
for (final Expression argExpr : node.getArguments()) {
process(argExpr, context);
final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, context);
// for lambdas - if we find an array or map passed in before encountering a lambda function
// we save the type information to resolve the lambda generics
context.visitType(resolvedArgType);
final TypeContext childContext = context.getCopy();
final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, childContext);

if (argExpr instanceof LambdaFunctionCall) {
argumentTypes.add(
SqlArgument.of(
SqlLambda.of(context.getLambdaInputTypes(), resolvedArgType)));
argumentTypes.add(SqlArgument.of(SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType())));
} else {
argumentTypes.add(SqlArgument.of(resolvedArgType));
// for lambdas - if we find an array or map passed in before encountering a lambda function
// we save the type information to resolve the lambda generics
context.visitType(node, resolvedArgType);
}
}

Expand Down Expand Up @@ -295,5 +295,14 @@ private void addRequiredColumn(final ColumnName columnName) {
column.index()
);
}

private boolean hasLambdaFunctionCall(FunctionCall node) {
for (Expression e : node.getArguments()) {
if (e instanceof LambdaFunctionCall) {
return true;
}
}
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -455,16 +455,15 @@ public Pair<String, SqlType> visitFunctionCall(
final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName());
final List<SqlArgument> argumentSchemas = new ArrayList<>();
for (final Expression argExpr : node.getArguments()) {
final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, context);
// for lambdas - if we find an array or map passed in before encountering a lambda function
// we save the type information to resolve the lambda generics
context.visitType(resolvedArgType);
final TypeContext childContext = context.getCopy();
final SqlType resolvedArgType = expressionTypeManager.getExpressionSqlType(argExpr, childContext);
if (argExpr instanceof LambdaFunctionCall) {
argumentSchemas.add(
SqlArgument.of(
SqlLambda.of(context.getLambdaInputTypes(), resolvedArgType)));
argumentSchemas.add(SqlArgument.of(SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType())));
} else {
argumentSchemas.add(SqlArgument.of(resolvedArgType));
// for lambdas - if we find an array or map passed in before encountering a lambda function
// we save the type information to resolve the lambda generics
context.visitType(node, resolvedArgType);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.confluent.ksql.execution.codegen;

import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlArray;
Expand All @@ -30,8 +31,18 @@

public class TypeContext {
private SqlType sqlType;
private final List<SqlType> lambdaInputTypes = new ArrayList<>();
private final Map<String, SqlType> lambdaInputTypeMapping = new HashMap<>();
private final List<SqlType> lambdaInputTypes;
private final Map<String, SqlType> lambdaInputTypeMapping;

public TypeContext() {
lambdaInputTypes = new ArrayList<SqlType>();
lambdaInputTypeMapping = new HashMap<>();
}

TypeContext (final List<SqlType> lambdaInputTypes, final Map<String, SqlType> lambdaInputTypeMapping) {
this.lambdaInputTypes = lambdaInputTypes;
this.lambdaInputTypeMapping = lambdaInputTypeMapping;
}

public SqlType getSqlType() {
return sqlType;
Expand Down Expand Up @@ -61,28 +72,35 @@ public void mapLambdaInputTypes(final List<String> argumentList) {
for (int i = 0; i < argumentList.size(); i++) {
this.lambdaInputTypeMapping.putIfAbsent(argumentList.get(i), lambdaInputTypes.get(i));
}
lambdaInputTypes.clear();
}

public SqlType getLambdaType(final String name) {
return lambdaInputTypeMapping.get(name);
}

public boolean notAllInputsSeen() {
return lambdaInputTypeMapping.size() != lambdaInputTypes.size() || lambdaInputTypes.size() == 0;

public TypeContext getCopy() {
return new TypeContext(this.lambdaInputTypes, this.lambdaInputTypeMapping);
}

public void visitType(SqlType type) {
if (notAllInputsSeen()) {
if (type instanceof SqlArray) {
final SqlArray inputArray = (SqlArray) type;
addLambdaInputType(inputArray.getItemType());
} else if (type instanceof SqlMap) {
final SqlMap inputMap = (SqlMap) type;
addLambdaInputType(inputMap.getKeyType());
addLambdaInputType(inputMap.getValueType());
} else {
addLambdaInputType(type);
public void visitType(final FunctionCall node, SqlType type) {
boolean hasLambda = false;
for (Expression e : node.getArguments()) {
if (e instanceof LambdaFunctionCall) {
hasLambda = true;
break;
}
}
if (hasLambda) {
final SqlArray inputArray = (SqlArray) type;
addLambdaInputType(inputArray.getItemType());
} else if (type instanceof SqlMap) {
final SqlMap inputMap = (SqlMap) type;
addLambdaInputType(inputMap.getKeyType());
addLambdaInputType(inputMap.getValueType());
} else {
addLambdaInputType(type);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

public class ExpressionTypeManager {
Expand Down Expand Up @@ -474,18 +475,18 @@ public Void visitFunctionCall(
final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName());

final List<SqlArgument> argTypes = new ArrayList<>();

for (final Expression expression : node.getArguments()) {
process(expression, expressionTypeContext);
final SqlType resolvedArgType = expressionTypeContext.getSqlType();
// for lambdas - if we find an array or map passed in before encountering a lambda function
// we save the type information to resolve the lambda generics
expressionTypeContext.visitType(resolvedArgType);
final TypeContext childContext = expressionTypeContext.getCopy();
process(expression, childContext);
final SqlType resolvedArgType = childContext.getSqlType();
if (expression instanceof LambdaFunctionCall) {
argTypes.add(
SqlArgument.of(
SqlLambda.of(expressionTypeContext.getLambdaInputTypes(), resolvedArgType)));
argTypes.add(SqlArgument.of(SqlLambda.of(expressionTypeContext.getLambdaInputTypes(), childContext.getSqlType())));
} else {
argTypes.add(SqlArgument.of(resolvedArgType));
// for lambdas - if we find an array or map passed in before encountering a lambda function
// we save the type information to resolve the lambda generics
expressionTypeContext.visitType(node, resolvedArgType);
}
}

Expand Down Expand Up @@ -603,4 +604,13 @@ private Optional<SqlType> validateWhenClauses(
return previousResult;
}
}

private boolean hasLambdaFunctionCall(FunctionCall node) {
for (Expression e : node.getArguments()) {
if (e instanceof LambdaFunctionCall) {
return true;
}
}
return false;
}
}

0 comments on commit 5049d52

Please sign in to comment.