Skip to content

Commit

Permalink
chore: remove GenericRowValueTypeEnforcer (#3792)
Browse files Browse the repository at this point in the history
This change sees `GenericRowValueTypeEnforcer` removed from the code base.

It's being removed as the functionality it provides is:

1. Not needed in the production code path as values are always of the right type.
1. Only needed for a few tests to pass, (which have been fixed up).
1. It supports coercing between types which are outside of KSQLs implicit casting rules, e.g. coerce String -> BIGINT.

We should support implicitly casting parameters to functions to wider types,
as defined by KSQLs implicit casting rules. However, this is not currently the
case. See #3791.
  • Loading branch information
big-andy-coates authored Nov 7, 2019
1 parent a8cc010 commit 57fb964
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 737 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -543,9 +543,9 @@ public void shouldHandleArithmeticExpr() {
@Test
public void testCastNumericArithmeticExpressions() {
final Map<Integer, Object> inputValues =
ImmutableMap.of(0, 1, 3, 3, 4, 4, 5, 5);
ImmutableMap.of(0, 1L, 3, 3.0D, 4, 4.0D, 5, 5);

// INT64 - INT32
// INT - BIGINT
assertThat(executeExpression(
"SELECT "
+ "CAST((col5 - col0) AS INTEGER),"
Expand All @@ -555,7 +555,7 @@ public void testCastNumericArithmeticExpressions() {
+ "FROM codegen_test EMIT CHANGES;",
inputValues), contains(4, 4L, 4.0, "4"));

// FLOAT64 - FLOAT64
// DOUBLE - DOUBLE
assertThat(executeExpression(
"SELECT "
+ "CAST((col4 - col3) AS INTEGER),"
Expand All @@ -565,7 +565,7 @@ public void testCastNumericArithmeticExpressions() {
+ "FROM codegen_test EMIT CHANGES;",
inputValues), contains(1, 1L, 1.0, "1.0"));

// FLOAT64 - INT64
// DOUBLE - INT
assertThat(executeExpression(
"SELECT "
+ "CAST((col4 - col0) AS INTEGER),"
Expand Down Expand Up @@ -594,7 +594,7 @@ public void shouldHandleMathUdfs() {
final String query =
"SELECT FLOOR(col3), CEIL(col3*3), ABS(col0+1.34), ROUND(col3*2)+12 FROM codegen_test EMIT CHANGES;";

final Map<Integer, Object> inputValues = ImmutableMap.of(0, 15, 3, 1.5);
final Map<Integer, Object> inputValues = ImmutableMap.of(0, 15L, 3, 1.5);

// When:
final List<Object> columns = executeExpression(query, inputValues);
Expand All @@ -607,7 +607,7 @@ public void shouldHandleMathUdfs() {
public void shouldHandleRandomUdf() {
// Given:
final String query = "SELECT RANDOM()+10, RANDOM()+col0 FROM codegen_test EMIT CHANGES;";
final Map<Integer, Object> inputValues = ImmutableMap.of(0, 15);
final Map<Integer, Object> inputValues = ImmutableMap.of(0, 15L);

// When:
final List<Object> columns = executeExpression(query, inputValues);
Expand Down Expand Up @@ -825,7 +825,7 @@ public void shouldHandleFunctionWithNullArgument() {
final String query =
"SELECT test_udf(col0, NULL) FROM codegen_test EMIT CHANGES;";

final Map<Integer, Object> inputValues = ImmutableMap.of(0, 0);
final Map<Integer, Object> inputValues = ImmutableMap.of(0, 0L);
final List<Object> columns = executeExpression(query, inputValues);
// test
assertThat(columns, equalTo(Collections.singletonList("doStuffLongString")));
Expand All @@ -836,7 +836,7 @@ public void shouldHandleFunctionWithVarargs() {
final String query =
"SELECT test_udf(col0, col0, col0, col0, col0) FROM codegen_test EMIT CHANGES;";

final Map<Integer, Object> inputValues = ImmutableMap.of(0, 0);
final Map<Integer, Object> inputValues = ImmutableMap.of(0, 0L);
final List<Object> columns = executeExpression(query, inputValues);
// test
assertThat(columns, equalTo(Collections.singletonList("doStuffLongVarargs")));
Expand All @@ -860,7 +860,7 @@ public void shouldChoseFunctionWithCorrectNumberOfArgsWhenNullArgument() {
final String query =
"SELECT test_udf(col0, col0, NULL) FROM codegen_test EMIT CHANGES;";

final Map<Integer, Object> inputValues = ImmutableMap.of(0, 0);
final Map<Integer, Object> inputValues = ImmutableMap.of(0, 0L);
final List<Object> columns = executeExpression(query, inputValues);
// test
assertThat(columns, equalTo(Collections.singletonList("doStuffLongLongString")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ public void shouldSelectChosenColumns() {

// When:
final GenericRow transformed = selectMapper.apply(
genericRow(1521834663L, "key1", 1L, "hi", "bye", 2.0F, "blah"));
genericRow(1521834663L, "key1", 1L, "hi", "bye", 2.0D, "blah"));

// Then:
assertThat(transformed, is(genericRow(1L, "bye", 2.0F)));
assertThat(transformed, is(genericRow(1L, "bye", 2.0D)));
}

@Test
Expand All @@ -73,10 +73,10 @@ public void shouldApplyUdfsToColumns() {

// When:
final GenericRow row = selectMapper.apply(
genericRow(1521834663L, "key1", 2L, "foo", "whatever", 6.9F, "boo", "hoo"));
genericRow(1521834663L, "key1", 2L, "foo", "whatever", 6.9D, "boo", "hoo"));

// Then:
assertThat(row, is(genericRow(2L, "foo", "whatever", 7.0F)));
assertThat(row, is(genericRow(2L, "foo", "whatever", 7.0D)));
}

private SelectValueMapper givenSelectMapperFor(final String query) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ public void shouldUseConsistentOrderInPreAggSelectMapper() {
assertThat("invalid test", valueMappers, hasSize(greaterThanOrEqualTo(1)));
final ValueMapper preAggSelectMapper = valueMappers.get(0);
final GenericRow result = (GenericRow) preAggSelectMapper
.apply(new GenericRow("rowtime", "rowkey", "0", "1", "2", "3"));
.apply(new GenericRow("rowtime", "rowkey", 0L, "1", "2", 3.0D));
assertThat("should select col0, col1, col2, col3", result.getColumns(),
contains(0L, "1", "2", 3.0));
}
Expand All @@ -163,7 +163,7 @@ public void shouldUseConsistentOrderInPostAggSelectMapper() {
assertThat("invalid test", valueMappers, hasSize(greaterThanOrEqualTo(2)));
final ValueMapper postAggSelect = valueMappers.get(1);
final GenericRow result = (GenericRow) postAggSelect
.apply(new GenericRow("0", "-1", "2", "3", "4"));
.apply(new GenericRow(0L, "-1", 2.0D, 3L, 4.0D));
assertThat("should select col0, agg1, agg2", result.getColumns(), contains(0L, 2.0, 3L, 4.0));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor;
import io.confluent.ksql.execution.function.udf.structfieldextractor.FetchFieldFromStruct;
import io.confluent.ksql.execution.util.ExpressionTypeManager;
import io.confluent.ksql.execution.util.GenericRowValueTypeEnforcer;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.function.KsqlScalarFunction;
import io.confluent.ksql.function.UdfFactory;
Expand Down Expand Up @@ -111,7 +110,6 @@ public ExpressionMetadata buildCodeGenFromParseTree(Expression expression, Strin
ee,
spec,
expressionType,
new GenericRowValueTypeEnforcer(schema),
expression
);
} catch (KsqlException | CompileException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.google.common.collect.ImmutableMap;
import com.google.errorprone.annotations.Immutable;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.util.GenericRowValueTypeEnforcer;
import io.confluent.ksql.function.udf.Kudf;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.ksql.ColumnRef;
Expand Down Expand Up @@ -71,16 +70,13 @@ public String getUniqueNameForFunction(FunctionName functionName, int index) {
return names.get(index);
}

public void resolve(
GenericRow row, GenericRowValueTypeEnforcer typeEnforcer, Object[] parameters
) {
public void resolve(GenericRow row, Object[] parameters) {
for (int paramIdx = 0; paramIdx < arguments.size(); paramIdx++) {
ArgumentSpec spec = arguments.get(paramIdx);

if (spec.colIndex().isPresent()) {
int colIndex = spec.colIndex().getAsInt();
parameters[paramIdx] = typeEnforcer
.enforceColumnType(colIndex, row.getColumns().get(colIndex));
parameters[paramIdx] = row.getColumns().get(colIndex);
} else {
int copyOfParamIdxForLambda = paramIdx;
parameters[paramIdx] = spec.kudf()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.codegen.CodeGenSpec.ArgumentSpec;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.util.GenericRowValueTypeEnforcer;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.util.KsqlException;
import java.lang.reflect.InvocationTargetException;
Expand All @@ -32,18 +31,18 @@ public class ExpressionMetadata {

private final IExpressionEvaluator expressionEvaluator;
private final SqlType expressionType;
private final GenericRowValueTypeEnforcer typeEnforcer;
private final ThreadLocal<Object[]> threadLocalParameters;
private final Expression expression;
private final CodeGenSpec spec;

public ExpressionMetadata(
IExpressionEvaluator expressionEvaluator, CodeGenSpec spec, SqlType expressionType,
GenericRowValueTypeEnforcer typeEnforcer, Expression expression
IExpressionEvaluator expressionEvaluator,
CodeGenSpec spec,
SqlType expressionType,
Expression expression
) {
this.expressionEvaluator = Objects.requireNonNull(expressionEvaluator, "expressionEvaluator");
this.expressionType = Objects.requireNonNull(expressionType, "expressionType");
this.typeEnforcer = Objects.requireNonNull(typeEnforcer, "typeEnforcer");
this.expression = Objects.requireNonNull(expression, "expression");
this.spec = Objects.requireNonNull(spec, "spec");
this.threadLocalParameters = ThreadLocal.withInitial(() -> new Object[spec.arguments().size()]);
Expand Down Expand Up @@ -80,7 +79,7 @@ public Object evaluate(GenericRow row) {

private Object[] getParameters(GenericRow row) {
Object[] parameters = this.threadLocalParameters.get();
spec.resolve(row, typeEnforcer, parameters);
spec.resolve(row, parameters);
return parameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import io.confluent.ksql.execution.codegen.SqlToJavaVisitor;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.util.EngineProcessingLogMessageFactory;
import io.confluent.ksql.execution.util.GenericRowValueTypeEnforcer;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.logging.processing.ProcessingLogger;
import io.confluent.ksql.schema.ksql.LogicalSchema;
Expand All @@ -39,16 +38,17 @@ public final class SqlPredicate {

private final Expression filterExpression;
private final IExpressionEvaluator ee;
private final GenericRowValueTypeEnforcer genericRowValueTypeEnforcer;
private final ProcessingLogger processingLogger;
private final CodeGenSpec spec;

public SqlPredicate(
Expression filterExpression, LogicalSchema schema, KsqlConfig ksqlConfig,
FunctionRegistry functionRegistry, ProcessingLogger processingLogger
Expression filterExpression,
LogicalSchema schema,
KsqlConfig ksqlConfig,
FunctionRegistry functionRegistry,
ProcessingLogger processingLogger
) {
this.filterExpression = requireNonNull(filterExpression, "filterExpression");
this.genericRowValueTypeEnforcer = new GenericRowValueTypeEnforcer(schema);
this.processingLogger = requireNonNull(processingLogger);

CodeGenRunner codeGenRunner = new CodeGenRunner(schema, ksqlConfig, functionRegistry);
Expand Down Expand Up @@ -86,7 +86,7 @@ public <K> Predicate<K, GenericRow> getPredicate() {

try {
Object[] values = new Object[spec.arguments().size()];
spec.resolve(row, genericRowValueTypeEnforcer, values);
spec.resolve(row, values);
return (Boolean) ee.evaluate(values);
} catch (Exception e) {
logProcessingError(e, row);
Expand Down

This file was deleted.

Loading

0 comments on commit 57fb964

Please sign in to comment.