Skip to content

Commit

Permalink
feat(static): fail on ROWTIME in projection (#3430)
Browse files Browse the repository at this point in the history
* feat(static): fail on ROWTIME in projection

At the moment static queries do not support returning ROWTIME as this information is not available in the response for KS IQ.

In the future, we _may_ choose to support this by always including ROWTIME in the value of the changelog topic, but this is out of scope for this initial MVP.
  • Loading branch information
big-andy-coates authored Sep 27, 2019
1 parent 5f28ff5 commit 2f27b68
Showing 6 changed files with 157 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -40,7 +40,9 @@
import io.confluent.ksql.serde.SerdeOption;
import io.confluent.ksql.util.SchemaUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -56,6 +58,7 @@ public class Analysis {
private Optional<JoinInfo> joinInfo = Optional.empty();
private Optional<Expression> whereExpression = Optional.empty();
private final List<SelectExpression> selectExpressions = new ArrayList<>();
private final Set<ColumnRef> selectColumnRefs = new HashSet<>();
private final List<Expression> groupByExpressions = new ArrayList<>();
private Optional<WindowExpression> windowExpression = Optional.empty();
private Optional<ColumnName> partitionBy = Optional.empty();
@@ -76,6 +79,10 @@ void addSelectItem(final Expression expression, final ColumnName alias) {
selectExpressions.add(SelectExpression.of(alias, expression));
}

void addSelectColumnRefs(final Collection<ColumnRef> columnRefs) {
selectColumnRefs.addAll(columnRefs);
}

public Optional<Into> getInto() {
return into;
}
@@ -96,6 +103,10 @@ public List<SelectExpression> getSelectExpressions() {
return Collections.unmodifiableList(selectExpressions);
}

Set<ColumnRef> getSelectColumnRefs() {
return Collections.unmodifiableSet(selectColumnRefs);
}

public List<Expression> getGroupByExpressions() {
return ImmutableList.copyOf(groupByExpressions);
}
@@ -156,7 +167,7 @@ public List<AliasedDataSource> getFromDataSources() {
return ImmutableList.copyOf(fromDataSources);
}

public SourceSchemas getFromSourceSchemas() {
SourceSchemas getFromSourceSchemas() {
final Map<SourceName, LogicalSchema> schemaBySource = fromDataSources.stream()
.collect(Collectors.toMap(
AliasedDataSource::getAlias,
32 changes: 29 additions & 3 deletions ksql-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.windows.KsqlWindowExpression;
import io.confluent.ksql.metastore.MetaStore;
@@ -62,6 +63,7 @@
import io.confluent.ksql.serde.ValueFormat;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@@ -511,7 +513,7 @@ protected AstNode visitSelect(final Select node, final Void context) {
visitSelectStar((AllColumns) selectItem);
} else if (selectItem instanceof SingleColumn) {
final SingleColumn column = (SingleColumn) selectItem;
analysis.addSelectItem(column.getExpression(), column.getAlias());
addSelectItem(column.getExpression(), column.getAlias());
} else {
throw new IllegalArgumentException(
"Unsupported SelectItem type: " + selectItem.getClass().getName());
@@ -562,14 +564,19 @@ private void visitSelectStar(final AllColumns allColumns) {
? source.getAlias().name() + "_"
: "";

for (final Column column : source.getDataSource().getSchema().columns()) {
final LogicalSchema schema = source.getDataSource().getSchema();
for (final Column column : schema.columns()) {

if (staticQuery && schema.isMetaColumn(column.name())) {
continue;
}

final ColumnReferenceExp selectItem = new ColumnReferenceExp(location,
ColumnRef.of(source.getAlias(), column.name()));

final String alias = aliasPrefix + column.name().name();

analysis.addSelectItem(selectItem, ColumnName.of(alias));
addSelectItem(selectItem, ColumnName.of(alias));
}
}
}
@@ -598,6 +605,25 @@ public void validate() {
+ System.lineSeparator() + KAFKA_VALUE_FORMAT_LIMITATION_DETAILS);
}
}

private void addSelectItem(final Expression exp, final ColumnName columnName) {
final Set<ColumnRef> columnRefs = new HashSet<>();
final TraversalExpressionVisitor<Void> visitor = new TraversalExpressionVisitor<Void>() {
@Override
public Void visitColumnReference(
final ColumnReferenceExp node,
final Void context
) {
columnRefs.add(node.getReference());
return null;
}
};

visitor.process(exp, null);

analysis.addSelectItem(exp, columnName);
analysis.addSelectColumnRefs(columnRefs);
}
}

@FunctionalInterface
Original file line number Diff line number Diff line change
@@ -17,7 +17,9 @@

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.parser.tree.ResultMaterialization;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
@@ -89,6 +91,12 @@ public class StaticQueryValidator implements QueryValidator {
Rule.of(
analysis -> !analysis.getLimitClause().isPresent(),
"Static queries don't support LIMIT clauses."
),
Rule.of(
analysis -> analysis.getSelectColumnRefs().stream()
.map(ColumnRef::name)
.noneMatch(n -> n.equals(SchemaUtil.ROWTIME_NAME)),
"Static queries don't support ROWTIME in select columns."
)
);

Original file line number Diff line number Diff line change
@@ -16,8 +16,11 @@
package io.confluent.ksql.analyzer;

import static io.confluent.ksql.testutils.AnalysisTestUtil.analyzeQuery;
import static io.confluent.ksql.util.SchemaUtil.ROWTIME_NAME;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
@@ -35,6 +38,7 @@
import io.confluent.ksql.analyzer.Analyzer.SerdeOptionsSupplier;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.BooleanLiteral;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.Literal;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.plan.SelectExpression;
@@ -53,6 +57,7 @@
import io.confluent.ksql.parser.tree.Sink;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.planner.plan.JoinNode.JoinType;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.serde.Format;
@@ -90,6 +95,11 @@
public class AnalyzerFunctionalTest {

private static final Set<SerdeOption> DEFAULT_SERDE_OPTIONS = SerdeOption.none();
private static final SourceName TEST1 = SourceName.of("TEST1");
private static final ColumnName COL0 = ColumnName.of("COL0");
private static final ColumnName COL1 = ColumnName.of("COL1");
private static final ColumnName COL2 = ColumnName.of("COL2");
private static final ColumnName COL3 = ColumnName.of("COL3");

private MutableMetaStore jsonMetaStore;
private MutableMetaStore avroMetaStore;
@@ -136,17 +146,17 @@ public void testSimpleQueryAnalysis() {
final Analysis analysis = analyzeQuery(simpleQuery, jsonMetaStore);
assertEquals("FROM was not analyzed correctly.",
analysis.getFromDataSources().get(0).getDataSource().getName(),
SourceName.of("TEST1"));
TEST1);
assertThat(analysis.getWhereExpression().get().toString(), is("(TEST1.COL0 > 100)"));

final List<SelectExpression> selects = analysis.getSelectExpressions();
assertThat(selects.get(0).getExpression().toString(), is("TEST1.COL0"));
assertThat(selects.get(1).getExpression().toString(), is("TEST1.COL2"));
assertThat(selects.get(2).getExpression().toString(), is("TEST1.COL3"));

assertThat(selects.get(0).getName(), is(ColumnName.of("COL0")));
assertThat(selects.get(1).getName(), is(ColumnName.of("COL2")));
assertThat(selects.get(2).getName(), is(ColumnName.of("COL3")));
assertThat(selects.get(0).getName(), is(COL0));
assertThat(selects.get(1).getName(), is(COL2));
assertThat(selects.get(2).getName(), is(COL3));
}

@Test
@@ -202,7 +212,7 @@ public void testBooleanExpressionAnalysis() {
final Analysis analysis = analyzeQuery(queryStr, jsonMetaStore);

assertEquals("FROM was not analyzed correctly.",
analysis.getFromDataSources().get(0).getDataSource().getName(), SourceName.of("TEST1"));
analysis.getFromDataSources().get(0).getDataSource().getName(), TEST1);

final List<SelectExpression> selects = analysis.getSelectExpressions();
assertThat(selects.get(0).getExpression().toString(), is("(TEST1.COL0 = 10)"));
@@ -215,7 +225,7 @@ public void testFilterAnalysis() {
final String queryStr = "SELECT col0 = 10, col2, col3 > col1 FROM test1 WHERE col0 > 20 EMIT CHANGES;";
final Analysis analysis = analyzeQuery(queryStr, jsonMetaStore);

assertThat(analysis.getFromDataSources().get(0).getDataSource().getName(), is(SourceName.of("TEST1")));
assertThat(analysis.getFromDataSources().get(0).getDataSource().getName(), is(TEST1));

final List<SelectExpression> selects = analysis.getSelectExpressions();
assertThat(selects.get(0).getExpression().toString(), is("(TEST1.COL0 = 10)"));
@@ -450,6 +460,50 @@ public void shouldThrowOnJoinIfKafkaFormat() {
analyzer.analyze(query, Optional.of(sink));
}

@Test
public void shouldCaptureProjectionColumnRefs() {
// Given:
query = parseSingle("Select COL0, COL0 + COL1, SUBSTRING(COL2, 1) from TEST1;");

// When:
final Analysis analysis = analyzer.analyze(query, Optional.empty());

// Then:
assertThat(analysis.getSelectColumnRefs(), containsInAnyOrder(
ColumnRef.of(TEST1, COL0),
ColumnRef.of(TEST1, COL1),
ColumnRef.of(TEST1, COL2)
));
}

@Test
public void shouldIncludeMetaColumnsForSelectStarOnContinuousQueries() {
// Given:
query = parseSingle("Select * from TEST1 EMIT CHANGES;");

// When:
final Analysis analysis = analyzer.analyze(query, Optional.empty());

// Then:
assertThat(analysis.getSelectExpressions(), hasItem(
SelectExpression.of(ROWTIME_NAME, new ColumnReferenceExp(ColumnRef.of(TEST1, ROWTIME_NAME)))
));
}

@Test
public void shouldNotIncludeMetaColumnsForSelectStartOnStaticQueries() {
// Given:
query = parseSingle("Select * from TEST1;");

// When:
final Analysis analysis = analyzer.analyze(query, Optional.empty());

// Then:
assertThat(analysis.getSelectExpressions(), not(hasItem(
SelectExpression.of(ROWTIME_NAME, new ColumnReferenceExp(ColumnRef.of(TEST1, ROWTIME_NAME)))
)));
}

@SuppressWarnings("unchecked")
private <T extends Statement> T parseSingle(final String simpleQuery) {
return (T) Iterables.getOnlyElement(parse(simpleQuery, jsonMetaStore));
@@ -478,7 +532,7 @@ private void buildProps() {

private void registerKafkaSource() {
final LogicalSchema schema = LogicalSchema.builder()
.valueColumn(ColumnName.of("COL0"), SqlTypes.BIGINT)
.valueColumn(COL0, SqlTypes.BIGINT)
.build();

final KsqlTopic topic = new KsqlTopic(
Original file line number Diff line number Diff line change
@@ -19,12 +19,15 @@
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.confluent.ksql.analyzer.Analysis.Into;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.parser.tree.ResultMaterialization;
import io.confluent.ksql.parser.tree.WindowExpression;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.Optional;
import java.util.OptionalInt;
import org.junit.Before;
@@ -109,7 +112,7 @@ public void shouldThrowOnStaticQueryThatIsWindowed() {
}

@Test
public void shouldThrowOnStaticQueryThatHasGroupBy() {
public void shouldThrowOnGroupBy() {
// Given:
when(analysis.getGroupByExpressions()).thenReturn(ImmutableList.of(AN_EXPRESSION));

@@ -122,7 +125,7 @@ public void shouldThrowOnStaticQueryThatHasGroupBy() {
}

@Test
public void shouldThrowOnStaticQueryThatHasPartitionBy() {
public void shouldThrowOnPartitionBy() {
// Given:
when(analysis.getPartitionBy()).thenReturn(Optional.of(ColumnName.of("Something")));

@@ -135,7 +138,7 @@ public void shouldThrowOnStaticQueryThatHasPartitionBy() {
}

@Test
public void shouldThrowOnStaticQueryThatHasHavingClause() {
public void shouldThrowOnHavingClause() {
// Given:
when(analysis.getHavingExpression()).thenReturn(Optional.of(AN_EXPRESSION));

@@ -148,7 +151,7 @@ public void shouldThrowOnStaticQueryThatHasHavingClause() {
}

@Test
public void shouldThrowOnStaticQueryThatHasLimitClause() {
public void shouldThrowOnLimitClause() {
// Given:
when(analysis.getLimitClause()).thenReturn(OptionalInt.of(1));

@@ -159,4 +162,18 @@ public void shouldThrowOnStaticQueryThatHasLimitClause() {
// When:
validator.validate(analysis);
}

@Test
public void shouldThrowOnRowTimeInProjection() {
// Given:
when(analysis.getSelectColumnRefs())
.thenReturn(ImmutableSet.of(ColumnRef.of(SchemaUtil.ROWTIME_NAME)));

// Then:
expectedException.expect(KsqlException.class);
expectedException.expectMessage("Static queries don't support ROWTIME in select columns.");

// When:
validator.validate(analysis);
}
}
Original file line number Diff line number Diff line change
@@ -235,6 +235,32 @@
}
]
},
{
"name": "non-windowed projection WITH ROWTIME",
"statements": [
"CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT GROUP BY ROWKEY;",
"SELECT ROWTIME + 10, COUNT FROM AGGREGATE WHERE ROWKEY='10';"
],
"expectedError": {
"type": "io.confluent.ksql.rest.entity.KsqlStatementErrorMessage",
"message": "Static queries don't support ROWTIME in select columns.",
"status": 400
}
},
{
"name": "windowed with projection with ROWTIME",
"statements": [
"CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ROWKEY;",
"SELECT COUNT, ROWTIME + 10 FROM AGGREGATE WHERE ROWKEY='10' AND WindowStart=12000;"
],
"expectedError": {
"type": "io.confluent.ksql.rest.entity.KsqlStatementErrorMessage",
"message": "Static queries don't support ROWTIME in select columns.",
"status": 400
}
},
{
"name": "text datetime window bounds",
"enabled": false,

0 comments on commit 2f27b68

Please sign in to comment.