Skip to content

Commit

Permalink
additional unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenpyzhang authored and lct45 committed Feb 18, 2021
1 parent 55e968d commit 4264da9
Show file tree
Hide file tree
Showing 13 changed files with 463 additions and 36 deletions.
2 changes: 1 addition & 1 deletion ksqldb-common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
<dependency>
<groupId>io.confluent</groupId>
<artifactId>kafka-connect-avro-converter</artifactId>
<version>6.2.0-363</version>
<version>${io.confluent.schema-registry.version}</version>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,13 @@ public static Map<GenericType, SqlType> resolveGenerics(
final SqlType old = mapping.putIfAbsent(entry.getKey(), entry.getValue());
if (old != null && !old.equals(entry.getValue())) {
throw new KsqlException(String.format(
"Found invalid instance of generic schema. Cannot map %s to both %s and %s",
"Found invalid instance of generic schema when mapping %s to %s. "
+ "Cannot map %s to both %s and %s",
schema,
instance,
entry.getKey(),
old,
instance));
entry.getValue()));
}
}

Expand Down Expand Up @@ -213,10 +216,11 @@ private static boolean resolveGenerics(
}

if (schema instanceof ArrayType) {
final SqlArray sqlArray = (SqlArray) sqlType;
return resolveGenerics(
mapping,
((ArrayType) schema).element(),
SqlArgument.of(((SqlArray) sqlType).getItemType()));
SqlArgument.of(sqlArray.getItemType()));
}

if (schema instanceof MapType) {
Expand All @@ -236,7 +240,7 @@ private static boolean resolveGenerics(
boolean resolvedInputs = true;
if (sqlLambda.getInputType().size() != lambdaType.inputTypes().size()) {
throw new KsqlException(
"Number of lambda arguments don't match between schema and sql type");
"Number of lambda arguments doesn't match between schema and sql type");
}

int i = 0;
Expand All @@ -258,6 +262,8 @@ resolvedInputs && resolveGenerics(
private static boolean matches(final ParamType schema, final SqlArgument instance) {
if (schema instanceof LambdaType && instance.getSqlLambda() != null) {
return true;
} else if (schema instanceof LambdaType || instance.getSqlLambda() != null) {
return false;
}
final ParamType instanceParamType = SchemaConverters
.sqlToFunctionConverter().toFunctionType(instance.getSqlType());
Expand All @@ -272,7 +278,7 @@ private static boolean matches(final ParamType schema, final SqlArgument instanc
public static boolean instanceOf(final ParamType schema, final SqlArgument instance) {
final List<Entry<GenericType, SqlType>> mappings = new ArrayList<>();

if (!resolveGenerics(mappings, schema, SqlArgument.of(instance.getSqlType()))) {
if (!resolveGenerics(mappings, schema, instance)) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.google.common.collect.Iterables;
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.GenericType;
import io.confluent.ksql.function.types.LambdaType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.types.ParamTypes;
import io.confluent.ksql.schema.ksql.SqlArgument;
Expand Down Expand Up @@ -199,7 +198,18 @@ private KsqlException createNoMatchingFunctionException(final List<SqlArgument>
LOG.debug("Current UdfIndex:\n{}", describe());

final String requiredTypes = paramTypes.stream()
.map(type -> type == null ? "null" : type.getSqlType().toString(FormatOptions.noEscape()))
.map(argument -> {
if (argument == null) {
return "null";
} else {
final SqlType sqlType = argument.getSqlType();
if (sqlType != null) {
return sqlType.toString(FormatOptions.noEscape());
} else {
return argument.getSqlLambda().toString();
}
}
})
.collect(Collectors.joining(", ", "(", ")"));

final String acceptedTypes = allFunctions.values().stream()
Expand Down Expand Up @@ -369,8 +379,7 @@ private static boolean reserveGenerics(
final SqlArgument argument,
final Map<GenericType, SqlType> reservedGenerics
) {
if (!(schema instanceof LambdaType)
&& !GenericsUtil.instanceOf(schema, argument)) {
if (!GenericsUtil.instanceOf(schema, argument)) {
return false;
}
final Map<GenericType, SqlType> genericMapping =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.google.common.collect.ImmutableList;
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.GenericType;
import io.confluent.ksql.function.types.LambdaType;
import io.confluent.ksql.function.types.MapType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.types.ParamTypes;
Expand All @@ -21,6 +22,7 @@
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlLambda;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.util.KsqlConfig;
Expand All @@ -45,9 +47,14 @@ public class UdfIndexTest {
private static final ParamType STRUCT2 = StructType.builder().field("b", INT).build();
private static final ParamType MAP1 = MapType.of(STRING, STRING);
private static final ParamType MAP2 = MapType.of(INT, INT);
private static final ParamType LAMBDA_FUNCTION = LambdaType.of(ImmutableList.of(GenericType.of("A")), GenericType.of("B"));
private static final ParamType LAMBDA_BI_FUNCTION = LambdaType.of(ImmutableList.of(GenericType.of("A"), GenericType.of("B")), GenericType.of("C"));
private static final ParamType LAMBDA_BI_FUNCTION_STRING = LambdaType.of(ImmutableList.of(STRING, STRING), GenericType.of("A"));

private static final ParamType GENERIC_LIST = ArrayType.of(GenericType.of("T"));
private static final ParamType GENERIC_MAP = MapType.of(GenericType.of("A"), GenericType.of("B"));

private static final SqlType ARRAY_ARG = SqlTypes.array(INTEGER);
private static final SqlType MAP1_ARG = SqlTypes.map(SqlTypes.STRING, SqlTypes.STRING);
private static final SqlType DECIMAL1_ARG = SqlTypes.decimal(4, 2);

Expand Down Expand Up @@ -231,6 +238,127 @@ public void shouldChooseCorrectMap() {
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldChooseCorrectLambdaFunction() {
// Given:
givenFunctions(
function(EXPECTED, false, GENERIC_LIST, LAMBDA_FUNCTION)
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(ARRAY_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.STRING),
INTEGER))));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldChooseCorrectLambdaBiFunction() {
// Given:
givenFunctions(
function(EXPECTED, false, GENERIC_MAP, LAMBDA_BI_FUNCTION)
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(MAP1_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.STRING, SqlTypes.STRING),
INTEGER))));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldChooseCorrectLambdaForTypeSpecificCollections() {
// Given:
givenFunctions(
function(EXPECTED, false, MAP1, LAMBDA_BI_FUNCTION_STRING)
);

// When:
final KsqlScalarFunction fun1 = udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(MAP1_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.STRING, SqlTypes.STRING),
SqlTypes.BOOLEAN))));

final KsqlScalarFunction fun2 = udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(MAP1_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.STRING, SqlTypes.STRING),
INTEGER))));

final Exception e = assertThrows(
Exception.class,
() -> udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(MAP1_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.BOOLEAN, INTEGER),
INTEGER))))
);

// Then:
assertThat(fun1.name(), equalTo(EXPECTED));
assertThat(fun2.name(), equalTo(EXPECTED));
assertThat(e.getMessage(), containsString("Valid alternatives are:"
+ lineSeparator()
+ "expected(MAP<VARCHAR, VARCHAR>, LAMBDA<[VARCHAR, VARCHAR], A>)"));
}

@Test
public void shouldThrowOnInvalidLambdaMapping() {
// Given:
givenFunctions(
function(OTHER, false, GENERIC_MAP, LAMBDA_BI_FUNCTION)
);

// When:
final Exception e1 = assertThrows(
Exception.class,
() -> udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(MAP1_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.BOOLEAN, SqlTypes.STRING),
INTEGER))))
);

final Exception e2 = assertThrows(
Exception.class,
() -> udfIndex.getFunction(
ImmutableList.of(
SqlArgument.of(MAP1_ARG),
SqlArgument.of(
SqlLambda.of(
ImmutableList.of(SqlTypes.STRING,SqlTypes.STRING, SqlTypes.STRING),
INTEGER)
)))
);

// Then:
assertThat(e1.getMessage(), containsString("Valid alternatives are:"
+ lineSeparator()
+ "other(MAP<A, B>, LAMBDA<[A, B], C>)"));
assertThat(e2.getMessage(), containsString("Number of lambda arguments don't match between schema and sql type"));
}

@Test
public void shouldAllowAnyDecimal() {
// Given:
Expand Down
Loading

0 comments on commit 4264da9

Please sign in to comment.