Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add lambda type checking without adding lambda sql type #6966

Merged
merged 6 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import com.google.common.collect.Sets;
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.GenericType;
import io.confluent.ksql.function.types.LambdaType;
import io.confluent.ksql.function.types.MapType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.types.StructType;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlLambda;
import io.confluent.ksql.schema.ksql.types.SqlMap;
import io.confluent.ksql.schema.ksql.types.SqlStruct.Builder;
import io.confluent.ksql.schema.ksql.types.SqlType;
Expand All @@ -35,6 +37,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
Expand Down Expand Up @@ -77,6 +80,14 @@ public static Set<ParamType> constituentGenerics(final ParamType type) {
.collect(Collectors.toSet());
} else if (type instanceof GenericType) {
return ImmutableSet.of(type);
} else if (type instanceof LambdaType) {
final Set<ParamType> inputSet = new HashSet<>();
for (final ParamType paramType: ((LambdaType) type).inputTypes()) {
inputSet.addAll(constituentGenerics(paramType));
}
return Sets.union(
inputSet,
constituentGenerics(((LambdaType) type).returnType()));
} else {
return ImmutableSet.of();
}
Expand Down Expand Up @@ -163,22 +174,27 @@ public static Map<GenericType, SqlType> resolveGenerics(
final SqlType old = mapping.putIfAbsent(entry.getKey(), entry.getValue());
if (old != null && !old.equals(entry.getValue())) {
throw new KsqlException(String.format(
"Found invalid instance of generic schema. Cannot map %s to both %s and %s",
"Found invalid instance of generic schema when mapping %s to %s. "
+ "Cannot map %s to both %s and %s",
schema,
instance,
entry.getKey(),
old,
instance));
entry.getValue()));
}
}

return ImmutableMap.copyOf(mapping);
}

// CHECKSTYLE_RULES.OFF: NPathComplexity
// CHECKSTYLE_RULES.OFF: CyclomaticComplexity
private static boolean resolveGenerics(
final List<Entry<GenericType, SqlType>> mapping,
final ParamType schema,
final SqlArgument instance
) {
// CHECKSTYLE_RULES.ON: NPathComplexity
// CHECKSTYLE_RULES.ON: CyclomaticComplexity
final SqlType sqlType = instance.getSqlType();

Expand All @@ -202,7 +218,9 @@ private static boolean resolveGenerics(
if (schema instanceof ArrayType) {
final SqlArray sqlArray = (SqlArray) sqlType;
return resolveGenerics(
mapping, ((ArrayType) schema).element(), SqlArgument.of(sqlArray.getItemType()));
mapping,
((ArrayType) schema).element(),
SqlArgument.of(sqlArray.getItemType()));
}

if (schema instanceof MapType) {
Expand All @@ -216,10 +234,37 @@ private static boolean resolveGenerics(
throw new KsqlException("Generic STRUCT is not yet supported");
}

if (schema instanceof LambdaType) {
final LambdaType lambdaType = (LambdaType) schema;
final SqlLambda sqlLambda = instance.getSqlLambda();
boolean resolvedInputs = true;
if (sqlLambda.getInputType().size() != lambdaType.inputTypes().size()) {
throw new KsqlException(
"Number of lambda arguments doesn't match between schema and sql type");
}

int i = 0;
for (final ParamType paramType : lambdaType.inputTypes()) {
resolvedInputs =
resolvedInputs && resolveGenerics(
lct45 marked this conversation as resolved.
Show resolved Hide resolved
mapping, paramType, SqlArgument.of(sqlLambda.getInputType().get(i))
);
i++;
}
return resolvedInputs && resolveGenerics(
mapping, lambdaType.returnType(), SqlArgument.of(sqlLambda.getReturnType())
);
}

return true;
}

private static boolean matches(final ParamType schema, final SqlArgument instance) {
if (schema instanceof LambdaType && instance.getSqlLambda() != null) {
return true;
} else if (schema instanceof LambdaType || instance.getSqlLambda() != null) {
return false;
}
lct45 marked this conversation as resolved.
Show resolved Hide resolved
final ParamType instanceParamType = SchemaConverters
.sqlToFunctionConverter().toFunctionType(instance.getSqlType());
return schema.getClass() == instanceParamType.getClass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,18 @@ private KsqlException createNoMatchingFunctionException(final List<SqlArgument>
LOG.debug("Current UdfIndex:\n{}", describe());

final String requiredTypes = paramTypes.stream()
.map(type -> type == null ? "null" : type.getSqlType().toString(FormatOptions.noEscape()))
.map(argument -> {
if (argument == null) {
return "null";
} else {
final SqlType sqlType = argument.getSqlType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: I see @stevenpyzhang already commented about this below, I like his suggestion! Just keeping this comment here for historical purposes


I see the following method:

  public static SqlArgument of(final SqlType sqlType, final SqlLambda lambdaType) {
    return new SqlArgument(sqlType, lambdaType);
  }

What does it mean for a SqlArgument to have both a SqlType and a LambdaType? Can we update the docs to describe what this does? If it's not possible, we should enforce that in the code (make sure at most one of them is null).

if (sqlType != null) {
return sqlType.toString(FormatOptions.noEscape());
} else {
return argument.getSqlLambda().toString();
}
}
})
.collect(Collectors.joining(", ", "(", ")"));

final String acceptedTypes = allFunctions.values().stream()
Expand Down Expand Up @@ -351,7 +362,7 @@ public int hashCode() {
// CHECKSTYLE_RULES.OFF: BooleanExpressionComplexity
boolean accepts(final SqlArgument argument, final Map<GenericType, SqlType> reservedGenerics,
final boolean allowCasts) {
if (argument == null || argument.getSqlType() == null) {
if (argument == null || (argument.getSqlLambda() == null && argument.getSqlType() == null)) {
return true;
}

Expand All @@ -371,9 +382,8 @@ private static boolean reserveGenerics(
if (!GenericsUtil.instanceOf(schema, argument)) {
return false;
}

final Map<GenericType, SqlType> genericMapping = GenericsUtil
.resolveGenerics(schema, argument);
final Map<GenericType, SqlType> genericMapping =
GenericsUtil.resolveGenerics(schema, argument);

for (final Entry<GenericType, SqlType> entry : genericMapping.entrySet()) {
final SqlType old = reservedGenerics.putIfAbsent(entry.getKey(), entry.getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlBaseType;
import io.confluent.ksql.schema.ksql.types.SqlLambda;
import io.confluent.ksql.schema.ksql.types.SqlMap;
import io.confluent.ksql.schema.ksql.types.SqlStruct;
import io.confluent.ksql.schema.ksql.types.SqlStruct.Field;
Expand All @@ -44,12 +45,33 @@ public static boolean areCompatible(final SqlType actual, final ParamType declar
return areCompatible(SqlArgument.of(actual), declared, false);
}

// CHECKSTYLE_RULES.OFF: CyclomaticComplexity
// CHECKSTYLE_RULES.OFF: NPathComplexity
public static boolean areCompatible(
final SqlArgument argument,
final ParamType declared,
final boolean allowCast
) {
// CHECKSTYLE_RULES.ON: CyclomaticComplexity
// CHECKSTYLE_RULES.ON: NPathComplexity
final SqlType argumentSqlType = argument.getSqlType();
final SqlLambda sqlLambda = argument.getSqlLambda();

if (sqlLambda != null && declared instanceof LambdaType) {
final LambdaType declaredLambda = (LambdaType) declared;
if (sqlLambda.getInputType().size() != declaredLambda.inputTypes().size()) {
return false;
}
int i = 0;
for (final ParamType paramType: declaredLambda.inputTypes()) {
if (!areCompatible(sqlLambda.getInputType().get(i), paramType)) {
return false;
}
i++;
}
return areCompatible(sqlLambda.getReturnType(), declaredLambda.returnType());
}

if (argumentSqlType.baseType() == SqlBaseType.ARRAY && declared instanceof ArrayType) {
return areCompatible(
SqlArgument.of(((SqlArray) argumentSqlType).getItemType()),
Expand All @@ -61,7 +83,11 @@ public static boolean areCompatible(
final SqlMap sqlType = (SqlMap) argumentSqlType;
final MapType mapType = (MapType) declared;
return areCompatible(SqlArgument.of(sqlType.getKeyType()), mapType.key(), allowCast)
&& areCompatible(SqlArgument.of(sqlType.getValueType()), mapType.value(), allowCast);
&& areCompatible(
SqlArgument.of(sqlType.getValueType()),
lct45 marked this conversation as resolved.
Show resolved Hide resolved
mapType.value(),
allowCast
);
}

if (argumentSqlType.baseType() == SqlBaseType.STRUCT && declared instanceof StructType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.google.common.collect.ImmutableList;
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.GenericType;
import io.confluent.ksql.function.types.LambdaType;
import io.confluent.ksql.function.types.MapType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.types.ParamTypes;
Expand All @@ -21,6 +22,7 @@
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlLambda;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.util.KsqlConfig;
Expand All @@ -45,9 +47,14 @@ public class UdfIndexTest {
private static final ParamType STRUCT2 = StructType.builder().field("b", INT).build();
private static final ParamType MAP1 = MapType.of(STRING, STRING);
private static final ParamType MAP2 = MapType.of(INT, INT);
private static final ParamType LAMBDA_FUNCTION = LambdaType.of(ImmutableList.of(GenericType.of("A")), GenericType.of("B"));
private static final ParamType LAMBDA_BI_FUNCTION = LambdaType.of(ImmutableList.of(GenericType.of("A"), GenericType.of("B")), GenericType.of("C"));
private static final ParamType LAMBDA_BI_FUNCTION_STRING = LambdaType.of(ImmutableList.of(STRING, STRING), GenericType.of("A"));

private static final ParamType GENERIC_LIST = ArrayType.of(GenericType.of("T"));
private static final ParamType GENERIC_MAP = MapType.of(GenericType.of("A"), GenericType.of("B"));

private static final SqlType ARRAY_ARG = SqlTypes.array(INTEGER);
private static final SqlType MAP1_ARG = SqlTypes.map(SqlTypes.STRING, SqlTypes.STRING);
private static final SqlType DECIMAL1_ARG = SqlTypes.decimal(4, 2);

Expand Down Expand Up @@ -231,6 +238,127 @@ public void shouldChooseCorrectMap() {
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldChooseCorrectLambdaFunction() {
lct45 marked this conversation as resolved.
Show resolved Hide resolved
// Given:
givenFunctions(
function(EXPECTED, false, GENERIC_LIST, LAMBDA_FUNCTION)
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(ARRAY_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.STRING),
INTEGER))));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldChooseCorrectLambdaBiFunction() {
// Given:
givenFunctions(
function(EXPECTED, false, GENERIC_MAP, LAMBDA_BI_FUNCTION)
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(MAP1_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.STRING, SqlTypes.STRING),
INTEGER))));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldChooseCorrectLambdaForTypeSpecificCollections() {
// Given:
givenFunctions(
function(EXPECTED, false, MAP1, LAMBDA_BI_FUNCTION_STRING)
);

// When:
final KsqlScalarFunction fun1 = udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(MAP1_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.STRING, SqlTypes.STRING),
SqlTypes.BOOLEAN))));

final KsqlScalarFunction fun2 = udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(MAP1_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.STRING, SqlTypes.STRING),
INTEGER))));

final Exception e = assertThrows(
Exception.class,
() -> udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(MAP1_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.BOOLEAN, INTEGER),
INTEGER))))
);

// Then:
assertThat(fun1.name(), equalTo(EXPECTED));
assertThat(fun2.name(), equalTo(EXPECTED));
assertThat(e.getMessage(), containsString("Valid alternatives are:"
+ lineSeparator()
+ "expected(MAP<VARCHAR, VARCHAR>, LAMBDA<[VARCHAR, VARCHAR], A>)"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is shown to users, we might want to consider toString on lambdas to be something like (VARCHAR, VARCHAR) -> A instead of LAMBDA<[VARCHAR, VARCHAR], A> which is somewhat difficult to read

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should leave the LAMBDA identifier for the new toString? Just wondering if users will be able to easily identify if it's ((VARCHAR, VARCHAR) -> A

}

@Test
public void shouldThrowOnInvalidLambdaMapping() {
// Given:
givenFunctions(
function(OTHER, false, GENERIC_MAP, LAMBDA_BI_FUNCTION)
);

// When:
final Exception e1 = assertThrows(
Exception.class,
() -> udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(MAP1_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.BOOLEAN, SqlTypes.STRING),
INTEGER))))
);

final Exception e2 = assertThrows(
Exception.class,
() -> udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(MAP1_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.STRING,SqlTypes.STRING, SqlTypes.STRING),
INTEGER)
)))
);

// Then:
assertThat(e1.getMessage(), containsString("Valid alternatives are:"
+ lineSeparator()
+ "other(MAP<A, B>, LAMBDA<[A, B], C>)"));
assertThat(e2.getMessage(), containsString("Number of lambda arguments doesn't match between schema and sql type"));
lct45 marked this conversation as resolved.
Show resolved Hide resolved
}

@Test
public void shouldAllowAnyDecimal() {
// Given:
Expand Down Expand Up @@ -459,7 +587,7 @@ public void shouldChooseNonVarargWithNullValuesOfDifferingSchemas() {
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(Arrays.asList(SqlArgument.of(null), SqlArgument.of(null)));
final KsqlScalarFunction fun = udfIndex.getFunction(Arrays.asList(SqlArgument.of(null, null), SqlArgument.of(null, null)));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
Expand All @@ -474,7 +602,7 @@ public void shouldChooseNonVarargWithNullValuesOfSameSchemas() {
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(Arrays.asList(SqlArgument.of(null), SqlArgument.of(null)));
final KsqlScalarFunction fun = udfIndex.getFunction(Arrays.asList(SqlArgument.of(null, null), SqlArgument.of(null, null)));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
Expand Down
Loading