Skip to content

Commit

Permalink
fix: Improve error message for where/having type errors (#5023)
Browse files Browse the repository at this point in the history
* fix: Improve error message for where/having type errors
  • Loading branch information
AlanConfluent authored Apr 15, 2020
1 parent 64dd39e commit 23eb80d
Show file tree
Hide file tree
Showing 12 changed files with 422 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2020 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.analyzer;

import static java.util.Objects.requireNonNull;

import io.confluent.ksql.engine.rewrite.StatementRewriteForMagicPseudoTimestamp;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.util.ExpressionTypeManager;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.schema.ksql.FormatOptions;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.KsqlStatementException;

/**
* Validates types used in filtering statements.
*/
public final class FilterTypeValidator {

private final LogicalSchema schema;
private final FunctionRegistry functionRegistry;
private final FilterType filterType;

public FilterTypeValidator(
final LogicalSchema schema,
final FunctionRegistry functionRegistry,
final FilterType filterType
) {
this.schema = requireNonNull(schema, "schema");
this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry");
this.filterType = requireNonNull(filterType, "filterType");
}

/**
* Validates the given filter expression.
*/
public void validateFilterExpression(final Expression exp) {
final SqlType type = getExpressionReturnType(exp);
if (!SqlTypes.BOOLEAN.equals(type)) {
throw new KsqlException("Type error in " + filterType.name() + " expression: "
+ "Should evaluate to boolean but is " + exp.toString()
+ " (" + type.toString(FormatOptions.none()) + ") instead.");
}
}

private SqlType getExpressionReturnType(
final Expression exp
) {
final ExpressionTypeManager expressionTypeManager = new ExpressionTypeManager(schema,
functionRegistry);

// Rewrite the expression with magic timestamps, so type checking can pass
final Expression magicTimestampRewrite =
new StatementRewriteForMagicPseudoTimestamp().rewrite(exp);

try {
return expressionTypeManager.getExpressionSqlType(magicTimestampRewrite);
} catch (KsqlException e) {
throw new KsqlStatementException("Error in " + filterType.name() + " expression: "
+ e.getMessage(), exp.toString());
}
}

// The expression type being validated.
public enum FilterType {
WHERE,
HAVING
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.confluent.ksql.analyzer.Analysis.AliasedDataSource;
import io.confluent.ksql.analyzer.Analysis.Into;
import io.confluent.ksql.analyzer.Analysis.JoinInfo;
import io.confluent.ksql.analyzer.FilterTypeValidator;
import io.confluent.ksql.analyzer.FilterTypeValidator.FilterType;
import io.confluent.ksql.analyzer.ImmutableAnalysis;
import io.confluent.ksql.analyzer.RewrittenAnalysis;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter;
Expand Down Expand Up @@ -229,6 +231,15 @@ private AggregateNode buildAggregateNode(final PlanNode sourcePlanNode) {
refRewriter::process
);

if (analysis.getHavingExpression().isPresent()) {
final FilterTypeValidator validator = new FilterTypeValidator(
sourcePlanNode.getSchema(),
functionRegistry,
FilterType.HAVING);

validator.validateFilterExpression(analysis.getHavingExpression().get());
}

return new AggregateNode(
new PlanNodeId("Aggregate"),
sourcePlanNode,
Expand Down Expand Up @@ -327,10 +338,17 @@ private Stream<SelectExpression> resolveSelectItem(
"Unsupported SelectItem type: " + selectItem.getClass().getName());
}

private static FilterNode buildFilterNode(
private FilterNode buildFilterNode(
final PlanNode sourcePlanNode,
final Expression filterExpression
) {
final FilterTypeValidator validator = new FilterTypeValidator(
sourcePlanNode.getSchema(),
functionRegistry,
FilterType.WHERE);

validator.validateFilterExpression(filterExpression);

return new FilterNode(new PlanNodeId("WhereFilter"), sourcePlanNode, filterExpression);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Copyright 2020 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/


package io.confluent.ksql.analyzer;

import static io.confluent.ksql.schema.ksql.Column.Namespace.VALUE;
import static io.confluent.ksql.schema.ksql.types.SqlTypes.INTEGER;
import static io.confluent.ksql.schema.ksql.types.SqlTypes.STRING;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

import io.confluent.ksql.analyzer.FilterTypeValidator.FilterType;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression.Type;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.util.KsqlException;
import java.util.Optional;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;

@RunWith(MockitoJUnitRunner.class)
public class FilterTypeValidatorTest {
private static final ColumnName COLUMN1 = ColumnName.of("col1");
private static final ColumnName COLUMN2 = ColumnName.of("col2");

@Mock
private FunctionRegistry functionRegistry;
@Mock
private LogicalSchema schema;

private FilterTypeValidator validator;

@Before
public void setUp() {
validator = new FilterTypeValidator(schema, functionRegistry, FilterType.WHERE);
}


@Test
public void shouldThrowOnBadTypeComparison() {
// Given:
final Expression left = new UnqualifiedColumnReferenceExp(COLUMN1);
final Expression right = new IntegerLiteral(10);

final Expression comparision = new ComparisonExpression(Type.EQUAL, left, right);

when(schema.findValueColumn(any()))
.thenReturn(Optional.of(Column.of(COLUMN1, STRING, VALUE, 10)));

// When:
assertThrows("Error in WHERE expression: "
+ "Cannot compare col1 (STRING) to 10 (INTEGER) with EQUAL.",
KsqlException.class,
() -> validator.validateFilterExpression(comparision));
}

@Test
public void shouldNotThrowOnGoodTypeComparison() {
// Given:
final Expression left = new UnqualifiedColumnReferenceExp(COLUMN1);
final Expression right = new IntegerLiteral(10);

final Expression comparision = new ComparisonExpression(Type.EQUAL, left, right);

when(schema.findValueColumn(any()))
.thenReturn(Optional.of(Column.of(COLUMN1, INTEGER, VALUE, 10)));

// When:
validator.validateFilterExpression(comparision);
}

@Test
public void shouldThrowOnBadTypeComparison_twoVars() {
// Given:
final Expression left = new UnqualifiedColumnReferenceExp(COLUMN1);
final Expression right = new UnqualifiedColumnReferenceExp(COLUMN2);

final Expression comparision = new ComparisonExpression(Type.EQUAL, left, right);

when(schema.findValueColumn(COLUMN1))
.thenReturn(Optional.of(Column.of(COLUMN1, STRING, VALUE, 10)));
when(schema.findValueColumn(COLUMN2))
.thenReturn(Optional.of(Column.of(COLUMN2, INTEGER, VALUE, 10)));

// When:
assertThrows("Error in WHERE expression: "
+ "Cannot compare col1 (STRING) to col2 (INTEGER) with EQUAL.",
KsqlException.class,
() -> validator.validateFilterExpression(comparision));
}

@Test
public void shouldThrowOnBadType() {
// Given:
final Expression literal = new IntegerLiteral(10);

// When:
assertThrows("Type error in WHERE expression: "
+ "Should evaluate to boolean but is 10 (INTEGER) instead.",
KsqlException.class,
() -> validator.validateFilterExpression(literal));
}

@Test
public void shouldThrowOnBadTypeCompoundComparison_leftError() {
// Given:
final Expression left1 = new UnqualifiedColumnReferenceExp(COLUMN1);
final Expression right1 = new UnqualifiedColumnReferenceExp(COLUMN2);
final Expression comparision1 = new ComparisonExpression(Type.EQUAL, left1, right1);

final Expression left2 = new UnqualifiedColumnReferenceExp(COLUMN1);
final Expression right2 = new StringLiteral("foo");
final Expression comparision2 = new ComparisonExpression(Type.EQUAL, left2, right2);

final Expression expression = new LogicalBinaryExpression(LogicalBinaryExpression.Type.AND,
comparision1, comparision2);

when(schema.findValueColumn(COLUMN1))
.thenReturn(Optional.of(Column.of(COLUMN1, STRING, VALUE, 10)));
when(schema.findValueColumn(COLUMN2))
.thenReturn(Optional.of(Column.of(COLUMN2, INTEGER, VALUE, 10)));

// When:
assertThrows("Error in WHERE expression: "
+ "Cannot compare col1 (STRING) to col2 (INTEGER) with EQUAL.",
KsqlException.class,
() -> validator.validateFilterExpression(expression));
}

@Test
public void shouldThrowOnBadTypeCompoundComparison_rightError() {
// Given:
final Expression left1 = new UnqualifiedColumnReferenceExp(COLUMN2);
final Expression right1 = new IntegerLiteral(10);
final Expression comparision1 = new ComparisonExpression(Type.EQUAL, left1, right1);

final Expression left2 = new UnqualifiedColumnReferenceExp(COLUMN1);
final Expression right2 = new UnqualifiedColumnReferenceExp(COLUMN2);
final Expression comparision2 = new ComparisonExpression(Type.EQUAL, left2, right2);

final Expression expression = new LogicalBinaryExpression(LogicalBinaryExpression.Type.AND,
comparision1, comparision2);

when(schema.findValueColumn(COLUMN1))
.thenReturn(Optional.of(Column.of(COLUMN1, STRING, VALUE, 10)));
when(schema.findValueColumn(COLUMN2))
.thenReturn(Optional.of(Column.of(COLUMN2, INTEGER, VALUE, 10)));

// When:
assertThrows("Error in WHERE expression: "
+ "Cannot compare col1 (STRING) to col2 (INTEGER) with EQUAL.",
KsqlException.class,
() -> validator.validateFilterExpression(expression));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public void testSimpleLeftJoinFilterLogicalPlan() {
final String
simpleQuery =
"SELECT t1.col1, t2.col1, col5, t2.col4, t2.col2 FROM test1 t1 LEFT JOIN test2 t2 ON "
+ "t1.col0 = t2.col0 WHERE t1.col1 > 10 AND t2.col4 = 10.8 EMIT CHANGES;";
+ "t1.col0 = t2.col0 WHERE t1.col3 > 10.8 AND t2.col2 = 'foo' EMIT CHANGES;";
final PlanNode logicalPlan = buildLogicalPlan(simpleQuery);

assertThat(logicalPlan.getSources().get(0), instanceOf(ProjectNode.class));
Expand All @@ -125,7 +125,7 @@ public void testSimpleLeftJoinFilterLogicalPlan() {

assertThat(projectNode.getSources().get(0), instanceOf(FilterNode.class));
final FilterNode filterNode = (FilterNode) projectNode.getSources().get(0);
assertThat(filterNode.getPredicate().toString(), equalTo("((T1_COL1 > 10) AND (T2_COL4 = 10.8))"));
assertThat(filterNode.getPredicate().toString(), equalTo("((T1_COL3 > 10.8) AND (T2_COL2 = 'foo'))"));

assertThat(filterNode.getSources().get(0), instanceOf(JoinNode.class));
final JoinNode joinNode = (JoinNode) filterNode.getSources().get(0);
Expand Down Expand Up @@ -284,7 +284,7 @@ public void shouldCreateStreamOutputForStreamTableJoin() {
final String
simpleQuery =
"SELECT t1.col1, t2.col1, col5, t2.col4, t2.col2 FROM test1 t1 LEFT JOIN test2 t2 ON "
+ "t1.col0 = t2.col0 WHERE t1.col1 > 10 AND t2.col4 = 10.8 EMIT CHANGES;";
+ "t1.col0 = t2.col0 WHERE t1.col3 > 10.8 AND t2.col2 = 'foo' EMIT CHANGES;";
final PlanNode logicalPlan = buildLogicalPlan(simpleQuery);
assertThat(logicalPlan.getNodeOutputType(), equalTo(DataSourceType.KSTREAM));
}
Expand All @@ -300,7 +300,7 @@ public void shouldCreateStreamOutputForStreamFilter() {
@Test
public void shouldCreateTableOutputForTableFilter() {
final String
simpleQuery = "SELECT * FROM test2 WHERE col4 = 10.8 EMIT CHANGES;";
simpleQuery = "SELECT * FROM test2 WHERE col2 = 'foo' EMIT CHANGES;";
final PlanNode logicalPlan = buildLogicalPlan(simpleQuery);
assertThat(logicalPlan.getNodeOutputType(), equalTo(DataSourceType.KTABLE));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,18 @@ final class ComparisonUtil {
private ComparisonUtil() {
}

static void isValidComparison(
static boolean isValidComparison(
final SqlType left, final ComparisonExpression.Type operator, final SqlType right
) {
if (left == null || right == null) {
throw nullSchemaException(left, operator, right);
}

final boolean valid = HANDLERS.stream()
return HANDLERS.stream()
.filter(h -> h.handles.test(left.baseType()))
.findFirst()
.map(h -> h.validator.test(operator, right))
.orElse(false);

if (!valid) {
throw new KsqlException(
"Operator " + operator + " cannot be used to compare "
+ left.baseType() + " and " + right.baseType()
);
}
}

private static KsqlException nullSchemaException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ public Void visitComparisonExpression(
final SqlType leftSchema = expressionTypeContext.getSqlType();
process(node.getRight(), expressionTypeContext);
final SqlType rightSchema = expressionTypeContext.getSqlType();
ComparisonUtil.isValidComparison(leftSchema, node.getType(), rightSchema);
if (!ComparisonUtil.isValidComparison(leftSchema, node.getType(), rightSchema)) {
throw new KsqlException("Cannot compare "
+ node.getLeft().toString() + " (" + leftSchema.toString() + ") to "
+ node.getRight().toString() + " (" + rightSchema.toString() + ") "
+ "with " + node.getType() + ".");
}
expressionTypeContext.setSqlType(SqlTypes.BOOLEAN);
return null;
}
Expand Down
Loading

0 comments on commit 23eb80d

Please sign in to comment.