Skip to content

Commit

Permalink
fix: Pull query table scans support LIKE and BETWEEN operators (#8299)
Browse files Browse the repository at this point in the history
  • Loading branch information
hli21 authored Nov 5, 2021
1 parent 73ebeab commit bc3ea64
Show file tree
Hide file tree
Showing 5 changed files with 358 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.engine.rewrite.StatementRewriteForMagicPseudoTimestamp;
import io.confluent.ksql.execution.expression.tree.BetweenPredicate;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression.Type;
import io.confluent.ksql.execution.expression.tree.Expression;
Expand All @@ -36,10 +37,16 @@ private PullQueryRewriter() {
public static Expression rewrite(final Expression expression) {
final Expression pseudoTimestamp = new StatementRewriteForMagicPseudoTimestamp()
.rewrite(expression);
final Expression inPredicatesRemoved = rewriteInPredicates(pseudoTimestamp);
final Expression betweenPredicatesRemoved = rewriteBetweenPredicates(pseudoTimestamp);
final Expression inPredicatesRemoved = rewriteInPredicates(betweenPredicatesRemoved);
return LogicRewriter.rewriteDNF(inPredicatesRemoved);
}

public static Expression rewriteBetweenPredicates(final Expression expression) {
return new ExpressionTreeRewriter<>(new BetweenPredicateRewriter()::process)
.rewrite(expression, null);
}

public static Expression rewriteInPredicates(final Expression expression) {
return new ExpressionTreeRewriter<>(new InPredicateRewriter()::process)
.rewrite(expression, null);
Expand Down Expand Up @@ -79,4 +86,33 @@ public Optional<Expression> visitInPredicate(
throw new IllegalStateException("Shouldn't have an empty in predicate");
}
}

private static final class BetweenPredicateRewriter extends
VisitParentExpressionVisitor<Optional<Expression>, Context<Void>> {

@Override
public Optional<Expression> visitExpression(
final Expression node,
final Context<Void> context) {
return Optional.empty();
}

@Override
public Optional<Expression> visitBetweenPredicate(
final BetweenPredicate node,
final Context<Void> context
) {
final ComparisonExpression leftComparisonExpression = new ComparisonExpression(
node.getLocation(), Type.GREATER_THAN_OR_EQUAL, node.getValue(),
node.getMin());
final ComparisonExpression rightComparisonExpression = new ComparisonExpression(
node.getLocation(), Type.LESS_THAN_OR_EQUAL, node.getValue(),
node.getMax());
final Expression currentExpression = new LogicalBinaryExpression(
node.getLocation(), LogicalBinaryExpression.Type.AND, leftComparisonExpression,
rightComparisonExpression);

return Optional.of(currentExpression);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.confluent.ksql.execution.expression.tree.ComparisonExpression.Type;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.Literal;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
Expand All @@ -47,6 +48,7 @@
import io.confluent.ksql.schema.ksql.DefaultSqlValueCoercer;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.SystemColumns;
import io.confluent.ksql.schema.ksql.types.SqlBaseType;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.schema.utils.FormatOptions;
Expand Down Expand Up @@ -275,7 +277,8 @@ private final class Validator extends TraversalExpressionVisitor<Object> {
@Override
public Void process(final Expression node, final Object context) {
if (!(node instanceof LogicalBinaryExpression)
&& !(node instanceof ComparisonExpression)) {
&& !(node instanceof ComparisonExpression)
&& !(node instanceof LikePredicate)) {
throw invalidWhereClauseException("Unsupported expression in WHERE clause: " + node, false);
}
super.process(node, context);
Expand Down Expand Up @@ -358,6 +361,33 @@ public Void visitComparisonExpression(
}
}

@Override
public Void visitLikePredicate(final LikePredicate node, final Object context) {
if (node.getValue() instanceof UnqualifiedColumnReferenceExp) {
final UnqualifiedColumnReferenceExp column =
(UnqualifiedColumnReferenceExp) node.getValue();
final ColumnName columnName = column.getColumnName();
final Column col = schema.findColumn(columnName)
.orElseThrow(() -> invalidWhereClauseException(
"Like condition on non-existent column " + columnName, isWindowed));
if (SqlBaseType.STRING != col.type().baseType()) {
throw invalidWhereClauseException("The column type for Like "
+ "condition must be VARCHAR. The column type is "
+ col.type().baseType().toString(), isWindowed);
}
final Expression pattern = node.getPattern();
if (!(pattern instanceof StringLiteral || pattern instanceof NullLiteral)) {
throw invalidWhereClauseException(
"Like condition on non-string pattern " + pattern.getClass().getName(),
isWindowed);
}
} else {
setTableScanOrElseThrow(() -> invalidWhereClauseException("Like condition must be between "
+ "strings", isWindowed));
}
return null;
}

private boolean isKeyQuery(final ComparisonExpression node) {
if (node.getType() == Type.NOT_EQUAL || node.getType() == Type.IS_DISTINCT_FROM
|| node.getType() == Type.IS_NOT_DISTINCT_FROM) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ public void shouldRewriteInPredicate() {
+ "(ORDERS.ITEMID = 'c'))))");
}

@Test
public void shouldRewriteBetweenPredicate() {
assertRewrite("ORDERS", "ORDERID BETWEEN 2 AND 5",
"((ORDERS.ORDERID >= 2) AND (ORDERS.ORDERID <= 5))");
assertRewrite("ORDERS", "ORDERS.ITEMID BETWEEN 'a' AND 'b'",
"((ORDERS.ITEMID >= 'a') AND (ORDERS.ITEMID <= 'b'))");
}

private void assertRewrite(final String table, final String expressionStr,
final String expectedStr) {
Expression expression = getWhereExpression(table, expressionStr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.confluent.ksql.GenericKey;
import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression;
import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression.Sign;
import io.confluent.ksql.execution.expression.tree.BetweenPredicate;
import io.confluent.ksql.execution.expression.tree.BooleanLiteral;
import io.confluent.ksql.execution.expression.tree.BytesLiteral;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
Expand All @@ -37,6 +38,7 @@
import io.confluent.ksql.execution.expression.tree.InListExpression;
import io.confluent.ksql.execution.expression.tree.InPredicate;
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.NullLiteral;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
Expand Down Expand Up @@ -1623,4 +1625,55 @@ private void expectRangeScan(final Expression expression, final boolean windowed
assertThat(keys.get(0), isA((Class) KeyConstraint.class));
assertThat(((KeyConstraint) keys.get(0)).isRangeOperator(), is(true));
}

@Test
public void shouldThrowNonStringForLike() {
// Given:
final Expression expression = new LikePredicate(
new StringLiteral("a"),
new IntegerLiteral(10),
Optional.empty());

// When:
final KsqlException e = assertThrows(
KsqlException.class,
() -> new QueryFilterNode(
NODE_ID,
source,
expression,
metaStore,
ksqlConfig,
false,
plannerOptions
));

// Then:
assertThat(e.getMessage(), containsString("Like condition must be between strings"));
}

@Test
public void shouldThrowNotKeyColumnForBetween() {
// Given:
final Expression expression = new BetweenPredicate(
new StringLiteral("a"),
new StringLiteral("b"),
new IntegerLiteral(10)
);

// When:
final KsqlException e = assertThrows(
KsqlException.class,
() -> new QueryFilterNode(
NODE_ID,
source,
expression,
metaStore,
ksqlConfig,
false,
plannerOptions
));

// Then:
assertThat(e.getMessage(), containsString("A comparison must directly reference a key column"));
}
}
Loading

0 comments on commit bc3ea64

Please sign in to comment.