diff --git a/ksql-common/src/test/java/io/confluent/ksql/configdef/ConfigValidatorsTest.java b/ksql-common/src/test/java/io/confluent/ksql/configdef/ConfigValidatorsTest.java index 2cf0d7d7cfc0..52d2c5180796 100644 --- a/ksql-common/src/test/java/io/confluent/ksql/configdef/ConfigValidatorsTest.java +++ b/ksql-common/src/test/java/io/confluent/ksql/configdef/ConfigValidatorsTest.java @@ -16,6 +16,8 @@ package io.confluent.ksql.configdef; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.util.function.Function; import org.apache.kafka.common.config.ConfigDef.Validator; @@ -27,9 +29,6 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - @RunWith(MockitoJUnitRunner.class) public class ConfigValidatorsTest { diff --git a/ksql-engine/src/main/java/io/confluent/ksql/materialization/AggregatesInfo.java b/ksql-engine/src/main/java/io/confluent/ksql/materialization/AggregatesInfo.java new file mode 100644 index 000000000000..ccfad6bc36a2 --- /dev/null +++ b/ksql-engine/src/main/java/io/confluent/ksql/materialization/AggregatesInfo.java @@ -0,0 +1,70 @@ +/* + * Copyright 2019 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.materialization; + +import static java.util.Objects.requireNonNull; + +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import java.util.List; + + +@Immutable +public final class AggregatesInfo { + + private final int startingColumnIndex; + private final List aggregateFunctions; + private final LogicalSchema schema; + + /** + * @param startingColumnIndex column index of first aggregate function. + * @param aggregateFunctions the map of column index to aggregate function. + * @param schema the schema required by the aggregators. + * @return the immutable instance. + */ + public static AggregatesInfo of( + final int startingColumnIndex, + final List aggregateFunctions, + final LogicalSchema schema + ) { + return new AggregatesInfo(startingColumnIndex, aggregateFunctions, schema); + } + + private AggregatesInfo( + final int startingColumnIndex, + final List aggregateFunctions, + final LogicalSchema prepareSchema + ) { + this.startingColumnIndex = startingColumnIndex; + this.aggregateFunctions = ImmutableList + .copyOf(requireNonNull(aggregateFunctions, "aggregateFunctions")); + this.schema = requireNonNull(prepareSchema, "prepareSchema"); + } + + public int startingColumnIndex() { + return startingColumnIndex; + } + + public List aggregateFunctions() { + return aggregateFunctions; + } + + public LogicalSchema schema() { + return schema; + } +} diff --git a/ksql-engine/src/main/java/io/confluent/ksql/materialization/KsqlMaterialization.java b/ksql-engine/src/main/java/io/confluent/ksql/materialization/KsqlMaterialization.java index 59ddbc4455aa..5ae711df2be5 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/materialization/KsqlMaterialization.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/materialization/KsqlMaterialization.java @@ -33,24 +33,48 @@ /** * {@link Materialization} implementation responsible for handling HAVING and SELECT clauses. * - *

Underlying {@link Materialization} store data in a different schema and have not had any - * HAVING predicate applied. Mapping from the aggregate store schema to the table's schema and - * applying any HAVING predicate is handled by this class. + *

Underlying {@link Materialization} store data is not the same as the table it servers. + * Specifically, it has not had: + *

    + *
  1. + * The {@link io.confluent.ksql.function.udaf.Udaf#map} call applied to convert intermediate + * aggregate types on output types + *
  2. + *
  3. + * Any HAVING predicate applied. + *
  4. + *
  5. + * The select value mapper applied to convert from the internal schema to the table's scheam. + *
  6. + *
+ * + *

This class is responsible for this for now. Long term, these should be handled by physical + * plan steps. */ class KsqlMaterialization implements Materialization { private final Materialization inner; + private final Function aggregateTransform; private final Predicate havingPredicate; private final Function storeToTableTransform; private final LogicalSchema schema; + /** + * @param inner the inner materialization, e.g. a KS specific one + * @param aggregateTransform converts from aggregates from intermediate to output types. + * @param havingPredicate the predicate for handling HAVING clauses. + * @param storeToTableTransform maps from internal to table schema. + * @param schema the schema of the materialized table. + */ KsqlMaterialization( final Materialization inner, + final Function aggregateTransform, final Predicate havingPredicate, final Function storeToTableTransform, final LogicalSchema schema ) { this.inner = requireNonNull(inner, "table"); + this.aggregateTransform = requireNonNull(aggregateTransform, "aggregateTransform"); this.havingPredicate = requireNonNull(havingPredicate, "havingPredicate"); this.storeToTableTransform = requireNonNull(storeToTableTransform, "storeToTableTransform"); this.schema = requireNonNull(schema, "schema"); @@ -86,6 +110,9 @@ private Optional filterAndTransform( final GenericRow value ) { return Optional.of(value) + // Call Udaf.map() to convert the internal representation stored in the state store into + // the output type of the aggregator + .map(aggregateTransform) // HAVING predicate from source table query that has not already been applied to the // store, so must be applied to any result from the store. .filter(v -> havingPredicate.test(key, v)) diff --git a/ksql-engine/src/main/java/io/confluent/ksql/materialization/KsqlMaterializationFactory.java b/ksql-engine/src/main/java/io/confluent/ksql/materialization/KsqlMaterializationFactory.java index 6430a037ccfc..6775a3d1b4e0 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/materialization/KsqlMaterializationFactory.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/materialization/KsqlMaterializationFactory.java @@ -25,6 +25,7 @@ import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.execution.sqlpredicate.SqlPredicate; +import io.confluent.ksql.execution.streams.AggregateParams; import io.confluent.ksql.execution.streams.SelectValueMapperFactory; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.logging.processing.ProcessingLogContext; @@ -47,8 +48,9 @@ public final class KsqlMaterializationFactory { private final KsqlConfig ksqlConfig; private final FunctionRegistry functionRegistry; private final ProcessingLogContext processingLogContext; + private final AggregateMapperFactory aggregateMapperFactory; private final SqlPredicateFactory sqlPredicateFactory; - private final ValueMapperFactory valueMapperFactory; + private final SelectMapperFactory selectMapperFactory; private final MaterializationFactory materializationFactory; public KsqlMaterializationFactory( @@ -60,6 +62,7 @@ public KsqlMaterializationFactory( ksqlConfig, functionRegistry, processingLogContext, + defaultAggregateMapperFactory(), SqlPredicate::new, defaultValueMapperFactory(), KsqlMaterialization::new @@ -71,15 +74,17 @@ public KsqlMaterializationFactory( final KsqlConfig ksqlConfig, final FunctionRegistry functionRegistry, final ProcessingLogContext processingLogContext, + final AggregateMapperFactory aggregateMapperFactory, final SqlPredicateFactory sqlPredicateFactory, - final ValueMapperFactory valueMapperFactory, + final SelectMapperFactory selectMapperFactory, final MaterializationFactory materializationFactory ) { this.ksqlConfig = requireNonNull(ksqlConfig, "ksqlConfig"); this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry"); this.processingLogContext = requireNonNull(processingLogContext, "processingLogContext"); + this.aggregateMapperFactory = requireNonNull(aggregateMapperFactory, "aggregateMapperFactory"); this.sqlPredicateFactory = requireNonNull(sqlPredicateFactory, "sqlPredicateFactory"); - this.valueMapperFactory = requireNonNull(valueMapperFactory, "valueMapperFactory"); + this.selectMapperFactory = requireNonNull(selectMapperFactory, "selectMapperFactory"); this.materializationFactory = requireNonNull(materializationFactory, "materializationFactory"); } @@ -88,6 +93,9 @@ public Materialization create( final MaterializationInfo info, final QueryContext.Stacker contextStacker ) { + final Function aggregateMapper = + bakeAggregateMapper(info); + final Predicate havingPredicate = bakeHavingExpression(info, contextStacker); @@ -96,12 +104,22 @@ public Materialization create( return materializationFactory.create( delegate, + aggregateMapper, havingPredicate, valueMapper, info.tableSchema() ); } + private Function bakeAggregateMapper( + final MaterializationInfo info + ) { + return aggregateMapperFactory.create( + info.aggregatesInfo(), + functionRegistry + ); + } + private Predicate bakeHavingExpression( final MaterializationInfo info, final QueryContext.Stacker contextStacker @@ -135,7 +153,7 @@ private Function bakeStoreSelects( QueryLoggerUtil.queryLoggerName(contextStacker.push(PROJECT_OP_NAME).getQueryContext()) ); - return valueMapperFactory.create( + return selectMapperFactory.create( info.tableSelects(), info.aggregationSchema(), ksqlConfig, @@ -144,7 +162,19 @@ private Function bakeStoreSelects( ); } - private static ValueMapperFactory defaultValueMapperFactory() { + private static AggregateMapperFactory defaultAggregateMapperFactory() { + return (info, functionRegistry) -> + new AggregateParams( + info.schema(), + info.startingColumnIndex(), + functionRegistry, + info.aggregateFunctions() + ) + .getAggregator() + .getResultMapper()::apply; + } + + private static SelectMapperFactory defaultValueMapperFactory() { return (selectExpressions, sourceSchema, ksqlConfig, functionRegistry, processingLogger) -> SelectValueMapperFactory.create( selectExpressions, @@ -155,6 +185,14 @@ private static ValueMapperFactory defaultValueMapperFactory() { )::apply; } + interface AggregateMapperFactory { + + Function create( + AggregatesInfo info, + FunctionRegistry functionRegistry + ); + } + interface SqlPredicateFactory { SqlPredicate create( @@ -166,7 +204,7 @@ SqlPredicate create( ); } - interface ValueMapperFactory { + interface SelectMapperFactory { Function create( List selectExpressions, @@ -181,6 +219,7 @@ interface MaterializationFactory { KsqlMaterialization create( Materialization inner, + Function aggregateTransform, Predicate havingPredicate, Function storeToTableTransform, LogicalSchema schema diff --git a/ksql-engine/src/main/java/io/confluent/ksql/materialization/MaterializationInfo.java b/ksql-engine/src/main/java/io/confluent/ksql/materialization/MaterializationInfo.java index f9eee18e79d7..556dabc0569a 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/materialization/MaterializationInfo.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/materialization/MaterializationInfo.java @@ -32,6 +32,7 @@ public final class MaterializationInfo { private final String stateStoreName; + private final AggregatesInfo aggregatesInfo; private final LogicalSchema aggregationSchema; private final Optional havingExpression; private final LogicalSchema tableSchema; @@ -41,7 +42,8 @@ public final class MaterializationInfo { * Create instance. * * @param stateStoreName the name of the state store - * @param stateStoreSchema the schema of the state store + * @param aggregatesInfo info about the aggregate functions used. + * @param aggregationSchema the schema of the state store * @param havingExpression optional HAVING expression that should be apply to any store result. * @param tableSchema the schema of the table. * @param tableSelects SELECT expressions to convert state store schema to table schema. @@ -49,14 +51,16 @@ public final class MaterializationInfo { */ public static MaterializationInfo of( final String stateStoreName, - final LogicalSchema stateStoreSchema, + final AggregatesInfo aggregatesInfo, + final LogicalSchema aggregationSchema, final Optional havingExpression, final LogicalSchema tableSchema, final List tableSelects ) { return new MaterializationInfo( stateStoreName, - stateStoreSchema, + aggregatesInfo, + aggregationSchema, havingExpression, tableSchema, tableSelects @@ -67,6 +71,10 @@ public String stateStoreName() { return stateStoreName; } + public AggregatesInfo aggregatesInfo() { + return aggregatesInfo; + } + public LogicalSchema aggregationSchema() { return aggregationSchema; } @@ -85,12 +93,14 @@ public List tableSelects() { private MaterializationInfo( final String stateStoreName, + final AggregatesInfo aggregatesInfo, final LogicalSchema aggregationSchema, final Optional havingExpression, final LogicalSchema tableSchema, final List tableSelects ) { this.stateStoreName = requireNonNull(stateStoreName, "stateStoreName"); + this.aggregatesInfo = requireNonNull(aggregatesInfo, "aggregatesInfo"); this.aggregationSchema = requireNonNull(aggregationSchema, "aggregationSchema"); this.havingExpression = requireNonNull(havingExpression, "havingExpression"); this.tableSchema = requireNonNull(tableSchema, "tableSchema"); diff --git a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java index 32972ae4d402..5a06bf5ad7bd 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java @@ -32,6 +32,7 @@ import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.KsqlAggregateFunction; +import io.confluent.ksql.materialization.AggregatesInfo; import io.confluent.ksql.materialization.MaterializationInfo; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.name.ColumnName; @@ -248,6 +249,8 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { final QueryContext.Stacker aggregationContext = contextStacker.push(AGGREGATION_OP_NAME); + // This is the schema post any {@link Udaf#map} steps to reduce intermediate aggregate state + // to the final output state final LogicalSchema outputSchema = buildLogicalSchema( prepareSchema, functionsWithInternalIdentifiers, @@ -280,8 +283,15 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { final List finalSelects = internalSchema .updateFinalSelectExpressions(getFinalSelectExpressions()); + final AggregatesInfo aggregatesInfo = AggregatesInfo.of( + requiredColumns.size(), + functionsWithInternalIdentifiers, + prepareSchema + ); + materializationInfo = Optional.of(MaterializationInfo.of( AGGREGATE_STATE_STORE_NAME, + aggregatesInfo, outputSchema, havingExpression, schema, diff --git a/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/StatementRewriterTest.java b/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/StatementRewriterTest.java index fe591db0609c..e97ed2afe5b3 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/StatementRewriterTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/engine/rewrite/StatementRewriterTest.java @@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.SourceName; import io.confluent.ksql.parser.NodeLocation; @@ -42,7 +43,6 @@ import io.confluent.ksql.parser.tree.Join; import io.confluent.ksql.parser.tree.Join.Type; import io.confluent.ksql.parser.tree.JoinCriteria; -import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.Relation; import io.confluent.ksql.parser.tree.ResultMaterialization; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/materialization/KsqlMaterializationFactoryTest.java b/ksql-engine/src/test/java/io/confluent/ksql/materialization/KsqlMaterializationFactoryTest.java index eb4a398785a1..14ec213337ef 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/materialization/KsqlMaterializationFactoryTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/materialization/KsqlMaterializationFactoryTest.java @@ -36,9 +36,10 @@ import io.confluent.ksql.logging.processing.ProcessingLogContext; import io.confluent.ksql.logging.processing.ProcessingLogger; import io.confluent.ksql.logging.processing.ProcessingLoggerFactory; +import io.confluent.ksql.materialization.KsqlMaterializationFactory.AggregateMapperFactory; import io.confluent.ksql.materialization.KsqlMaterializationFactory.MaterializationFactory; +import io.confluent.ksql.materialization.KsqlMaterializationFactory.SelectMapperFactory; import io.confluent.ksql.materialization.KsqlMaterializationFactory.SqlPredicateFactory; -import io.confluent.ksql.materialization.KsqlMaterializationFactory.ValueMapperFactory; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.schema.ksql.LogicalSchema; @@ -85,6 +86,8 @@ public class KsqlMaterializationFactoryTest { @Mock private MaterializationInfo info; @Mock + private AggregatesInfo aggInfo; + @Mock private SqlPredicateFactory sqlPredicateFactory; @Mock private ProcessingLogger filterProcessingLogger; @@ -93,13 +96,17 @@ public class KsqlMaterializationFactoryTest { @Mock private ProcessingLoggerFactory processingLoggerFactory; @Mock - private ValueMapperFactory valueMapperFactory; + private AggregateMapperFactory aggregateMapperFactory; + @Mock + private SelectMapperFactory selectMapperFactory; @Mock private SqlPredicate havingSqlPredicate; @Mock + private Function aggregateMapper; + @Mock private Predicate havingPredicate; @Mock - private Function valueMapper; + private Function selectMapper; @Mock private MaterializationFactory materializationFactory; @@ -114,8 +121,9 @@ public void setUp() { ksqlConfig, functionRegistry, processingLogContext, + aggregateMapperFactory, sqlPredicateFactory, - valueMapperFactory, + selectMapperFactory, materializationFactory ); @@ -125,11 +133,13 @@ public void setUp() { when(info.aggregationSchema()).thenReturn(AGGREGATE_SCHEMA); when(info.tableSchema()).thenReturn(TABLE_SCHEMA); + when(info.aggregatesInfo()).thenReturn(aggInfo); + when(aggregateMapperFactory.create(any(), any())).thenReturn(aggregateMapper); when(havingSqlPredicate.getPredicate()).thenReturn((Predicate) havingPredicate); when(sqlPredicateFactory.create(any(), any(), any(), any(), any())) .thenReturn(havingSqlPredicate); - when(valueMapperFactory.create(any(), any(), any(), any(), any())).thenReturn(valueMapper); + when(selectMapperFactory.create(any(), any(), any(), any(), any())).thenReturn(selectMapper); when(info.havingExpression()).thenReturn(Optional.of(HAVING_EXP)); } @@ -188,6 +198,21 @@ public void shouldGetProjectProcessingLoggerWithCorrectParams() { verify(processingLoggerFactory).getLogger("start.project"); } + @Test + public void shouldBuildSelectAggregateMapperWithCorrectParameters() { + // Given: + when(info.tableSelects()).thenReturn(SELECTS); + + // When: + factory.create(materialization, info, contextStacker); + + // Then: + verify(aggregateMapperFactory).create( + aggInfo, + functionRegistry + ); + } + @Test public void shouldBuildSelectValueMapperWithCorrectParameters() { // Given: @@ -197,7 +222,7 @@ public void shouldBuildSelectValueMapperWithCorrectParameters() { factory.create(materialization, info, contextStacker); // Then: - verify(valueMapperFactory).create( + verify(selectMapperFactory).create( SELECTS, AGGREGATE_SCHEMA, ksqlConfig, @@ -214,8 +239,9 @@ public void shouldBuildMaterializationWithCorrectParams() { // Then: verify(materializationFactory).create( eq(materialization), + eq(aggregateMapper), eq(havingPredicate), - any(), + eq(selectMapper), eq(TABLE_SCHEMA) ); } @@ -224,7 +250,8 @@ public void shouldBuildMaterializationWithCorrectParams() { public void shouldReturnMaterialization() { // Given: final KsqlMaterialization ksqlMaterialization = mock(KsqlMaterialization.class); - when(materializationFactory.create(any(), any(), any(), any())).thenReturn(ksqlMaterialization); + when(materializationFactory.create(any(), any(), any(), any(), any())) + .thenReturn(ksqlMaterialization); // When: final Materialization result = factory diff --git a/ksql-engine/src/test/java/io/confluent/ksql/materialization/KsqlMaterializationTest.java b/ksql-engine/src/test/java/io/confluent/ksql/materialization/KsqlMaterializationTest.java index eb5eb79d7a38..f2e914d2ee41 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/materialization/KsqlMaterializationTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/materialization/KsqlMaterializationTest.java @@ -88,6 +88,8 @@ public class KsqlMaterializationTest { @Mock private Materialization inner; @Mock + private Function aggregateTransform; + @Mock private Predicate havingPredicate; @Mock private Function storeToTableTransform; @@ -103,6 +105,7 @@ public class KsqlMaterializationTest { public void setUp() { materialization = new KsqlMaterialization( inner, + aggregateTransform, havingPredicate, storeToTableTransform, SCHEMA @@ -114,6 +117,8 @@ public void setUp() { when(innerNonWindowed.get(any())).thenReturn(Optional.of(ROW)); when(innerWindowed.get(any(), any())).thenReturn(ImmutableList.of(WINDOWED_ROW)); + when(aggregateTransform.apply(any())).thenAnswer(inv -> inv.getArgument(0)); + when(havingPredicate.test(any(), any())).thenReturn(true); when(storeToTableTransform.apply(any())).thenAnswer(inv -> inv.getArgument(0)); @@ -271,7 +276,7 @@ public void shouldFilterWindowed() { } @Test - public void shouldTransformRowAfterFilterNonWindowed() { + public void shouldAggregateMapThenTransformRowThenFilterNonWindowed() { // Given: final MaterializedTable table = materialization.nonWindowed(); @@ -279,13 +284,14 @@ public void shouldTransformRowAfterFilterNonWindowed() { table.get(A_KEY); // Then: - final InOrder inOrder = inOrder(havingPredicate, storeToTableTransform); + final InOrder inOrder = inOrder(aggregateTransform, havingPredicate, storeToTableTransform); + inOrder.verify(aggregateTransform).apply(any()); inOrder.verify(havingPredicate).test(any(), any()); inOrder.verify(storeToTableTransform).apply(any()); } @Test - public void shouldTransformRowAfterFilterWindowed() { + public void shouldAggregateMapThenTransformRowThenFilterWindowed() { // Given: final MaterializedWindowedTable table = materialization.windowed(); @@ -293,11 +299,38 @@ public void shouldTransformRowAfterFilterWindowed() { table.get(A_KEY, WINDOW_START_BOUNDS); // Then: - final InOrder inOrder = inOrder(havingPredicate, storeToTableTransform); + final InOrder inOrder = inOrder(aggregateTransform, havingPredicate, storeToTableTransform); + inOrder.verify(aggregateTransform).apply(any()); inOrder.verify(havingPredicate).test(any(), any()); inOrder.verify(storeToTableTransform).apply(any()); } + @Test + public void shouldUseAggregateTransformedFromNonWindowed() { + // Given: + final MaterializedTable table = materialization.nonWindowed(); + when(aggregateTransform.apply(any())).thenReturn(TRANSFORMED); + + // When: + table.get(A_KEY); + + // Then: + verify(havingPredicate).test(A_KEY, TRANSFORMED); + } + + @Test + public void shouldUseAggregateTransformedFromWindowed() { + // Given: + final MaterializedWindowedTable table = materialization.windowed(); + when(aggregateTransform.apply(any())).thenReturn(TRANSFORMED); + + // When: + table.get(A_KEY, AN_INSTANT, AN_INSTANT); + + // Then: + verify(havingPredicate).test(A_KEY, TRANSFORMED); + } + @Test public void shouldCallTransformWithCorrectParamsNonWindowed() { // Given: @@ -324,7 +357,7 @@ public void shouldCallTransformWithCorrectParamsWindowed() { @SuppressWarnings("OptionalGetWithoutIsPresent") @Test - public void shouldReturnTransformedFromNonWindowed() { + public void shouldReturnSelectTransformedFromNonWindowed() { // Given: final MaterializedTable table = materialization.nonWindowed(); when(storeToTableTransform.apply(any())).thenReturn(TRANSFORMED); @@ -340,7 +373,7 @@ public void shouldReturnTransformedFromNonWindowed() { } @Test - public void shouldReturnTransformedFromWindowed() { + public void shouldReturnSelectTransformedFromWindowed() { // Given: final MaterializedWindowedTable table = materialization.windowed(); when(storeToTableTransform.apply(any())).thenReturn(TRANSFORMED); diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java index d3533a609eaa..ccc4f26fd6ca 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedStreamTest.java @@ -31,12 +31,12 @@ import io.confluent.ksql.execution.plan.Formats; import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; +import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.execution.windows.SessionWindowExpression; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.function.InternalFunctionRegistry; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.model.WindowType; -import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.parser.tree.WindowExpression; diff --git a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java index 846f75bbd5e1..47f1f9a5b9d7 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/structured/SchemaKGroupedTableTest.java @@ -25,19 +25,19 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; -import io.confluent.ksql.execution.expression.tree.FunctionCall; -import io.confluent.ksql.name.ColumnName; -import io.confluent.ksql.name.FunctionName; -import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; +import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.plan.ExecutionStep; import io.confluent.ksql.execution.plan.Formats; import io.confluent.ksql.execution.streams.ExecutionStepFactory; import io.confluent.ksql.execution.streams.MaterializedFactory; import io.confluent.ksql.function.InternalFunctionRegistry; import io.confluent.ksql.metastore.model.KeyField; +import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.query.QueryId; +import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.serde.Format; diff --git a/ksql-execution/src/test/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapperTest.java b/ksql-execution/src/test/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapperTest.java index 1b960b296e56..477844da5f9f 100644 --- a/ksql-execution/src/test/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapperTest.java +++ b/ksql-execution/src/test/java/io/confluent/ksql/execution/function/udaf/window/WindowSelectMapperTest.java @@ -22,7 +22,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.confluent.ksql.GenericRow; -import io.confluent.ksql.execution.function.udaf.window.WindowSelectMapper; import io.confluent.ksql.function.KsqlAggregateFunction; import java.util.ArrayList; import java.util.Arrays; diff --git a/ksql-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestExecutor.java b/ksql-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestExecutor.java index 9424202c48b1..e95b90deb254 100644 --- a/ksql-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestExecutor.java +++ b/ksql-functional-tests/src/test/java/io/confluent/ksql/test/rest/RestTestExecutor.java @@ -29,6 +29,7 @@ import io.confluent.ksql.rest.client.RestResponse; import io.confluent.ksql.rest.entity.KsqlEntity; import io.confluent.ksql.rest.entity.KsqlEntityList; +import io.confluent.ksql.rest.entity.KsqlStatementErrorMessage; import io.confluent.ksql.services.ServiceContext; import io.confluent.ksql.test.rest.model.Response; import io.confluent.ksql.test.tools.Record; @@ -212,9 +213,13 @@ private Optional> sendStatements( if (resp.isErroneous()) { final Optional>> expectedError = testCase.expectedError(); if (!expectedError.isPresent()) { + final String statement = resp.getErrorMessage() instanceof KsqlStatementErrorMessage + ? ((KsqlStatementErrorMessage)resp.getErrorMessage()).getStatementText() + : ""; + throw new AssertionError( "Server failed to execute statement" + System.lineSeparator() - + "statement: " + System.lineSeparator() + + "statement: " + statement + System.lineSeparator() + "reason: " + resp.getErrorMessage() ); } diff --git a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json index d67d50641934..1e9eecef7134 100644 --- a/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json +++ b/ksql-functional-tests/src/test/resources/rest-query-validation-tests/materialized-aggregate-static-queries.json @@ -263,7 +263,6 @@ }, { "name": "text datetime window bounds", - "enabled": false, "statements": [ "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ROWKEY;", @@ -340,6 +339,51 @@ {"@type": "rows", "rows": []} ] }, + { + "name": "non-windowed with UDAF with different intermediate type", + "statements": [ + "CREATE STREAM INPUT (VAL INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT, AVG(VAL) AS AVG FROM INPUT GROUP BY ROWKEY;", + "SELECT * FROM AGGREGATE WHERE ROWKEY='10';" + ], + "inputs": [ + {"topic": "test_topic", "key": "11", "value": {"val": 1}}, + {"topic": "test_topic", "key": "10", "value": {"val": 2}}, + {"topic": "test_topic", "key": "10", "value": {"val": 4}} + ], + "responses": [ + {"@type": "currentStatus"}, + {"@type": "currentStatus"}, + { + "@type": "rows", + "schema": "`ROWKEY` STRING KEY, `COUNT` BIGINT, `AVG` DOUBLE", + "rows": [["10", 2, 3.0]] + } + ] + }, + { + "name": "windowed with UDAF with different intermediate type", + "statements": [ + "CREATE STREAM INPUT (VAL INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT, AVG(VAL) AS AVG FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ROWKEY;", + "SELECT * FROM AGGREGATE WHERE ROWKEY='10' AND WindowStart=11000;" + ], + "inputs": [ + {"topic": "test_topic", "timestamp": 12345, "key": "11", "value": {"VAL": 1}}, + {"topic": "test_topic", "timestamp": 11345, "key": "10", "value": {"VAL": 6}}, + {"topic": "test_topic", "timestamp": 11346, "key": "10", "value": {"VAL": 4}} + ], + "responses": [ + {"@type": "currentStatus"}, + {"@type": "currentStatus"}, + { + "@type": "rows", + "schema": "`ROWKEY` STRING KEY, `WINDOWSTART` BIGINT KEY, `COUNT` BIGINT, `AVG` DOUBLE", + "rows": [ + ["10", 11000, 2, 5.0] + ]} + ] + }, { "name": "fail on unsupported query feature: join", "statements": [ diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/HoppingWindowExpressionTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/HoppingWindowExpressionTest.java index a9112715447b..60cbe69ca945 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/HoppingWindowExpressionTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/HoppingWindowExpressionTest.java @@ -21,32 +21,16 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import com.google.common.testing.EqualsTester; -import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.windows.HoppingWindowExpression; -import io.confluent.ksql.function.UdafAggregator; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.parser.NodeLocation; import io.confluent.ksql.serde.WindowInfo; import java.time.Duration; import java.util.Optional; -import java.util.concurrent.TimeUnit; -import org.apache.kafka.common.utils.Bytes; -import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; -import org.apache.kafka.streams.kstream.KGroupedStream; -import org.apache.kafka.streams.kstream.Materialized; -import org.apache.kafka.streams.kstream.TimeWindowedKStream; -import org.apache.kafka.streams.kstream.TimeWindows; -import org.apache.kafka.streams.state.WindowStore; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; @RunWith(MockitoJUnitRunner.class) diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/ParserModelTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/ParserModelTest.java index f9ab9730daae..75706f8b93b0 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/ParserModelTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/ParserModelTest.java @@ -28,15 +28,15 @@ import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.InListExpression; -import io.confluent.ksql.name.SourceName; -import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.execution.expression.tree.StringLiteral; import io.confluent.ksql.execution.expression.tree.Type; import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.execution.windows.TumblingWindowExpression; +import io.confluent.ksql.name.SourceName; import io.confluent.ksql.parser.properties.with.CreateSourceAsProperties; import io.confluent.ksql.parser.properties.with.CreateSourceProperties; import io.confluent.ksql.properties.with.CommonCreateConfigs; +import io.confluent.ksql.schema.ksql.ColumnRef; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.test.util.ClassFinder; diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/SessionWindowExpressionTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/SessionWindowExpressionTest.java index 0740e2d0f0fb..5429ba93e9df 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/SessionWindowExpressionTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/SessionWindowExpressionTest.java @@ -17,31 +17,15 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import io.confluent.ksql.GenericRow; import io.confluent.ksql.execution.windows.SessionWindowExpression; -import io.confluent.ksql.function.UdafAggregator; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.serde.WindowInfo; -import java.time.Duration; import java.util.Optional; import java.util.concurrent.TimeUnit; -import org.apache.kafka.common.utils.Bytes; -import org.apache.kafka.connect.data.Struct; -import org.apache.kafka.streams.kstream.Initializer; -import org.apache.kafka.streams.kstream.KGroupedStream; -import org.apache.kafka.streams.kstream.Materialized; -import org.apache.kafka.streams.kstream.Merger; -import org.apache.kafka.streams.kstream.SessionWindowedKStream; -import org.apache.kafka.streams.kstream.SessionWindows; -import org.apache.kafka.streams.state.SessionStore; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; @RunWith(MockitoJUnitRunner.class) diff --git a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/TumblingWindowExpressionTest.java b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/TumblingWindowExpressionTest.java index 158f343a57e4..deaf61d8c49e 100644 --- a/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/TumblingWindowExpressionTest.java +++ b/ksql-parser/src/test/java/io/confluent/ksql/parser/tree/TumblingWindowExpressionTest.java @@ -18,6 +18,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; + import io.confluent.ksql.execution.windows.TumblingWindowExpression; import io.confluent.ksql.model.WindowType; import io.confluent.ksql.serde.WindowInfo; diff --git a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/StaticQueryExecutor.java b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/StaticQueryExecutor.java index a8468f0d6a42..88f955fb7f6a 100644 --- a/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/StaticQueryExecutor.java +++ b/ksql-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/StaticQueryExecutor.java @@ -182,7 +182,7 @@ public static Optional execute( return Optional.of(entity); } catch (final Exception e) { throw new KsqlStatementException( - e.getMessage(), + e.getMessage() == null ? "Server Error" : e.getMessage(), statement.getStatementText(), e ); diff --git a/ksql-serde/src/test/java/io/confluent/ksql/serde/delimited/KsqlDelimitedSerializerTest.java b/ksql-serde/src/test/java/io/confluent/ksql/serde/delimited/KsqlDelimitedSerializerTest.java index 1cff6ef577f5..f292aa7d04ff 100644 --- a/ksql-serde/src/test/java/io/confluent/ksql/serde/delimited/KsqlDelimitedSerializerTest.java +++ b/ksql-serde/src/test/java/io/confluent/ksql/serde/delimited/KsqlDelimitedSerializerTest.java @@ -23,12 +23,11 @@ import io.confluent.ksql.util.DecimalUtil; import java.math.BigDecimal; import java.nio.charset.StandardCharsets; +import org.apache.commons.csv.CSVFormat; import org.apache.kafka.common.errors.SerializationException; -import org.apache.kafka.common.serialization.Serializer; import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.SchemaBuilder; import org.apache.kafka.connect.data.Struct; -import org.apache.commons.csv.CSVFormat; import org.junit.Before; import org.junit.Rule; import org.junit.Test; diff --git a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateParams.java b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateParams.java index 37b2398eaddd..14c051160bbb 100644 --- a/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateParams.java +++ b/ksql-streams/src/main/java/io/confluent/ksql/execution/streams/AggregateParams.java @@ -41,7 +41,7 @@ public final class AggregateParams { private final List> functions; private final KudafAggregatorFactory aggregatorFactory; - AggregateParams( + public AggregateParams( final LogicalSchema internalSchema, final int initialUdafIndex, final FunctionRegistry functionRegistry, diff --git a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableTableJoinBuilderTest.java b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableTableJoinBuilderTest.java index e87816a54662..0cbcbad26c83 100644 --- a/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableTableJoinBuilderTest.java +++ b/ksql-streams/src/test/java/io/confluent/ksql/execution/streams/TableTableJoinBuilderTest.java @@ -20,7 +20,6 @@ import io.confluent.ksql.query.QueryId; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; -import org.apache.kafka.common.serialization.Serde; import org.apache.kafka.connect.data.Struct; import org.apache.kafka.streams.kstream.Joined; import org.apache.kafka.streams.kstream.KTable;