diff --git a/ksqldb-api-reactive-streams-tck/src/test/java/io/confluent/ksql/api/tck/BlockingQueryPublisherVerificationTest.java b/ksqldb-api-reactive-streams-tck/src/test/java/io/confluent/ksql/api/tck/BlockingQueryPublisherVerificationTest.java index 54e58b1c2e3f..2f6e26d341dc 100644 --- a/ksqldb-api-reactive-streams-tck/src/test/java/io/confluent/ksql/api/tck/BlockingQueryPublisherVerificationTest.java +++ b/ksqldb-api-reactive-streams-tck/src/test/java/io/confluent/ksql/api/tck/BlockingQueryPublisherVerificationTest.java @@ -17,7 +17,7 @@ import io.confluent.ksql.GenericRow; import io.confluent.ksql.api.impl.BlockingQueryPublisher; -import io.confluent.ksql.api.server.PushQueryHandle; +import io.confluent.ksql.api.server.QueryHandle; import io.confluent.ksql.query.BlockingRowQueue; import io.confluent.ksql.query.TransientQueryQueue; import io.confluent.ksql.util.KeyValue; @@ -27,6 +27,7 @@ import java.util.ArrayList; import java.util.List; import java.util.OptionalInt; +import java.util.function.Consumer; import org.reactivestreams.Publisher; import org.reactivestreams.tck.PublisherVerification; import org.reactivestreams.tck.TestEnvironment; @@ -49,7 +50,7 @@ public Publisher, GenericRow>> createPublisher(long elements) { final Context context = vertx.getOrCreateContext(); BlockingQueryPublisher publisher = new BlockingQueryPublisher(context, workerExecutor); final TestQueryHandle queryHandle = new TestQueryHandle(elements); - publisher.setQueryHandle(queryHandle); + publisher.setQueryHandle(queryHandle, false); if (elements < Integer.MAX_VALUE) { for (long l = 0; l < elements; l++) { queryHandle.queue.acceptRow(null, generateRow(l)); @@ -71,7 +72,7 @@ private static GenericRow generateRow(long num) { return GenericRow.fromList(l); } - private static class TestQueryHandle implements PushQueryHandle { + private static class TestQueryHandle implements QueryHandle { private final TransientQueryQueue queue; @@ -105,5 +106,9 @@ public void stop() { public BlockingRowQueue getQueue() { return queue; } + + @Override + public void onException(Consumer onException) { + } } } diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/reactive/BasePublisher.java b/ksqldb-common/src/main/java/io/confluent/ksql/reactive/BasePublisher.java index 8d842936fe41..297b59d3f452 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/reactive/BasePublisher.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/reactive/BasePublisher.java @@ -38,7 +38,7 @@ public abstract class BasePublisher implements Publisher { private long demand; private boolean cancelled; private boolean sentComplete; - private volatile Exception failure; + private volatile Throwable failure; public BasePublisher(final Context ctx) { this.ctx = Objects.requireNonNull(ctx); @@ -75,7 +75,7 @@ protected void checkContext() { VertxUtils.checkContext(ctx); } - protected final void sendError(final Exception e) { + protected final void sendError(final Throwable e) { checkContext(); try { if (subscriber != null) { diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/KsqlExecutionContext.java b/ksqldb-engine/src/main/java/io/confluent/ksql/KsqlExecutionContext.java index 60b64a912537..1ecd658c23c9 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/KsqlExecutionContext.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/KsqlExecutionContext.java @@ -17,7 +17,6 @@ import io.confluent.ksql.engine.KsqlEngine; import io.confluent.ksql.engine.KsqlPlan; -import io.confluent.ksql.execution.streams.RoutingFilter.RoutingFilterFactory; import io.confluent.ksql.execution.streams.RoutingOptions; import io.confluent.ksql.internal.PullQueryExecutorMetrics; import io.confluent.ksql.logging.processing.ProcessingLogContext; @@ -145,18 +144,19 @@ TransientQueryMetadata executeQuery( * plan. The physical plan is then traversed for every row in the state store. * @param serviceContext The service context to execute the query in * @param statement The pull query - * @param routingFilterFactory The filters used to route requests for HA routing * @param routingOptions Configuration parameters used for routing requests * @param pullQueryMetrics JMX metrics + * @param startImmediately Whether to start the pull query immediately. If not, the caller must + * call PullQueryResult.start to start the query. * @return the rows that are the result of the query evaluation. */ PullQueryResult executePullQuery( ServiceContext serviceContext, ConfiguredStatement statement, HARouting routing, - RoutingFilterFactory routingFilterFactory, RoutingOptions routingOptions, - Optional pullQueryMetrics + Optional pullQueryMetrics, + boolean startImmediately ); /** diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/EngineExecutor.java b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/EngineExecutor.java index 65383a423f75..3edf272363c3 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/EngineExecutor.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/EngineExecutor.java @@ -24,7 +24,6 @@ import io.confluent.ksql.config.SessionConfig; import io.confluent.ksql.execution.ddl.commands.DdlCommand; import io.confluent.ksql.execution.plan.ExecutionStep; -import io.confluent.ksql.execution.streams.RoutingFilter.RoutingFilterFactory; import io.confluent.ksql.execution.streams.RoutingOptions; import io.confluent.ksql.internal.PullQueryExecutorMetrics; import io.confluent.ksql.metastore.model.DataSource; @@ -41,6 +40,7 @@ import io.confluent.ksql.physical.pull.HARouting; import io.confluent.ksql.physical.pull.PullPhysicalPlan; import io.confluent.ksql.physical.pull.PullPhysicalPlanBuilder; +import io.confluent.ksql.physical.pull.PullQueryQueuePopulator; import io.confluent.ksql.physical.pull.PullQueryResult; import io.confluent.ksql.planner.LogicalPlanNode; import io.confluent.ksql.planner.LogicalPlanner; @@ -49,6 +49,7 @@ import io.confluent.ksql.planner.plan.KsqlStructuredDataOutputNode; import io.confluent.ksql.planner.plan.OutputNode; import io.confluent.ksql.planner.plan.PlanNode; +import io.confluent.ksql.query.PullQueryQueue; import io.confluent.ksql.query.QueryExecutor; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.schema.ksql.LogicalSchema; @@ -137,7 +138,6 @@ ExecuteResult execute(final KsqlPlan plan) { * Evaluates a pull query by first analyzing it, then building the logical plan and finally * the physical plan. The execution is then done using the physical plan in a pipelined manner. * @param statement The pull query - * @param routingFilterFactory The filters used for HA routing * @param routingOptions Configuration parameters used for HA routing * @param pullQueryMetrics JMX metrics * @return the rows that are the result of evaluating the pull query @@ -145,9 +145,9 @@ ExecuteResult execute(final KsqlPlan plan) { PullQueryResult executePullQuery( final ConfiguredStatement statement, final HARouting routing, - final RoutingFilterFactory routingFilterFactory, final RoutingOptions routingOptions, - final Optional pullQueryMetrics + final Optional pullQueryMetrics, + final boolean startImmediately ) { if (!statement.getStatement().isPullQuery()) { @@ -168,10 +168,17 @@ PullQueryResult executePullQuery( logicalPlan, analysis ); - return routing.handlePullQuery( + final PullQueryQueue pullQueryQueue = new PullQueryQueue(); + final PullQueryQueuePopulator populator = () -> routing.handlePullQuery( serviceContext, physicalPlan, statement, routingOptions, physicalPlan.getOutputSchema(), - physicalPlan.getQueryId()); + physicalPlan.getQueryId(), pullQueryQueue); + final PullQueryResult result = new PullQueryResult(physicalPlan.getOutputSchema(), populator, + physicalPlan.getQueryId(), pullQueryQueue, pullQueryMetrics); + if (startImmediately) { + result.start(); + } + return result; } catch (final Exception e) { pullQueryMetrics.ifPresent(metrics -> metrics.recordErrorRate(1)); throw new KsqlStatementException( diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/KsqlEngine.java b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/KsqlEngine.java index b77defb0394a..52745370e8cf 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/KsqlEngine.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/KsqlEngine.java @@ -19,7 +19,6 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.KsqlExecutionContext; import io.confluent.ksql.ServiceInfo; -import io.confluent.ksql.execution.streams.RoutingFilter.RoutingFilterFactory; import io.confluent.ksql.execution.streams.RoutingOptions; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.internal.KsqlEngineMetrics; @@ -270,9 +269,9 @@ public PullQueryResult executePullQuery( final ServiceContext serviceContext, final ConfiguredStatement statement, final HARouting routing, - final RoutingFilterFactory routingFilterFactory, final RoutingOptions routingOptions, - final Optional pullQueryMetrics + final Optional pullQueryMetrics, + final boolean startImmediately ) { return EngineExecutor .create( @@ -283,9 +282,9 @@ public PullQueryResult executePullQuery( .executePullQuery( statement, routing, - routingFilterFactory, routingOptions, - pullQueryMetrics + pullQueryMetrics, + startImmediately ); } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/SandboxedExecutionContext.java b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/SandboxedExecutionContext.java index 8b2bf3acdab6..6832507e12fb 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/SandboxedExecutionContext.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/SandboxedExecutionContext.java @@ -17,7 +17,6 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.KsqlExecutionContext; -import io.confluent.ksql.execution.streams.RoutingFilter.RoutingFilterFactory; import io.confluent.ksql.execution.streams.RoutingOptions; import io.confluent.ksql.internal.PullQueryExecutorMetrics; import io.confluent.ksql.logging.processing.NoopProcessingLogContext; @@ -164,9 +163,9 @@ public PullQueryResult executePullQuery( final ServiceContext serviceContext, final ConfiguredStatement statement, final HARouting routing, - final RoutingFilterFactory routingFilterFactory, final RoutingOptions routingOptions, - final Optional pullQueryMetrics + final Optional pullQueryMetrics, + final boolean startImmediately ) { return EngineExecutor.create( engineContext, @@ -175,9 +174,9 @@ public PullQueryResult executePullQuery( ).executePullQuery( statement, routing, - routingFilterFactory, routingOptions, - pullQueryMetrics + pullQueryMetrics, + startImmediately ); } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/HARouting.java b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/HARouting.java index 22cb6303a946..6d51878dd5ba 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/HARouting.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/HARouting.java @@ -16,6 +16,7 @@ package io.confluent.ksql.physical.pull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -26,9 +27,11 @@ import io.confluent.ksql.execution.streams.materialization.MaterializationException; import io.confluent.ksql.internal.PullQueryExecutorMetrics; import io.confluent.ksql.parser.tree.Query; +import io.confluent.ksql.query.PullQueryQueue; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.rest.client.RestResponse; import io.confluent.ksql.rest.entity.StreamedRow; +import io.confluent.ksql.rest.entity.StreamedRow.Header; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.services.ServiceContext; import io.confluent.ksql.statement.ConfiguredStatement; @@ -38,18 +41,20 @@ import io.confluent.ksql.util.KsqlServerException; import io.confluent.ksql.util.KsqlStatementException; import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -94,14 +99,15 @@ public void close() { executorService.shutdown(); } - public PullQueryResult handlePullQuery( + public CompletableFuture handlePullQuery( final ServiceContext serviceContext, final PullPhysicalPlan pullPhysicalPlan, final ConfiguredStatement statement, final RoutingOptions routingOptions, final LogicalSchema outputSchema, - final QueryId queryId - ) throws InterruptedException { + final QueryId queryId, + final PullQueryQueue pullQueryQueue + ) { final List locations = pullPhysicalPlan.getMaterialization().locator() .locate( pullPhysicalPlan.getKeys(), @@ -119,12 +125,31 @@ public PullQueryResult handlePullQuery( statement.getStatementText())); } - // The source nodes associated with each of the rows - final List sourceNodes = new ArrayList<>(); - // Each of the table rows returned, aggregated across nodes - final List> tableRows = new ArrayList<>(); - // Each of the schemas returned, grouped by node - final Map> schemasByHost = new HashMap<>(); + final CompletableFuture completableFuture = new CompletableFuture<>(); + executorService.submit(() -> { + try { + executeRounds(serviceContext, pullPhysicalPlan, statement, routingOptions, outputSchema, + queryId, locations, pullQueryQueue); + completableFuture.complete(null); + } catch (Throwable t) { + completableFuture.completeExceptionally(t); + } + }); + + return completableFuture; + } + + private void executeRounds( + final ServiceContext serviceContext, + final PullPhysicalPlan pullPhysicalPlan, + final ConfiguredStatement statement, + final RoutingOptions routingOptions, + final LogicalSchema outputSchema, + final QueryId queryId, + final List locations, + final PullQueryQueue pullQueryQueue + ) throws InterruptedException { + // The remaining partition locations to retrieve without error List remainingLocations = ImmutableList.copyOf(locations); // For each round, each set of partition location objects is grouped by host, and all // keys associated with that host are batched together. For any requests that fail, @@ -144,13 +169,16 @@ public PullQueryResult handlePullQuery( // Make requests to each host, specifying the partitions we're interested in from // this host. - final Map> futures = new LinkedHashMap<>(); + final Map> futures = new LinkedHashMap<>(); for (Map.Entry> entry : groupedByHost.entrySet()) { final KsqlNode node = entry.getKey(); futures.put(node, executorService.submit( - () -> routeQuery.routeQuery( - node, entry.getValue(), statement, serviceContext, routingOptions, - pullQueryMetrics, pullPhysicalPlan, outputSchema, queryId) + () -> { + routeQuery.routeQuery( + node, entry.getValue(), statement, serviceContext, routingOptions, + pullQueryMetrics, pullPhysicalPlan, outputSchema, queryId, pullQueryQueue); + return null; + } )); } @@ -158,18 +186,14 @@ public PullQueryResult handlePullQuery( // the locations to the nextRoundRemaining list. final ImmutableList.Builder nextRoundRemaining = ImmutableList.builder(); - for (Map.Entry> entry : futures.entrySet()) { - final Future future = entry.getValue(); + for (Map.Entry> entry : futures.entrySet()) { + final Future future = entry.getValue(); final KsqlNode node = entry.getKey(); try { - final PullQueryResult result = future.get(); - result.getSourceNodes().ifPresent(sourceNodes::addAll); - schemasByHost.putIfAbsent(node, new ArrayList<>()); - schemasByHost.get(node).add(result.getSchema()); - tableRows.addAll(result.getTableRows()); + future.get(); } catch (ExecutionException e) { LOG.warn("Error routing query {} to host {} at timestamp {} with exception {}", - statement.getStatementText(), node, System.currentTimeMillis(), e.getCause()); + statement.getStatementText(), node, System.currentTimeMillis(), e.getCause()); nextRoundRemaining.addAll(groupedByHost.get(node)); } } @@ -177,12 +201,8 @@ public PullQueryResult handlePullQuery( // If there are no partition locations remaining, then we're done. if (remainingLocations.size() == 0) { - final LogicalSchema schema = validateSchemas(schemasByHost); - return new PullQueryResult( - tableRows, - sourceNodes.isEmpty() ? Optional.empty() : Optional.of(sourceNodes), - schema, - queryId); + pullQueryQueue.close(); + return; } } } @@ -216,21 +236,22 @@ private static Map> groupByHost( @VisibleForTesting interface RouteQuery { - PullQueryResult routeQuery( + void routeQuery( KsqlNode node, List locations, ConfiguredStatement statement, ServiceContext serviceContext, RoutingOptions routingOptions, Optional pullQueryMetrics, - PullPhysicalPlan pullPhysicalPlan, - LogicalSchema outputSchema, - QueryId queryId + PullPhysicalPlan pullPhysicalPlan, + LogicalSchema outputSchema, + QueryId queryId, + PullQueryQueue pullQueryQueue ); } @VisibleForTesting - static PullQueryResult executeOrRouteQuery( + static void executeOrRouteQuery( final KsqlNode node, final List locations, final ConfiguredStatement statement, @@ -239,16 +260,19 @@ static PullQueryResult executeOrRouteQuery( final Optional pullQueryMetrics, final PullPhysicalPlan pullPhysicalPlan, final LogicalSchema outputSchema, - final QueryId queryId + final QueryId queryId, + final PullQueryQueue pullQueryQueue ) { - List> rows = null; + final BiFunction, LogicalSchema, PullQueryRow> rowFactory = (rawRow, schema) -> + new PullQueryRow(rawRow, schema, Optional.ofNullable( + routingOptions.getIsDebugRequest() ? node : null)); if (node.isLocal()) { try { LOG.debug("Query {} executed locally at host {} at timestamp {}.", statement.getStatementText(), node.location(), System.currentTimeMillis()); pullQueryMetrics .ifPresent(queryExecutorMetrics -> queryExecutorMetrics.recordLocalRequests(1)); - rows = pullPhysicalPlan.execute(locations); + pullPhysicalPlan.execute(locations, pullQueryQueue, rowFactory); } catch (Exception e) { LOG.error("Error executing query {} locally at node {} with exception {}", statement.getStatementText(), node, e.getCause()); @@ -264,7 +288,8 @@ static PullQueryResult executeOrRouteQuery( statement.getStatementText(), node.location(), System.currentTimeMillis()); pullQueryMetrics .ifPresent(queryExecutorMetrics -> queryExecutorMetrics.recordRemoteRequests(1)); - rows = forwardTo(node, locations, statement, serviceContext); + forwardTo(node, locations, statement, serviceContext, pullQueryQueue, rowFactory, + outputSchema); } catch (Exception e) { LOG.error("Error forwarding query {} to node {} with exception {}", statement.getStatementText(), node, e.getCause()); @@ -275,21 +300,16 @@ static PullQueryResult executeOrRouteQuery( ); } } - final Optional> debugNodes = Optional.ofNullable( - routingOptions.getIsDebugRequest() - ? Collections.nCopies(rows.size(), node) : null); - return new PullQueryResult( - rows, - debugNodes, - outputSchema, - queryId); } - private static List> forwardTo( + private static void forwardTo( final KsqlNode owner, final List locations, final ConfiguredStatement statement, - final ServiceContext serviceContext + final ServiceContext serviceContext, + final PullQueryQueue pullQueryQueue, + final BiFunction, LogicalSchema, PullQueryRow> rowFactory, + final LogicalSchema outputSchema ) { // Specify the partitions we specifically want to read. This will prevent reading unintended @@ -302,13 +322,15 @@ private static List> forwardTo( KsqlRequestConfig.KSQL_REQUEST_QUERY_PULL_SKIP_FORWARDING, true, KsqlRequestConfig.KSQL_REQUEST_INTERNAL_REQUEST, true, KsqlRequestConfig.KSQL_REQUEST_QUERY_PULL_PARTITIONS, partitions); - final RestResponse> response = serviceContext + final RestResponse response = serviceContext .getKsqlClient() .makeQueryRequest( owner.location(), statement.getStatementText(), statement.getSessionConfig().getOverrides(), - requestProperties + requestProperties, + streamedRowsHandler(owner, statement, requestProperties, pullQueryQueue, rowFactory, + outputSchema) ); if (response.isErroneous()) { @@ -318,59 +340,78 @@ private static List> forwardTo( owner, response.getErrorMessage())); } - final List streamedRows = response.getResponse(); - if (streamedRows.isEmpty()) { + final int numRows = response.getResponse(); + if (numRows == 0) { throw new KsqlServerException(String.format( "Forwarding pull query request [%s, %s, %s] to node %s failed due to invalid " + "empty response from forwarding call, expected a header row.", statement.getStatement(), statement.getSessionConfig().getOverrides(), requestProperties, owner)); } + } - final List> rows = new ArrayList<>(); - - for (final StreamedRow row : streamedRows.subList(1, streamedRows.size())) { - if (row.getErrorMessage().isPresent()) { - throw new KsqlStatementException( - row.getErrorMessage().get().getMessage(), - statement.getStatementText() - ); + private static Consumer> streamedRowsHandler( + final KsqlNode owner, + final ConfiguredStatement statement, + final Map requestProperties, + final PullQueryQueue pullQueryQueue, + final BiFunction, LogicalSchema, PullQueryRow> rowFactory, + final LogicalSchema outputSchema + ) { + final AtomicInteger processedRows = new AtomicInteger(0); + final AtomicReference
header = new AtomicReference<>(); + return streamedRows -> { + if (streamedRows == null || streamedRows.isEmpty()) { + return; } + final List rows = new ArrayList<>(); - if (!row.getRow().isPresent()) { - throw new KsqlServerException(String.format( - "Forwarding pull query request [%s, %s, %s] to node %s failed due to " - + "missing row data.", - statement.getStatement(), statement.getSessionConfig().getOverrides(), - requestProperties, owner)); - } + // If this is the first row overall, skip the header + final int previousProcessedRows = processedRows.getAndAdd(streamedRows.size()); + for (int i = 0; i < streamedRows.size(); i++) { + final StreamedRow row = streamedRows.get(i); + if (i == 0 && previousProcessedRows == 0) { + final Optional
optionalHeader = row.getHeader(); + optionalHeader.ifPresent(h -> validateSchema(outputSchema, h.getSchema(), owner)); + optionalHeader.ifPresent(header::set); + continue; + } - rows.add(row.getRow().get().getColumns()); - } + if (row.getErrorMessage().isPresent()) { + throw new KsqlStatementException( + row.getErrorMessage().get().getMessage(), + statement.getStatementText() + ); + } - return rows; - } + if (!row.getRow().isPresent()) { + throw new KsqlServerException(String.format( + "Forwarding pull query request [%s, %s, %s] to node %s failed due to " + + "missing row data.", + statement.getStatement(), statement.getSessionConfig().getOverrides(), + requestProperties, owner)); + } - private LogicalSchema validateSchemas(final Map> schemasByNode) { - LogicalSchema compareAgainst = null; - KsqlNode host = null; - for (Entry> entry: schemasByNode.entrySet()) { - final KsqlNode node = entry.getKey(); - final List schemas = entry.getValue(); - if (compareAgainst == null) { - compareAgainst = schemas.get(0); - host = node; + final List r = row.getRow().get().getColumns(); + Preconditions.checkNotNull(header.get()); + rows.add(rowFactory.apply(r, header.get().getSchema())); } - for (LogicalSchema s : schemas) { - if (!s.equals(compareAgainst)) { - throw new KsqlException(String.format( - "Schemas %s from host %s differs from schema %s from hosts %s", - s, node, compareAgainst, host)); - } + + if (!pullQueryQueue.acceptRows(rows)) { + LOG.info("Failed to queue all rows"); } - } - return compareAgainst; + }; } - + private static void validateSchema( + final LogicalSchema expectedSchema, + final LogicalSchema forwardedSchema, + final KsqlNode forwardedNode + ) { + if (!forwardedSchema.equals(expectedSchema)) { + throw new KsqlException(String.format( + "Schemas %s from host %s differs from schema %s", + forwardedSchema, forwardedNode, expectedSchema)); + } + } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullPhysicalPlan.java b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullPhysicalPlan.java index 8d6f24ab0da0..b68b188e2f35 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullPhysicalPlan.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullPhysicalPlan.java @@ -20,11 +20,14 @@ import io.confluent.ksql.execution.streams.materialization.Materialization; import io.confluent.ksql.physical.pull.operators.AbstractPhysicalOperator; import io.confluent.ksql.physical.pull.operators.DataSourceOperator; +import io.confluent.ksql.query.PullQueryQueue; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.schema.ksql.LogicalSchema; -import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.function.BiFunction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Represents the physical plan for pull queries. It is a tree of physical operators that gets @@ -33,6 +36,8 @@ * the data stores. */ public class PullPhysicalPlan { + private static final Logger LOGGER = LoggerFactory.getLogger(PullPhysicalPlan.class); + private final AbstractPhysicalOperator root; private final LogicalSchema schema; private final QueryId queryId; @@ -57,22 +62,30 @@ public PullPhysicalPlan( dataSourceOperator, "dataSourceOperator"); } - public List> execute( - final List locations) { + public void execute( + final List locations, + final PullQueryQueue pullQueryQueue, + final BiFunction, LogicalSchema, PullQueryRow> rowFactory) { // We only know at runtime which partitions to get from which node. // That's why we need to set this explicitly for the dataSource operators dataSourceOperator.setPartitionLocations(locations); open(); - final List> localResult = new ArrayList<>(); - List row = null; + List row; while ((row = (List)next()) != null) { - localResult.add(row); + if (pullQueryQueue.isClosed()) { + // If the queue has been closed, we stop adding rows and cleanup. This should be triggered + // because the client has closed their connection with the server before the results have + // completed. + LOGGER.info("Queue closed before results completed. Stopping execution."); + break; + } + if (!pullQueryQueue.acceptRow(rowFactory.apply(row, schema))) { + LOGGER.info("Failed to queue row"); + } } close(); - - return localResult; } private void open() { diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullQueryQueuePopulator.java b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullQueryQueuePopulator.java new file mode 100644 index 000000000000..926ee1d90307 --- /dev/null +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullQueryQueuePopulator.java @@ -0,0 +1,29 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.physical.pull; + +import java.util.concurrent.CompletableFuture; + +public interface PullQueryQueuePopulator { + + /** + * Runs the pull query asynchronously. When the returned future is complete, the pull query has + * run to completion and every row has been added to the PullQueryQueue. If there's an error + * during completion, the future will also complete with an error. + * @return The future + */ + CompletableFuture run(); +} diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullQueryResult.java b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullQueryResult.java index c964e8abec57..fe6d92991679 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullQueryResult.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullQueryResult.java @@ -15,30 +15,41 @@ package io.confluent.ksql.physical.pull; -import io.confluent.ksql.execution.streams.materialization.Locator.KsqlNode; +import com.google.common.base.Preconditions; +import io.confluent.ksql.internal.PullQueryExecutorMetrics; +import io.confluent.ksql.query.PullQueryQueue; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.schema.ksql.LogicalSchema; -import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; public class PullQueryResult { - private final List> tableRows; - private final Optional> sourceNodes; private final LogicalSchema schema; + private final PullQueryQueuePopulator populator; private final QueryId queryId; + private final PullQueryQueue pullQueryQueue; + private final Optional pullQueryMetrics; + + // This future is used to keep track of all of the callbacks since we allow for adding them both + // before and after the pull query has been started. When the pull query has completed, it will + // pass on the outcome to this future. + private CompletableFuture future = new CompletableFuture<>(); + private boolean started = false; public PullQueryResult( - final List> tableRowsEntity, - final Optional> sourceNodes, final LogicalSchema schema, - final QueryId queryId + final PullQueryQueuePopulator populator, + final QueryId queryId, + final PullQueryQueue pullQueryQueue, + final Optional pullQueryMetrics ) { - - this.tableRows = tableRowsEntity; - this.sourceNodes = sourceNodes; this.schema = schema; + this.populator = populator; this.queryId = queryId; + this.pullQueryQueue = pullQueryQueue; + this.pullQueryMetrics = pullQueryMetrics; } public LogicalSchema getSchema() { @@ -49,12 +60,34 @@ public QueryId getQueryId() { return queryId; } - public List> getTableRows() { - return tableRows; + public PullQueryQueue getPullQueryQueue() { + return pullQueryQueue; + } + + public void start() { + Preconditions.checkState(!started, "Should only start once"); + started = true; + final CompletableFuture f = populator.run(); + f.exceptionally(t -> { + future.completeExceptionally(t); + return null; + }); + f.thenAccept(future::complete); } - public Optional> getSourceNodes() { - return sourceNodes; + public void stop() { + pullQueryQueue.close(); } + public void onException(final Consumer consumer) { + future.exceptionally(t -> { + pullQueryMetrics.ifPresent(metrics -> metrics.recordErrorRate(1)); + consumer.accept(t); + return null; + }); + } + + public void onCompletion(final Consumer consumer) { + future.thenAccept(consumer::accept); + } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullQueryRow.java b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullQueryRow.java new file mode 100644 index 000000000000..779fd53bcdf1 --- /dev/null +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/physical/pull/PullQueryRow.java @@ -0,0 +1,58 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.physical.pull; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.streams.materialization.Locator.KsqlNode; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import java.util.List; +import java.util.Optional; + +public class PullQueryRow { + + private final List row; + private final LogicalSchema schema; + private final Optional sourceNode; + + public PullQueryRow( + final List row, + final LogicalSchema schema, + final Optional sourceNode) { + this.row = row; + this.schema = schema; + this.sourceNode = sourceNode; + } + + public List getRow() { + return row; + } + + public LogicalSchema getSchema() { + return schema; + } + + public Optional getSourceNode() { + return sourceNode; + } + + public GenericRow getGenericRow() { + return toGenericRow(row); + } + + private static GenericRow toGenericRow(final List values) { + return new GenericRow().appendAll(values); + } +} diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/query/PullQueryQueue.java b/ksqldb-engine/src/main/java/io/confluent/ksql/query/PullQueryQueue.java new file mode 100644 index 000000000000..906b764e7874 --- /dev/null +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/query/PullQueryQueue.java @@ -0,0 +1,218 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"; you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.query; + +import io.confluent.ksql.GenericRow; +import io.confluent.ksql.physical.pull.PullQueryRow; +import io.confluent.ksql.util.KeyValue; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This queue allows for results to be streamed back to the client when running pull queries. + * Streaming behavior is important when dealing with large results since we don't want to hold it + * all in memory at once. + * + *

New rows are produced and enqueued by PullPhysicalPlan if the request is being handled locally + * or HARouting if the request must be forwarded to another node. This is done with the method + * acceptRow and may block the caller if the queue is at capacity. + * + *

Rows are consumed by the request thread of the endpoint. This is done with the various poll + * methods. + */ +public class PullQueryQueue implements BlockingRowQueue { + private static final Logger LOG = LoggerFactory.getLogger(PullQueryQueue.class); + + // The capacity to allow before blocking when enqueuing + private static final int BLOCKING_QUEUE_CAPACITY = 50; + // The time to wait while enqueuing a row before quitting to retry + private static final long DEFAULT_OFFER_TIMEOUT_MS = 100; + + private final BlockingQueue rowQueue; + private final long offerTimeoutMs; + private AtomicBoolean closed = new AtomicBoolean(false); + + /** + * The callback run when we've hit the end of the data. Specifically, this happens when + * {@link #close()} is called. + */ + private LimitHandler limitHandler; + /** + * Callback is checked before enqueueing new rows and called when new rows are actually added. + */ + private Runnable queuedCallback; + + public PullQueryQueue() { + this(BLOCKING_QUEUE_CAPACITY, DEFAULT_OFFER_TIMEOUT_MS); + } + + public PullQueryQueue( + final int queueSizeLimit, + final long offerTimeoutMs) { + this.queuedCallback = () -> { }; + this.limitHandler = () -> { }; + this.rowQueue = new ArrayBlockingQueue<>(queueSizeLimit); + this.offerTimeoutMs = offerTimeoutMs; + } + + @Override + public void setLimitHandler(final LimitHandler limitHandler) { + this.limitHandler = limitHandler; + } + + @Override + public void setQueuedCallback(final Runnable queuedCallback) { + final Runnable parent = this.queuedCallback; + + this.queuedCallback = () -> { + parent.run(); + queuedCallback.run(); + }; + } + + @Override + public KeyValue, GenericRow> poll(final long timeout, final TimeUnit unit) + throws InterruptedException { + return pullQueryRowToKeyValue(rowQueue.poll(timeout, unit)); + } + + @Override + public KeyValue, GenericRow> poll() { + return pullQueryRowToKeyValue(rowQueue.poll()); + } + + @Override + public void drainTo(final Collection, GenericRow>> collection) { + final List list = new ArrayList<>(); + drainRowsTo(list); + list.stream() + .map(PullQueryQueue::pullQueryRowToKeyValue) + .forEach(collection::add); + } + + /** + * Similar to {@link #poll(long, TimeUnit)} , but returns a {@link PullQueryRow}. + */ + public PullQueryRow pollRow(final long timeout, final TimeUnit unit) throws InterruptedException { + return rowQueue.poll(timeout, unit); + } + + /** + * Similar to {@link #drainTo(Collection)}, but takes {@link PullQueryRow}s. + */ + public void drainRowsTo(final Collection collection) { + rowQueue.drainTo(collection); + } + + @Override + public int size() { + return rowQueue.size(); + } + + @Override + public boolean isEmpty() { + return rowQueue.isEmpty(); + } + + /** + * Unlike push queries that run forever until someone deliberately kills it, pull queries have an + * ending. When they've reached their end, this is expected to be called. Also, if the system + * wants to end pull queries prematurely, such as when the client connection closes, this should + * also be called then. + */ + @Override + public void close() { + if (!closed.getAndSet(true)) { + // Unlike limits based on a number of rows which can be checked and possibly triggered after + // every queuing of a row, pull queries just declare they've reached their limit when close is + // called. + limitHandler.limitReached(); + } + } + + public boolean isClosed() { + return closed.get(); + } + + /** + * Similar to {@link #acceptRow(PullQueryRow)} but takes many rows. + * @param tableRows The rows to enqueue. + */ + public boolean acceptRows(final List tableRows) { + if (tableRows == null) { + return false; + } + for (PullQueryRow row : tableRows) { + if (!acceptRow(row)) { + return false; + } + } + return true; + } + + private static KeyValue, GenericRow> pullQueryRowToKeyValue(final PullQueryRow row) { + if (row == null) { + return null; + } + return KeyValue.keyValue(null, row.getGenericRow()); + } + + /** + * Enqueues a row on the queue. Blocks until the row can be accepted. + * @param row The row to enqueue. + */ + public boolean acceptRow(final PullQueryRow row) { + try { + if (row == null) { + return false; + } + + while (!closed.get()) { + if (rowQueue.offer(row, offerTimeoutMs, TimeUnit.MILLISECONDS)) { + queuedCallback.run(); + return true; + } + } + } catch (final InterruptedException e) { + // Forced shutdown? + LOG.error("Interrupted while trying to offer row to queue", e); + Thread.currentThread().interrupt(); + } + return false; + } + + /** + * If you don't want to rely on poll timeouts, a sentinel can be directly used, rather than + * interrupting the sleeping thread. The main difference between this and acceptRow is that + * this allows the addition of the sentinel even if the queue is closed. + * @param row The row to use as the sentinel + */ + public void putSentinelRow(final PullQueryRow row) { + try { + rowQueue.put(row); + } catch (InterruptedException e) { + LOG.error("Interrupted while trying to put row into queue", e); + Thread.currentThread().interrupt(); + } + } +} diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/physical/pull/HARoutingTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/physical/pull/HARoutingTest.java index 2bd061037fc3..7e06abeff888 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/physical/pull/HARoutingTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/physical/pull/HARoutingTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -35,6 +36,7 @@ import io.confluent.ksql.execution.streams.materialization.MaterializationException; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.physical.pull.HARouting.RouteQuery; +import io.confluent.ksql.query.PullQueryQueue; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.services.ServiceContext; @@ -43,6 +45,9 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -57,6 +62,11 @@ public class HARoutingTest { private static final List ROW1 = ImmutableList.of("a", "b"); private static final List ROW2 = ImmutableList.of("c", "d"); + private static final LogicalSchema SCHEMA = LogicalSchema.builder().build(); + private static final PullQueryRow PQ_ROW1 = new PullQueryRow(ROW1, SCHEMA, Optional.empty()); + private static final PullQueryRow PQ_ROW2 = new PullQueryRow(ROW2, SCHEMA, Optional.empty()); + + @Mock private ConfiguredStatement statement; @Mock @@ -92,6 +102,8 @@ public class HARoutingTest { @Mock private KsqlConfig ksqlConfig; + private PullQueryQueue pullQueryQueue = new PullQueryQueue(); + private HARouting haRouting; @Before @@ -101,7 +113,8 @@ public void setUp() { when(location2.getNodes()).thenReturn(ImmutableList.of(node2, node1)); when(location3.getNodes()).thenReturn(ImmutableList.of(node1, node2)); when(location4.getNodes()).thenReturn(ImmutableList.of(node2, node1)); - when(ksqlConfig.getInt(KsqlConfig.KSQL_QUERY_PULL_THREAD_POOL_SIZE_CONFIG)).thenReturn(1); + // We require at least two threads, one for the orchestrator, and the other for the partitions. + when(ksqlConfig.getInt(KsqlConfig.KSQL_QUERY_PULL_THREAD_POOL_SIZE_CONFIG)).thenReturn(2); haRouting = new HARouting( routingFilterFactory, Optional.empty(), ksqlConfig, routeQuery); } @@ -114,12 +127,8 @@ public void tearDown() { } @Test - public void shouldCallRouteQuery_success() throws InterruptedException { + public void shouldCallRouteQuery_success() throws InterruptedException, ExecutionException { // Given: - final PullQueryResult pullQueryResult1 = new PullQueryResult( - ImmutableList.of(ROW1), Optional.empty(), logicalSchema, queryId); - final PullQueryResult pullQueryResult2 = new PullQueryResult( - ImmutableList.of(ROW2), Optional.empty(), logicalSchema, queryId); List locations = ImmutableList.of(location1, location2, location3, location4); when(pullPhysicalPlan.getMaterialization()).thenReturn(materialization); when(pullPhysicalPlan.getMaterialization().locator()).thenReturn(locator); @@ -129,41 +138,45 @@ public void shouldCallRouteQuery_success() throws InterruptedException { routingFilterFactory )).thenReturn(locations); List> locationsQueried = new ArrayList<>(); - when(routeQuery.routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), any())) - .thenAnswer(inv -> { - locationsQueried.add(inv.getArgument(1)); - return pullQueryResult1; - }); - when(routeQuery.routeQuery(eq(node2), any(), any(), any(), any(), any(), any(), any(), any())) - .thenAnswer(inv -> { - locationsQueried.add(inv.getArgument(1)); - return pullQueryResult2; - }); + doAnswer(inv -> { + locationsQueried.add(inv.getArgument(1)); + PullQueryQueue queue = inv.getArgument(9); + queue.acceptRow(PQ_ROW1); + return null; + }).when(routeQuery).routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), + any(), any()); + doAnswer(inv -> { + locationsQueried.add(inv.getArgument(1)); + PullQueryQueue queue = inv.getArgument(9); + queue.acceptRow(PQ_ROW2); + return null; + }).when(routeQuery).routeQuery(eq(node2), any(), any(), any(), any(), any(), any(), any(), + any(), any()); // When: - PullQueryResult result = haRouting.handlePullQuery(serviceContext, pullPhysicalPlan, statement, - routingOptions, logicalSchema, queryId); + CompletableFuture future = haRouting.handlePullQuery( + serviceContext, pullPhysicalPlan, statement, routingOptions, logicalSchema, queryId, + pullQueryQueue); + future.get(); // Then: - verify(routeQuery).routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), any()); + verify(routeQuery).routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), any(), + any()); assertThat(locationsQueried.get(0).get(0), is(location1)); assertThat(locationsQueried.get(0).get(1), is(location3)); - verify(routeQuery).routeQuery(eq(node2), any(), any(), any(), any(), any(), any(), any(), any()); + verify(routeQuery).routeQuery(eq(node2), any(), any(), any(), any(), any(), any(), any(), any(), + any()); assertThat(locationsQueried.get(1).get(0), is(location2)); assertThat(locationsQueried.get(1).get(1), is(location4)); - assertThat(result.getTableRows().size(), is(2)); - assertThat(result.getTableRows().get(0), is(ROW1)); - assertThat(result.getTableRows().get(1), is(ROW2)); + assertThat(pullQueryQueue.size(), is(2)); + assertThat(pullQueryQueue.pollRow(1, TimeUnit.SECONDS).getRow(), is(ROW1)); + assertThat(pullQueryQueue.pollRow(1, TimeUnit.SECONDS).getRow(), is(ROW2)); } @Test - public void shouldCallRouteQuery_twoRound() throws InterruptedException { + public void shouldCallRouteQuery_twoRound() throws InterruptedException, ExecutionException { // Given: - PullQueryResult pullQueryResult1 = new PullQueryResult( - ImmutableList.of(ROW1), Optional.empty(), logicalSchema, queryId); - PullQueryResult pullQueryResult2 = new PullQueryResult( - ImmutableList.of(ROW2), Optional.empty(), logicalSchema, queryId); List locations = ImmutableList.of(location1, location2, location3, location4); when(pullPhysicalPlan.getMaterialization()).thenReturn(materialization); when(pullPhysicalPlan.getMaterialization().locator()).thenReturn(locator); @@ -173,48 +186,52 @@ public void shouldCallRouteQuery_twoRound() throws InterruptedException { routingFilterFactory )).thenReturn(locations); List> locationsQueried = new ArrayList<>(); - when(routeQuery.routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), any())) - .thenAnswer(inv -> { - locationsQueried.add(inv.getArgument(1)); - throw new RuntimeException("Error!"); - }); - when(routeQuery.routeQuery(eq(node2), any(), any(), any(), any(), any(), any(), any(), any())) - .thenAnswer(new Answer() { + doAnswer(inv -> { + locationsQueried.add(inv.getArgument(1)); + throw new RuntimeException("Error!"); + }).when(routeQuery).routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), + any(), any()); + doAnswer(new Answer() { private int count = 0; public Object answer(InvocationOnMock invocation) { locationsQueried.add(invocation.getArgument(1)); - if (++count == 1) - return pullQueryResult2; - - return pullQueryResult1; + PullQueryQueue queue = invocation.getArgument(9); + if (++count == 1) { + queue.acceptRow(PQ_ROW2); + } else { + queue.acceptRow(PQ_ROW1); + } + return null; } - }); + }).when(routeQuery).routeQuery(eq(node2), any(), any(), any(), any(), any(), any(), any(), + any(), any()); // When: - PullQueryResult result = haRouting.handlePullQuery(serviceContext, pullPhysicalPlan, statement, - routingOptions, logicalSchema, queryId); + CompletableFuture future = haRouting.handlePullQuery(serviceContext, pullPhysicalPlan, + statement, routingOptions, logicalSchema, queryId, pullQueryQueue); + future.get(); // Then: - verify(routeQuery).routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), any()); + verify(routeQuery).routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), any(), + any()); assertThat(locationsQueried.get(0).get(0), is(location1)); assertThat(locationsQueried.get(0).get(1), is(location3)); - verify(routeQuery, times(2)).routeQuery(eq(node2), any(), any(), any(), any(), any(), any(), any(), any()); + verify(routeQuery, times(2)).routeQuery(eq(node2), any(), any(), any(), any(), any(), any(), + any(), any(), any()); assertThat(locationsQueried.get(1).get(0), is(location2)); assertThat(locationsQueried.get(1).get(1), is(location4)); assertThat(locationsQueried.get(2).get(0), is(location1)); assertThat(locationsQueried.get(2).get(1), is(location3)); - assertThat(result.getTableRows().size(), is(2)); - assertThat(result.getTableRows().get(0), is(ROW2)); - assertThat(result.getTableRows().get(1), is(ROW1)); + assertThat(pullQueryQueue.size(), is(2)); + assertThat(pullQueryQueue.pollRow(1, TimeUnit.SECONDS).getRow(), is(ROW2)); + assertThat(pullQueryQueue.pollRow(1, TimeUnit.SECONDS).getRow(), is(ROW1)); } @Test public void shouldCallRouteQuery_allFail() { // Given: - PullQueryResult pullQueryResult2 = new PullQueryResult( - ImmutableList.of(ROW2), Optional.empty(), logicalSchema, queryId); List locations = ImmutableList.of(location1, location2, location3, location4); when(pullPhysicalPlan.getMaterialization()).thenReturn(materialization); when(pullPhysicalPlan.getMaterialization().locator()).thenReturn(locator); @@ -224,42 +241,50 @@ public void shouldCallRouteQuery_allFail() { routingFilterFactory )).thenReturn(locations); List> locationsQueried = new ArrayList<>(); - when(routeQuery.routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), any())) - .thenAnswer(inv -> { - locationsQueried.add(inv.getArgument(1)); - throw new RuntimeException("Error!"); - }); - when(routeQuery.routeQuery(eq(node2), any(), any(), any(), any(), any(), any(), any(), any())) - .thenAnswer(new Answer() { - private int count = 0; - - public Object answer(InvocationOnMock invocation) { - locationsQueried.add(invocation.getArgument(1)); - if (++count == 1) - return pullQueryResult2; + doAnswer(inv -> { + locationsQueried.add(inv.getArgument(1)); + throw new RuntimeException("Error!"); + }).when(routeQuery).routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), + any(), any()); + doAnswer(new Answer() { + private int count = 0; - throw new RuntimeException("Error!"); - } - }); + public Object answer(InvocationOnMock invocation) { + locationsQueried.add(invocation.getArgument(1)); + PullQueryQueue queue = invocation.getArgument(9); + if (++count == 1) { + queue.acceptRow(PQ_ROW2); + } else { + throw new RuntimeException("Error!"); + } + return null; + } + }).when(routeQuery).routeQuery(eq(node2), any(), any(), any(), any(), any(), any(), any(), + any(), any()); // When: final Exception e = assertThrows( - MaterializationException.class, - () -> haRouting.handlePullQuery(serviceContext, pullPhysicalPlan, statement, routingOptions, - logicalSchema, queryId) + ExecutionException.class, + () -> { + CompletableFuture future = haRouting.handlePullQuery(serviceContext, + pullPhysicalPlan, statement, routingOptions, logicalSchema, queryId, pullQueryQueue); + future.get(); + } ); // Then: - verify(routeQuery).routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), any()); + verify(routeQuery).routeQuery(eq(node1), any(), any(), any(), any(), any(), any(), any(), any(), + any()); assertThat(locationsQueried.get(0).get(0), is(location1)); assertThat(locationsQueried.get(0).get(1), is(location3)); - verify(routeQuery, times(2)).routeQuery(eq(node2), any(), any(), any(), any(), any(), any(), any(), any()); + verify(routeQuery, times(2)).routeQuery(eq(node2), any(), any(), any(), any(), any(), any(), + any(), any(), any()); assertThat(locationsQueried.get(1).get(0), is(location2)); assertThat(locationsQueried.get(1).get(1), is(location4)); assertThat(locationsQueried.get(2).get(0), is(location1)); assertThat(locationsQueried.get(2).get(1), is(location3)); - assertThat(e.getMessage(), containsString("Unable to execute pull query: foo. " + assertThat(e.getCause().getMessage(), containsString("Unable to execute pull query: foo. " + "Exhausted standby hosts to try.")); } @@ -280,7 +305,7 @@ public void shouldCallRouteQuery_allFiltered() { final Exception e = assertThrows( MaterializationException.class, () -> haRouting.handlePullQuery(serviceContext, pullPhysicalPlan, statement, routingOptions, - logicalSchema, queryId) + logicalSchema, queryId, pullQueryQueue) ); // Then: diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/query/PullQueryQueueTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/query/PullQueryQueueTest.java new file mode 100644 index 000000000000..0d4a78dff6bc --- /dev/null +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/query/PullQueryQueueTest.java @@ -0,0 +1,147 @@ +package io.confluent.ksql.query; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import io.confluent.ksql.physical.pull.PullQueryRow; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class PullQueryQueueTest { + private static final int QUEUE_SIZE = 5; + + private static final PullQueryRow VAL_ONE = mock(PullQueryRow.class); + private static final PullQueryRow VAL_TWO = mock(PullQueryRow.class); + + @Rule + public final Timeout timeout = Timeout.seconds(10); + + @Mock + private LimitHandler limitHandler; + @Mock + private Runnable queuedCallback; + private PullQueryQueue queue; + private ScheduledExecutorService executorService; + + @Before + public void setUp() { + givenQueue(); + } + + @After + public void tearDown() { + if (executorService != null) { + executorService.shutdownNow(); + } + } + + @Test + public void shouldQueue() { + // When: + queue.acceptRow(VAL_ONE); + queue.acceptRow(VAL_TWO); + + // Then: + assertThat(drainValues(), contains(VAL_ONE, VAL_TWO)); + verify(queuedCallback, times(2)).run(); + } + + @Test + public void shouldQueueUntilClosed() { + // When: + IntStream.range(0, QUEUE_SIZE) + .forEach(idx -> { + queue.acceptRow(VAL_ONE); + if (idx == 2) { + queue.close(); + } + }); + + // Then: + assertThat(queue.size(), is(3)); + verify(queuedCallback, times(3)).run(); + } + + @Test + public void shouldPoll() throws Exception { + // Given: + queue.acceptRow(VAL_ONE); + queue.acceptRow(VAL_TWO); + + // When: + final PullQueryRow result1 = queue.pollRow(1, TimeUnit.SECONDS); + final PullQueryRow result2 = queue.pollRow(1, TimeUnit.SECONDS); + final PullQueryRow result3 = queue.pollRow(1, TimeUnit.MICROSECONDS); + + // Then: + assertThat(result1, is(VAL_ONE)); + assertThat(result2, is(VAL_TWO)); + assertThat(result3, is(nullValue())); + verify(queuedCallback, times(2)).run(); + } + + @Test + public void shouldCallLimitHandlerOnClose() { + // When: + queue.close(); + queue.close(); + + // Then: + verify(limitHandler, times(1)).limitReached(); + } + + + @Test + public void shouldBlockOnProduceOnceQueueLimitReachedAndUnblockOnClose() { + // Given: + givenQueue(); + + IntStream.range(0, QUEUE_SIZE) + .forEach(idx -> queue.acceptRow(VAL_ONE)); + + givenWillCloseQueueAsync(); + + // When: + queue.acceptRow(VAL_TWO); + + // Then: did not block and: + assertThat(queue.size(), is(QUEUE_SIZE)); + verify(queuedCallback, times(QUEUE_SIZE)).run(); + } + + private void givenWillCloseQueueAsync() { + executorService = Executors.newSingleThreadScheduledExecutor(); + executorService.schedule(queue::close, 200, TimeUnit.MILLISECONDS); + } + + private void givenQueue() { + queue = new PullQueryQueue(QUEUE_SIZE, 1); + + queue.setLimitHandler(limitHandler); + queue.setQueuedCallback(queuedCallback); + } + + private List drainValues() { + final List entries = new ArrayList<>(); + queue.drainRowsTo(entries); + return entries; + } +} diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/services/DisabledKsqlClient.java b/ksqldb-execution/src/main/java/io/confluent/ksql/services/DisabledKsqlClient.java index 714b9bb435e4..4a432575a74d 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/services/DisabledKsqlClient.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/services/DisabledKsqlClient.java @@ -24,6 +24,7 @@ import java.net.URI; import java.util.List; import java.util.Map; +import java.util.function.Consumer; /** * A KSQL client implementation for use when communication with other nodes is not supported. @@ -55,6 +56,17 @@ public RestResponse> makeQueryRequest( throw new UnsupportedOperationException("KSQL client is disabled"); } + @Override + public RestResponse makeQueryRequest( + final URI serverEndPoint, + final String sql, + final Map configOverrides, + final Map requestProperties, + final Consumer> rowConsumer + ) { + throw new UnsupportedOperationException("KSQL client is disabled"); + } + @Override public void makeAsyncHeartbeatRequest( final URI serverEndPoint, diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/services/SimpleKsqlClient.java b/ksqldb-execution/src/main/java/io/confluent/ksql/services/SimpleKsqlClient.java index d215f1663fda..b61f352c78f2 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/services/SimpleKsqlClient.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/services/SimpleKsqlClient.java @@ -24,6 +24,7 @@ import java.net.URI; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import javax.annotation.concurrent.ThreadSafe; @ThreadSafe @@ -57,6 +58,24 @@ RestResponse> makeQueryRequest( Map requestProperties ); + /** + * Send pull query request to remote Ksql server. This version of makeQueryRequest allows + * consuming the rows as they stream in rather than aggregating them all in one list. + * @param serverEndPoint the remote destination + * @param sql the pull query statement + * @param configOverrides the config overrides provided by the client + * @param requestProperties the request metadata provided by the server + * @param rowConsumer A consumer that's fed lists of rows as they stream in + * @return the number of rows returned by pull query + */ + RestResponse makeQueryRequest( + URI serverEndPoint, + String sql, + Map configOverrides, + Map requestProperties, + Consumer> rowConsumer + ); + /** * Send heartbeat to remote Ksql server. * @param serverEndPoint the remote destination. diff --git a/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-against-materialized-aggregates.json b/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-against-materialized-aggregates.json index d7b9ff3b30e0..39921d2549cc 100644 --- a/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-against-materialized-aggregates.json +++ b/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-against-materialized-aggregates.json @@ -258,19 +258,19 @@ "statements": [ "CREATE STREAM INPUT (ID DOUBLE KEY, IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", "CREATE TABLE AGGREGATE AS SELECT ID, COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ID;", - "SELECT * FROM AGGREGATE WHERE ID=10.1;", + "SELECT * FROM AGGREGATE WHERE ID=10.5;", "SELECT * FROM AGGREGATE WHERE ID=0;" ], "inputs": [ - {"topic": "test_topic", "timestamp": 12346, "key": 11.1, "value": {"val": 1}}, - {"topic": "test_topic", "timestamp": 12345, "key": 10.1, "value": {"val": 2}} + {"topic": "test_topic", "timestamp": 12346, "key": 11.5, "value": {"val": 1}}, + {"topic": "test_topic", "timestamp": 12345, "key": 10.5, "value": {"val": 2}} ], "responses": [ {"admin": {"@type": "currentStatus"}}, {"admin": {"@type": "currentStatus"}}, {"query": [ {"header":{"schema":"`ID` DOUBLE KEY, `WINDOWSTART` BIGINT KEY, `WINDOWEND` BIGINT KEY, `COUNT` BIGINT"}}, - {"row":{"columns":[10.1, 12000, 13000, 1]}} + {"row":{"columns":[10.5, 12000, 13000, 1]}} ]}, {"query": [ {"header":{"schema":"`ID` DOUBLE KEY, `WINDOWSTART` BIGINT KEY, `WINDOWEND` BIGINT KEY, `COUNT` BIGINT"}} @@ -706,18 +706,18 @@ "statements": [ "CREATE STREAM INPUT (ID DOUBLE KEY, IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", "CREATE TABLE AGGREGATE AS SELECT ID, COUNT(1) AS COUNT FROM INPUT GROUP BY ID;", - "SELECT ID, COUNT FROM AGGREGATE WHERE ID=10.1;" + "SELECT ID, COUNT FROM AGGREGATE WHERE ID=10.5;" ], "inputs": [ {"topic": "test_topic", "timestamp": 12345, "key": 11.2, "value": {}}, - {"topic": "test_topic", "timestamp": 12346, "key": 10.1, "value": {}} + {"topic": "test_topic", "timestamp": 12346, "key": 10.5, "value": {}} ], "responses": [ {"admin": {"@type": "currentStatus"}}, {"admin": {"@type": "currentStatus"}}, {"query": [ {"header":{"schema":"`ID` DOUBLE KEY, `COUNT` BIGINT"}}, - {"row":{"columns":[10.1, 1]}} + {"row":{"columns":[10.5, 1]}} ]} ] }, @@ -726,19 +726,19 @@ "statements": [ "CREATE STREAM INPUT (ID DOUBLE KEY, IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", "CREATE TABLE AGGREGATE AS SELECT ID, COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ID;", - "SELECT COUNT, ID FROM AGGREGATE WHERE ID=10.1 AND WindowStart=12000;" + "SELECT COUNT, ID FROM AGGREGATE WHERE ID=10.5 AND WindowStart=12000;" ], "inputs": [ - {"topic": "test_topic", "timestamp": 12345, "key": 11.1, "value": {}}, - {"topic": "test_topic", "timestamp": 11345, "key": 10.1, "value": {}}, - {"topic": "test_topic", "timestamp": 12345, "key": 10.1, "value": {}} + {"topic": "test_topic", "timestamp": 12345, "key": 11.5, "value": {}}, + {"topic": "test_topic", "timestamp": 11345, "key": 10.5, "value": {}}, + {"topic": "test_topic", "timestamp": 12345, "key": 10.5, "value": {}} ], "responses": [ {"admin": {"@type": "currentStatus"}}, {"admin": {"@type": "currentStatus"}}, {"query": [ {"header":{"schema":"`COUNT` BIGINT, `ID` DOUBLE KEY"}}, - {"row":{"columns":[1, 10.1]}} + {"row":{"columns":[1, 10.5]}} ]} ] }, @@ -1146,18 +1146,18 @@ "statements": [ "CREATE STREAM INPUT (ID DOUBLE KEY, IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", "CREATE TABLE AGGREGATE AS SELECT ID, COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ID;", - "SELECT ID, TIMESTAMPTOSTRING(WINDOWSTART, 'yyyy-MM-dd HH:mm:ss Z', 'UTC') AS WSTART, TIMESTAMPTOSTRING(WINDOWEND, 'yyyy-MM-dd HH:mm:ss Z', 'UTC') AS WEND, ROWTIME, COUNT FROM AGGREGATE WHERE ID=10.1 AND WINDOWSTART=1580223282000;" + "SELECT ID, TIMESTAMPTOSTRING(WINDOWSTART, 'yyyy-MM-dd HH:mm:ss Z', 'UTC') AS WSTART, TIMESTAMPTOSTRING(WINDOWEND, 'yyyy-MM-dd HH:mm:ss Z', 'UTC') AS WEND, ROWTIME, COUNT FROM AGGREGATE WHERE ID=10.5 AND WINDOWSTART=1580223282000;" ], "inputs": [ - {"topic": "test_topic", "timestamp": 1580223282123, "key": 11.1, "value": {"val": 1}}, - {"topic": "test_topic", "timestamp": 1580223282456, "key": 10.1, "value": {"val": 2}} + {"topic": "test_topic", "timestamp": 1580223282123, "key": 11.5, "value": {"val": 1}}, + {"topic": "test_topic", "timestamp": 1580223282456, "key": 10.5, "value": {"val": 2}} ], "responses": [ {"admin": {"@type": "currentStatus"}}, {"admin": {"@type": "currentStatus"}}, {"query": [ {"header":{"schema":"`ID` DOUBLE KEY, `WSTART` STRING, `WEND` STRING, `ROWTIME` BIGINT, `COUNT` BIGINT"}}, - {"row":{"columns":[10.1, "2020-01-28 14:54:42 +0000", "2020-01-28 14:54:43 +0000", 1580223282456, 1]}} + {"row":{"columns":[10.5, "2020-01-28 14:54:42 +0000", "2020-01-28 14:54:43 +0000", 1580223282456, 1]}} ]} ] }, @@ -1166,19 +1166,19 @@ "statements": [ "CREATE STREAM INPUT (ID DOUBLE KEY, IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", "CREATE TABLE AGGREGATE AS SELECT ID, COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ID;", - "SELECT ROWTIME, ID, WINDOWSTART, WINDOWEND FROM AGGREGATE WHERE ID=10.1 AND WindowStart=12000;" + "SELECT ROWTIME, ID, WINDOWSTART, WINDOWEND FROM AGGREGATE WHERE ID=10.5 AND WindowStart=12000;" ], "inputs": [ - {"topic": "test_topic", "timestamp": 12345, "key": 11.1, "value": {}}, - {"topic": "test_topic", "timestamp": 11345, "key": 10.1, "value": {}}, - {"topic": "test_topic", "timestamp": 12345, "key": 10.1, "value": {}} + {"topic": "test_topic", "timestamp": 12345, "key": 11.5, "value": {}}, + {"topic": "test_topic", "timestamp": 11345, "key": 10.5, "value": {}}, + {"topic": "test_topic", "timestamp": 12345, "key": 10.5, "value": {}} ], "responses": [ {"admin": {"@type": "currentStatus"}}, {"admin": {"@type": "currentStatus"}}, {"query": [ {"header":{"schema":"`ROWTIME` BIGINT, `ID` DOUBLE KEY, `WINDOWSTART` BIGINT KEY, `WINDOWEND` BIGINT KEY"}}, - {"row":{"columns":[12345, 10.1, 12000, 13000]}} + {"row":{"columns":[12345, 10.5, 12000, 13000]}} ]} ] }, @@ -1797,22 +1797,23 @@ "statements": [ "CREATE STREAM INPUT (ID DOUBLE KEY, IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", "CREATE TABLE AGGREGATE AS SELECT ID, COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ID;", - "SELECT * FROM AGGREGATE WHERE ID IN (10.1, 8.1);", + "SELECT * FROM AGGREGATE WHERE ID IN (10.5, 8.5);", "SELECT * FROM AGGREGATE WHERE ID IN (0, 1.0);" ], "inputs": [ - {"topic": "test_topic", "timestamp": 12346, "key": 11.1, "value": {"val": 1}}, - {"topic": "test_topic", "timestamp": 12345, "key": 10.1, "value": {"val": 2}}, - {"topic": "test_topic", "timestamp": 12366, "key": 9.1, "value": {"val": 3}}, - {"topic": "test_topic", "timestamp": 12367, "key": 8.1, "value": {"val": 4}}, - {"topic": "test_topic", "timestamp": 12368, "key": 12.1, "value": {"val": 5}} + {"topic": "test_topic", "timestamp": 12346, "key": 11.5, "value": {"val": 1}}, + {"topic": "test_topic", "timestamp": 12345, "key": 10.5, "value": {"val": 2}}, + {"topic": "test_topic", "timestamp": 12366, "key": 9.5, "value": {"val": 3}}, + {"topic": "test_topic", "timestamp": 12367, "key": 8.5, "value": {"val": 4}}, + {"topic": "test_topic", "timestamp": 12368, "key": 12.5, "value": {"val": 5}} ], "responses": [ {"admin": {"@type": "currentStatus"}}, {"admin": {"@type": "currentStatus"}}, {"query": [ {"header":{"schema":"`ID` DOUBLE KEY, `WINDOWSTART` BIGINT KEY, `WINDOWEND` BIGINT KEY, `COUNT` BIGINT"}}, - {"row":{"columns":[10.1, 12000, 13000, 1]}} + {"row":{"columns":[10.5, 12000, 13000, 1]}}, + {"row":{"columns":[8.5, 12000, 13000, 1]}} ]}, {"query": [ {"header":{"schema":"`ID` DOUBLE KEY, `WINDOWSTART` BIGINT KEY, `WINDOWEND` BIGINT KEY, `COUNT` BIGINT"}} diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/BlockingQueryPublisher.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/BlockingQueryPublisher.java index ca146caa9710..443999e91111 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/BlockingQueryPublisher.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/BlockingQueryPublisher.java @@ -17,7 +17,7 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import io.confluent.ksql.GenericRow; -import io.confluent.ksql.api.server.PushQueryHandle; +import io.confluent.ksql.api.server.QueryHandle; import io.confluent.ksql.api.spi.QueryPublisher; import io.confluent.ksql.query.BlockingRowQueue; import io.confluent.ksql.reactive.BasePublisher; @@ -47,7 +47,8 @@ public class BlockingQueryPublisher extends BasePublisher, Gene private final WorkerExecutor workerExecutor; private BlockingRowQueue queue; - private PushQueryHandle queryHandle; + private boolean isPullQuery; + private QueryHandle queryHandle; private List columnNames; private List columnTypes; private boolean complete; @@ -59,13 +60,21 @@ public BlockingQueryPublisher(final Context ctx, this.workerExecutor = Objects.requireNonNull(workerExecutor); } - public void setQueryHandle(final PushQueryHandle queryHandle) { + public void setQueryHandle(final QueryHandle queryHandle, final boolean isPullQuery) { this.columnNames = queryHandle.getColumnNames(); this.columnTypes = queryHandle.getColumnTypes(); this.queue = queryHandle.getQueue(); + this.isPullQuery = isPullQuery; this.queue.setQueuedCallback(this::maybeSend); - this.queue.setLimitHandler(() -> complete = true); + this.queue.setLimitHandler(() -> { + complete = true; + // This allows us to hit the limit without having to queue one last row + if (queue.isEmpty()) { + ctx.runOnContext(v -> sendComplete()); + } + }); this.queryHandle = queryHandle; + queryHandle.onException(t -> ctx.runOnContext(v -> sendError(t))); } @Override @@ -90,7 +99,7 @@ public void close() { @Override public boolean isPullQuery() { - return false; + return isPullQuery; } @Override @@ -122,6 +131,7 @@ private void doSend() { while (getDemand() > 0 && !queue.isEmpty()) { if (num < SEND_MAX_BATCH_SIZE) { doOnNext(queue.poll()); + if (complete && queue.isEmpty()) { ctx.runOnContext(v -> sendComplete()); } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/PullQueryPublisher.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/PullQueryPublisher.java deleted file mode 100644 index d1be615b348a..000000000000 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/PullQueryPublisher.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright 2020 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.api.impl; - -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.api.spi.QueryPublisher; -import io.confluent.ksql.reactive.BufferedPublisher; -import io.confluent.ksql.rest.entity.TableRows; -import io.confluent.ksql.util.KeyValue; -import io.vertx.core.Context; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; - -public class PullQueryPublisher - extends BufferedPublisher, GenericRow>> - implements QueryPublisher { - - private final List columnNames; - private final List columnTypes; - - public PullQueryPublisher(final Context ctx, final TableRows tableRows, - final List columnNames, final List columnTypes) { - super(ctx, toGenericRows(tableRows)); - this.columnNames = Objects.requireNonNull(columnNames); - this.columnTypes = Objects.requireNonNull(columnTypes); - } - - @SuppressWarnings({"unchecked", "rawtypes"}) - private static List, GenericRow>> toGenericRows(final TableRows tableRows) { - final List, GenericRow>> genericRows = - new ArrayList<>(tableRows.getRows().size()); - - for (List row : tableRows.getRows()) { - final GenericRow genericRow = GenericRow.fromList(row); - genericRows.add(KeyValue.keyValue(null, genericRow)); - } - return genericRows; - } - - @Override - public List getColumnNames() { - return columnNames; - } - - @Override - public List getColumnTypes() { - return columnTypes; - } - - @Override - public boolean isPullQuery() { - return true; - } -} diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/QueryEndpoint.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/QueryEndpoint.java index b36f7c7eebff..328b710bcbcc 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/QueryEndpoint.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/impl/QueryEndpoint.java @@ -17,7 +17,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.RateLimiter; -import io.confluent.ksql.api.server.PushQueryHandle; +import io.confluent.ksql.api.server.QueryHandle; import io.confluent.ksql.api.spi.QueryPublisher; import io.confluent.ksql.config.SessionConfig; import io.confluent.ksql.engine.KsqlEngine; @@ -33,7 +33,6 @@ import io.confluent.ksql.physical.pull.HARouting; import io.confluent.ksql.physical.pull.PullQueryResult; import io.confluent.ksql.query.BlockingRowQueue; -import io.confluent.ksql.rest.entity.TableRows; import io.confluent.ksql.rest.server.LocalCommands; import io.confluent.ksql.rest.server.resources.streaming.PullQueryConfigRoutingOptions; import io.confluent.ksql.schema.ksql.Column; @@ -51,6 +50,8 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; import java.util.stream.Collectors; import org.apache.kafka.common.utils.Time; @@ -96,8 +97,8 @@ public QueryPublisher createQueryPublisher( if (statement.getStatement().isPullQuery()) { return createPullQueryPublisher( - context, serviceContext, routingFilterFactory, statement, pullQueryMetrics, - startTimeNanos); + context, serviceContext, statement, pullQueryMetrics, + startTimeNanos, workerExecutor); } else { return createPushQueryPublisher(context, serviceContext, statement, workerExecutor); } @@ -116,7 +117,7 @@ private QueryPublisher createPushQueryPublisher( localCommands.ifPresent(lc -> lc.write(queryMetadata)); - publisher.setQueryHandle(new KsqlQueryHandle(queryMetadata)); + publisher.setQueryHandle(new KsqlQueryHandle(queryMetadata), false); return publisher; } @@ -124,10 +125,10 @@ private QueryPublisher createPushQueryPublisher( private QueryPublisher createPullQueryPublisher( final Context context, final ServiceContext serviceContext, - final RoutingFilterFactory routingFilterFactory, final ConfiguredStatement statement, final Optional pullQueryMetrics, - final long startTimeNanos + final long startTimeNanos, + final WorkerExecutor workerExecutor ) { final RoutingOptions routingOptions = new PullQueryConfigRoutingOptions( @@ -142,25 +143,20 @@ private QueryPublisher createPullQueryPublisher( serviceContext, statement, routing, - routingFilterFactory, routingOptions, - pullQueryMetrics + pullQueryMetrics, + false ); - pullQueryMetrics.ifPresent(p -> p.recordLatency(startTimeNanos)); + result.onCompletion(v -> { + pullQueryMetrics.ifPresent(p -> p.recordLatency(startTimeNanos)); + }); - final TableRows tableRows = new TableRows( - statement.getStatementText(), - result.getQueryId(), - result.getSchema(), - result.getTableRows()); + final BlockingQueryPublisher publisher = new BlockingQueryPublisher(context, workerExecutor); - return new PullQueryPublisher( - context, - tableRows, - colNamesFromSchema(tableRows.getSchema().columns()), - colTypesFromSchema(tableRows.getSchema().columns()) - ); + publisher.setQueryHandle(new KsqlPullQueryHandle(result, pullQueryMetrics), true); + + return publisher; } private ConfiguredStatement createStatement(final String queryString, @@ -196,7 +192,7 @@ private static List colNamesFromSchema(final List columns) { .collect(Collectors.toList()); } - private static class KsqlQueryHandle implements PushQueryHandle { + private static class KsqlQueryHandle implements QueryHandle { private final TransientQueryMetadata queryMetadata; @@ -228,5 +224,63 @@ public void stop() { public BlockingRowQueue getQueue() { return queryMetadata.getRowQueue(); } + + @Override + public void onException(final Consumer onException) { + // We don't try to do anything on exception for push queries, but rely on the + // existing exception handling + } + } + + private static class KsqlPullQueryHandle implements QueryHandle { + + private final PullQueryResult result; + private final Optional pullQueryMetrics; + private final CompletableFuture future = new CompletableFuture<>(); + + KsqlPullQueryHandle(final PullQueryResult result, + final Optional pullQueryMetrics) { + this.result = Objects.requireNonNull(result); + this.pullQueryMetrics = Objects.requireNonNull(pullQueryMetrics); + } + + @Override + public List getColumnNames() { + return colNamesFromSchema(result.getSchema().columns()); + } + + @Override + public List getColumnTypes() { + return colTypesFromSchema(result.getSchema().columns()); + } + + @Override + public void start() { + try { + result.start(); + result.onException(future::completeExceptionally); + result.onCompletion(future::complete); + } catch (Exception e) { + pullQueryMetrics.ifPresent(metrics -> metrics.recordErrorRate(1)); + } + } + + @Override + public void stop() { + result.stop(); + } + + @Override + public BlockingRowQueue getQueue() { + return result.getPullQueryQueue(); + } + + @Override + public void onException(final Consumer onException) { + future.exceptionally(t -> { + onException.accept(t); + return null; + }); + } } } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/OldApiUtils.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/OldApiUtils.java index eceea46c34aa..94b963cb0cdf 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/OldApiUtils.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/OldApiUtils.java @@ -139,7 +139,8 @@ private static void streamEndpointResponse(final Server server, final WorkerExecutor workerExecutor = server.getWorkerExecutor(); final VertxCompletableFuture vcf = new VertxCompletableFuture<>(); workerExecutor.executeBlocking(promise -> { - final OutputStream ros = new ResponseOutputStream(routingContext.response()); + final OutputStream ros = new ResponseOutputStream(routingContext.response(), + streamingOutput.getWriteTimeoutMs()); routingContext.request().connection().closeHandler(v -> { // Close the OutputStream on close of the HTTP connection try { diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/PushQueryHandle.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryHandle.java similarity index 83% rename from ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/PushQueryHandle.java rename to ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryHandle.java index 8778bd19343b..bc0b7749f2f7 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/PushQueryHandle.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryHandle.java @@ -17,11 +17,12 @@ import io.confluent.ksql.query.BlockingRowQueue; import java.util.List; +import java.util.function.Consumer; /** - * Handle to a push query running in the engine + * Handle to a query running in the engine */ -public interface PushQueryHandle { +public interface QueryHandle { List getColumnNames(); @@ -32,4 +33,6 @@ public interface PushQueryHandle { void stop(); BlockingRowQueue getQueue(); + + void onException(Consumer onException); } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryStreamHandler.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryStreamHandler.java index 50eb3eb33e37..cff4921f15de 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryStreamHandler.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/QueryStreamHandler.java @@ -86,6 +86,9 @@ public void handle(final RoutingContext routingContext) { metadata = new QueryResponseMetadata( queryPublisher.getColumnNames(), queryPublisher.getColumnTypes()); + + // When response is complete, publisher should be closed + routingContext.response().endHandler(v -> queryPublisher.close()); } else { final PushQueryHolder query = connectionQueryManager .createApiQuery(queryPublisher, routingContext.request()); @@ -111,5 +114,4 @@ public void handle(final RoutingContext routingContext) { .exceptionally(t -> ServerUtils.handleEndpointException(t, routingContext, "Failed to execute query")); } - } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/ResponseOutputStream.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/ResponseOutputStream.java index d7ad0b208101..0904954aee2e 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/ResponseOutputStream.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/ResponseOutputStream.java @@ -34,15 +34,15 @@ This is only used by legacy streaming endpoints from the old API which work with output streams. */ public class ResponseOutputStream extends OutputStream { - - private static final int WRITE_TIMEOUT_MS = 10 * 60000; private static final int BLOCK_TIME_MS = 100; private final HttpServerResponse response; + private final int writeTimeoutMs; private volatile boolean closed; - public ResponseOutputStream(final HttpServerResponse response) { + public ResponseOutputStream(final HttpServerResponse response, final int writeTimeoutMs) { this.response = response; + this.writeTimeoutMs = writeTimeoutMs; } @Override @@ -107,7 +107,7 @@ private void blockOnWrite(final CompletableFuture cf) throws IOException { } catch (Exception e) { throw new KsqlException(e); } - } while (System.currentTimeMillis() - start < WRITE_TIMEOUT_MS); + } while (System.currentTimeMillis() - start < writeTimeoutMs); throw new KsqlException("Timed out waiting to write to client"); } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/StreamingOutput.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/StreamingOutput.java index a1aa38686caf..33b014e7f322 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/StreamingOutput.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/api/server/StreamingOutput.java @@ -25,4 +25,9 @@ public interface StreamingOutput extends AutoCloseable { @Override void close(); + /** + * The amount of time the system will spend trying to clear the Vert.x response write queue + * before giving up and throwing an error. + */ + int getWriteTimeoutMs(); } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/entity/TableRows.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/entity/TableRows.java deleted file mode 100644 index 222c59616818..000000000000 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/entity/TableRows.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2019 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.rest.entity; - -import static java.util.Objects.requireNonNull; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableList.Builder; -import io.confluent.ksql.query.QueryId; -import io.confluent.ksql.schema.ksql.LogicalSchema; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Objects; - -public class TableRows { - - private final LogicalSchema schema; - private final QueryId queryId; - private final ImmutableList> rows; - private final String statementText; - - public TableRows( - final String statementText, - final QueryId queryId, - final LogicalSchema schema, - final List> rows - ) { - this.statementText = requireNonNull(statementText, "statementText"); - this.schema = requireNonNull(schema, "schema"); - this.queryId = requireNonNull(queryId, "queryId"); - this.rows = deepCopy(requireNonNull(rows, "rows")); - - rows.forEach(this::validate); - } - - public String getStatementText() { - return statementText; - } - - public LogicalSchema getSchema() { - return schema; - } - - public QueryId getQueryId() { - return queryId; - } - - public List> getRows() { - return rows; - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - final TableRows that = (TableRows) o; - return Objects.equals(schema, that.schema) - && Objects.equals(queryId, that.queryId) - && Objects.equals(rows, that.rows) - && Objects.equals(statementText, that.statementText); - } - - @Override - public int hashCode() { - return Objects.hash(schema, queryId, rows, statementText); - } - - private void validate(final List row) { - final int expectedSize = schema.key().size() + schema.value().size(); - final int actualSize = row.size(); - - if (expectedSize != actualSize) { - throw new IllegalArgumentException("column count mismatch." - + " expected: " + expectedSize - + ", got: " + actualSize - ); - } - } - - private static ImmutableList> deepCopy(final List> rows) { - final Builder> builder = ImmutableList.builder(); - rows.stream() - .>map(ArrayList::new) - .map(Collections::unmodifiableList) - .forEach(builder::add); - - return builder.build(); - } -} diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/entity/TableRowsFactory.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/entity/TableRowsFactory.java deleted file mode 100644 index 2680a51312bf..000000000000 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/entity/TableRowsFactory.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright 2019 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.rest.entity; - -import io.confluent.ksql.execution.streams.materialization.TableRow; -import io.confluent.ksql.schema.ksql.LogicalSchema; -import io.confluent.ksql.schema.ksql.LogicalSchema.Builder; -import io.confluent.ksql.schema.ksql.SystemColumns; -import io.confluent.ksql.schema.ksql.types.SqlTypes; -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; - -/** - * Factory class for {@link TableRows} - */ -public final class TableRowsFactory { - - private TableRowsFactory() { - } - - public static List> createRows( - final List result - ) { - return result.stream() - .map(TableRowsFactory::createRow) - .collect(Collectors.toList()); - } - - public static LogicalSchema buildSchema( - final LogicalSchema schema, - final boolean windowed - ) { - final Builder builder = LogicalSchema.builder() - .keyColumns(schema.key()); - - if (windowed) { - builder.keyColumn(SystemColumns.WINDOWSTART_NAME, SqlTypes.BIGINT); - builder.keyColumn(SystemColumns.WINDOWEND_NAME, SqlTypes.BIGINT); - } - - return builder - .valueColumns(schema.value()) - .build(); - } - - private static List createRow(final TableRow row) { - final List rowList = new ArrayList<>(row.key().values()); - - row.window().ifPresent(window -> { - rowList.add(window.start().toEpochMilli()); - rowList.add(window.end().toEpochMilli()); - }); - - rowList.addAll(row.value().values()); - - return rowList; - } -} diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisher.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisher.java index 9b757b6b9604..d7720dd8dbde 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisher.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisher.java @@ -18,37 +18,34 @@ import static java.util.Objects.requireNonNull; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.ListeningScheduledExecutorService; import com.google.common.util.concurrent.RateLimiter; import io.confluent.ksql.GenericRow; import io.confluent.ksql.engine.KsqlEngine; import io.confluent.ksql.engine.PullQueryExecutionUtil; import io.confluent.ksql.execution.streams.RoutingFilter.RoutingFilterFactory; import io.confluent.ksql.execution.streams.RoutingOptions; -import io.confluent.ksql.execution.streams.materialization.Locator.KsqlNode; import io.confluent.ksql.internal.PullQueryExecutorMetrics; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.physical.pull.HARouting; import io.confluent.ksql.physical.pull.PullQueryResult; -import io.confluent.ksql.rest.entity.KsqlHostInfoEntity; import io.confluent.ksql.rest.entity.StreamedRow; -import io.confluent.ksql.rest.entity.TableRows; import io.confluent.ksql.rest.server.resources.streaming.Flow.Subscriber; import io.confluent.ksql.services.ServiceContext; import io.confluent.ksql.statement.ConfiguredStatement; -import io.confluent.ksql.util.Pair; +import io.confluent.ksql.util.KeyValue; import java.util.Collection; import java.util.List; import java.util.Optional; -import java.util.concurrent.Callable; import java.util.stream.Collectors; -import java.util.stream.IntStream; class PullQueryPublisher implements Flow.Publisher> { private final KsqlEngine ksqlEngine; private final ServiceContext serviceContext; + private final ListeningScheduledExecutorService exec; private final ConfiguredStatement query; private final Optional pullQueryMetrics; private final long startTimeNanos; @@ -60,6 +57,7 @@ class PullQueryPublisher implements Flow.Publisher> { PullQueryPublisher( final KsqlEngine ksqlEngine, final ServiceContext serviceContext, + final ListeningScheduledExecutorService exec, final ConfiguredStatement query, final Optional pullQueryMetrics, final long startTimeNanos, @@ -69,6 +67,7 @@ class PullQueryPublisher implements Flow.Publisher> { ) { this.ksqlEngine = requireNonNull(ksqlEngine, "ksqlEngine"); this.serviceContext = requireNonNull(serviceContext, "serviceContext"); + this.exec = requireNonNull(exec, "exec"); this.query = requireNonNull(query, "query"); this.pullQueryMetrics = pullQueryMetrics; this.startTimeNanos = startTimeNanos; @@ -79,98 +78,68 @@ class PullQueryPublisher implements Flow.Publisher> { @Override public synchronized void subscribe(final Subscriber> subscriber) { - final PullQuerySubscription subscription = new PullQuerySubscription( - subscriber, - () -> { - final RoutingOptions routingOptions = new PullQueryConfigRoutingOptions( - query.getSessionConfig().getConfig(false), - query.getSessionConfig().getOverrides(), - ImmutableMap.of() - ); - - PullQueryExecutionUtil.checkRateLimit(rateLimiter); - - final PullQueryResult result = ksqlEngine.executePullQuery( - serviceContext, - query, - routing, - routingFilterFactory, - routingOptions, - pullQueryMetrics - ); - - pullQueryMetrics.ifPresent(pullQueryExecutorMetrics -> pullQueryExecutorMetrics - .recordLatency(startTimeNanos)); - return result; - }, - query + final RoutingOptions routingOptions = new PullQueryConfigRoutingOptions( + query.getSessionConfig().getConfig(false), + query.getSessionConfig().getOverrides(), + ImmutableMap.of() + ); + + PullQueryExecutionUtil.checkRateLimit(rateLimiter); + + final PullQueryResult result = ksqlEngine.executePullQuery( + serviceContext, + query, + routing, + routingOptions, + pullQueryMetrics, + true ); + result.onCompletion(v -> { + pullQueryMetrics.ifPresent(p -> p.recordLatency(startTimeNanos)); + }); + + final PullQuerySubscription subscription = new PullQuerySubscription( + exec, subscriber, result); + subscriber.onSubscribe(subscription); } - private static final class PullQuerySubscription implements Flow.Subscription { + private static final class PullQuerySubscription + extends PollingSubscription> { private final Subscriber> subscriber; - private final Callable executor; - private final ConfiguredStatement query; - private boolean done = false; + private final PullQueryResult result; private PullQuerySubscription( + final ListeningScheduledExecutorService exec, final Subscriber> subscriber, - final Callable executor, - final ConfiguredStatement query + final PullQueryResult result ) { + super(exec, subscriber, result.getSchema()); this.subscriber = requireNonNull(subscriber, "subscriber"); - this.executor = requireNonNull(executor, "executor"); - this.query = requireNonNull(query, "query"); + this.result = requireNonNull(result, "result"); + + result.onCompletion(v -> setDone()); + result.onException(this::setError); } @Override - public void request(final long n) { - Preconditions.checkArgument(n == 1, "number of requested items must be 1"); - - if (done) { - return; - } - - done = true; - - try { - final PullQueryResult result = executor.call(); - final TableRows entity = new TableRows( - query.getStatementText(), - result.getQueryId(), - result.getSchema(), - result.getTableRows()); - final Optional> hosts = result.getSourceNodes() - .map(list -> list.stream().map(KsqlNode::location) - .map(location -> new KsqlHostInfoEntity(location.getHost(), location.getPort())) - .collect(Collectors.toList())); - - subscriber.onSchema(entity.getSchema()); - - hosts.ifPresent(h -> Preconditions.checkState(h.size() == entity.getRows().size())); - final List rows = IntStream.range(0, entity.getRows().size()) - .mapToObj(i -> Pair.of( - PullQuerySubscription.toGenericRow(entity.getRows().get(i)), - hosts.map(h -> h.get(i)))) - .map(pair -> StreamedRow.pullRow(pair.getLeft(), pair.getRight())) - .collect(Collectors.toList()); - - subscriber.onNext(rows); - subscriber.onComplete(); - } catch (final Exception e) { - subscriber.onError(e); + Collection poll() { + final List, GenericRow>> rows = Lists.newLinkedList(); + result.getPullQueryQueue().drainTo(rows); + if (rows.isEmpty()) { + return null; + } else { + return rows.stream() + .map(kv -> StreamedRow.pushRow(kv.value())) + .collect(Collectors.toCollection(Lists::newLinkedList)); } } @Override - public void cancel() { - } - - private static GenericRow toGenericRow(final List values) { - return new GenericRow().appendAll(values); + void close() { + result.stop(); } } } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryStreamWriter.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryStreamWriter.java new file mode 100644 index 000000000000..ba4a4721a549 --- /dev/null +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryStreamWriter.java @@ -0,0 +1,373 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"; you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.rest.server.resources.streaming; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.Lists; +import io.confluent.ksql.api.server.StreamingOutput; +import io.confluent.ksql.execution.streams.materialization.Locator.KsqlNode; +import io.confluent.ksql.physical.pull.PullQueryResult; +import io.confluent.ksql.physical.pull.PullQueryRow; +import io.confluent.ksql.query.PullQueryQueue; +import io.confluent.ksql.rest.Errors; +import io.confluent.ksql.rest.entity.KsqlHostInfoEntity; +import io.confluent.ksql.rest.entity.StreamedRow; +import io.confluent.ksql.util.KsqlException; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class PullQueryStreamWriter implements StreamingOutput { + private static final Logger LOG = LoggerFactory.getLogger(PullQueryStreamWriter.class); + private static final int WRITE_TIMEOUT_MS = 3000; + + private static final int FLUSH_SIZE_BYTES = 50 * 1024; + private static final long MAX_FLUSH_MS = 1000; + + private final long disconnectCheckInterval; + private final PullQueryQueue pullQueryQueue; + private final Clock clock; + private final PullQueryResult result; + private final ObjectMapper objectMapper; + private AtomicBoolean completed = new AtomicBoolean(false); + private AtomicBoolean connectionClosed = new AtomicBoolean(false); + private AtomicReference pullQueryException = new AtomicReference<>(null); + private AtomicBoolean closed = new AtomicBoolean(false); + private boolean sentAtLeastOneRow = false; + + PullQueryStreamWriter( + final PullQueryResult result, + final long disconnectCheckInterval, + final ObjectMapper objectMapper, + final PullQueryQueue pullQueryQueue, + final Clock clock, + final CompletableFuture connectionClosedFuture + ) { + this.result = Objects.requireNonNull(result, "result"); + this.objectMapper = Objects.requireNonNull(objectMapper, "objectMapper"); + this.disconnectCheckInterval = disconnectCheckInterval; + this.pullQueryQueue = Objects.requireNonNull(pullQueryQueue, "pullQueryQueue"); + this.clock = Objects.requireNonNull(clock, "clock"); + connectionClosedFuture.thenAccept(v -> connectionClosed.set(true)); + result.onException(t -> { + if (pullQueryException.getAndSet(t) == null) { + interruptWriterThread(); + } + }); + result.onCompletion(v -> { + if (!completed.getAndSet(true)) { + interruptWriterThread(); + } + }); + } + + @Override + public void write(final OutputStream output) { + try { + final WriterState writerState = new WriterState(clock); + final QueueWrapper queueWrapper = new QueueWrapper(pullQueryQueue, disconnectCheckInterval); + + // First write the header with the schema + final StreamedRow header + = StreamedRow.header(result.getQueryId(), result.getSchema()); + writerState.append("[").append(writeValueAsString(header)); + + // While the query is still running, and the client hasn't closed the connection, continue to + // poll new rows. + while (!connectionClosed.get() && !isCompletedOrHasException()) { + processRow(output, writerState, queueWrapper); + } + + if (connectionClosed.get()) { + return; + } + + // If the query finished quickly, we might not have thrown the error + drainAndThrowOnError(output, writerState, queueWrapper); + + // If no error was thrown above, drain the queue + drain(writerState, queueWrapper); + writerState.append("]"); + if (writerState.length() > 0) { + output.write(writerState.getStringToFlush().getBytes(StandardCharsets.UTF_8)); + output.flush(); + } + } catch (InterruptedException e) { + // The most likely cause of this is the server shutting down. Should just try to close + // gracefully, without writing any more to the connection stream. + LOG.warn("Interrupted while writing to connection stream"); + } catch (Throwable e) { + LOG.error("Exception occurred while writing to connection stream: ", e); + outputException(output, e); + } + } + + /** + * Processes a single row from the queue or times out waiting for one. If an error has occurred + * during pull query execution, the queue is drained and the error is thrown. + * + *

Also, the thread may be interrupted by the completion callback, in which case, this method + * completes immediately. + * @param output The output stream to write to + * @param writerState writer state + * @param queueWrapper the queue wrapper + * @throws Throwable If an exception is found while running the pull query, it's rethrown here. + */ + private void processRow( + final OutputStream output, + final WriterState writerState, + final QueueWrapper queueWrapper + ) throws Throwable { + final PullQueryRow toProcess = queueWrapper.pollNextRow(); + if (toProcess == QueueWrapper.END_ROW) { + return; + } + if (toProcess != null) { + writeRow(toProcess, writerState, queueWrapper.hasAnotherRow()); + if (writerState.length() >= FLUSH_SIZE_BYTES + || (clock.millis() - writerState.getLastFlushMs()) >= MAX_FLUSH_MS + ) { + output.write(writerState.getStringToFlush().getBytes(StandardCharsets.UTF_8)); + output.flush(); + } + } + drainAndThrowOnError(output, writerState, queueWrapper); + } + + /** + * Does the job of writing the row to the writer state. + * @param row The row to write + * @param writerState writer state + * @param hasAnotherRow if there's another row after this one. This is used for determining how + * to write proper JSON, e.g. whether to add a comma. + */ + private void writeRow( + final PullQueryRow row, + final WriterState writerState, + final boolean hasAnotherRow + ) { + // Send for a comma after the header + if (!sentAtLeastOneRow) { + writerState.append(",").append(System.lineSeparator()); + sentAtLeastOneRow = true; + } + final StreamedRow streamedRow = StreamedRow + .pullRow(row.getGenericRow(), toKsqlHostInfo(row.getSourceNode())); + writerState.append(writeValueAsString(streamedRow)); + if (hasAnotherRow) { + writerState.append(",").append(System.lineSeparator()); + } + } + + /** + * If an error has been stored in pullQueryException, drains the queue and throws the exception. + * @param output The output stream to write to + * @param writerState writer state + * @throws Throwable If an exception is stored, it's rethrown. + */ + private void drainAndThrowOnError( + final OutputStream output, + final WriterState writerState, + final QueueWrapper queueWrapper + ) throws Throwable { + if (pullQueryException.get() != null) { + drain(writerState, queueWrapper); + output.write(writerState.getStringToFlush().getBytes(StandardCharsets.UTF_8)); + output.flush(); + throw pullQueryException.get(); + } + } + + /** + * Drains the queue and writes the contained rows. + * @param writerState writer state + * @param queueWrapper the queue wrapper + */ + private void drain(final WriterState writerState, final QueueWrapper queueWrapper) { + final List rows = queueWrapper.drain(); + int i = 0; + for (final PullQueryRow row : rows) { + writeRow(row, writerState, i + 1 < rows.size()); + i++; + } + } + + /** + * Outputs the given exception to the output stream. + * @param out The output stream + * @param exception The exception to write + */ + private void outputException(final OutputStream out, final Throwable exception) { + if (connectionClosed.get()) { + return; + } + try { + out.write(",\n".getBytes(StandardCharsets.UTF_8)); + if (exception.getCause() instanceof KsqlException) { + objectMapper.writeValue(out, StreamedRow + .error(exception.getCause(), Errors.ERROR_CODE_SERVER_ERROR)); + } else { + objectMapper.writeValue(out, StreamedRow + .error(exception, Errors.ERROR_CODE_SERVER_ERROR)); + } + out.write("]\n".getBytes(StandardCharsets.UTF_8)); + out.flush(); + } catch (final IOException e) { + LOG.debug("Client disconnected while attempting to write an error message"); + } + } + + @Override + public void close() { + if (!closed.getAndSet(true)) { + result.stop(); + } + } + + @Override + public int getWriteTimeoutMs() { + return WRITE_TIMEOUT_MS; + } + + private boolean isCompletedOrHasException() { + return completed.get() || pullQueryException.get() != null; + } + + private void interruptWriterThread() { + pullQueryQueue.putSentinelRow(QueueWrapper.END_ROW); + } + + /** + * Converts the object to json and returns the string. + * @param object The object to convert + * @return The serialized JSON + */ + private String writeValueAsString(final Object object) { + try { + return objectMapper.writeValueAsString(object); + } catch (final JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + /** + * Converts the KsqlNode to KsqlHostInfoEntity + */ + private static Optional toKsqlHostInfo(final Optional ksqlNode) { + return ksqlNode.map( + node -> new KsqlHostInfoEntity(node.location().getHost(), node.location().getPort())); + } + + /** + * State that's kept for the buffered response and the last flush time. + */ + private static class WriterState { + private final Clock clock; + // The buffer of JSON that we're always flushing as we hit either time or size thresholds. + private StringBuilder sb = new StringBuilder(); + // Last flush timestamp in millis + private long lastFlushMs; + + WriterState(final Clock clock) { + this.clock = clock; + } + + public WriterState append(final String str) { + sb.append(str); + return this; + } + + public int length() { + return sb.length(); + } + + public long getLastFlushMs() { + return lastFlushMs; + } + + public String getStringToFlush() { + final String str = sb.toString(); + sb = new StringBuilder(); + lastFlushMs = clock.millis(); + return str; + } + } + + /** + * Wraps the PullQueryQueue to keep a hold of the head of the queue explicitly so it always knows + * if there's something next. + */ + static final class QueueWrapper { + public static final PullQueryRow END_ROW = new PullQueryRow(null, null, null); + private final PullQueryQueue pullQueryQueue; + private final long disconnectCheckInterval; + // We always keep a reference to the head of the queue so that we know if there's another + // row in the result in order to produce proper JSON. + private PullQueryRow head = null; + + QueueWrapper(final PullQueryQueue pullQueryQueue, final long disconnectCheckInterval) { + this.pullQueryQueue = pullQueryQueue; + this.disconnectCheckInterval = disconnectCheckInterval; + } + + public boolean hasAnotherRow() { + return head != null; + } + + public PullQueryRow pollNextRow() throws InterruptedException { + final PullQueryRow row = pullQueryQueue.pollRow( + disconnectCheckInterval, + TimeUnit.MILLISECONDS + ); + + if (row == END_ROW) { + return END_ROW; + } + + if (row != null) { + // The head becomes the next thing to process and the newly polled row becomes the head. + // The first time this is run, we'll always return null since we're keeping the row as the + // new head. + final PullQueryRow toProcess = head; + head = row; + return toProcess; + } + return null; + } + + public List drain() { + final List rows = Lists.newArrayList(); + if (head != null) { + rows.add(head); + } + head = null; + pullQueryQueue.drainRowsTo(rows); + rows.remove(END_ROW); + return rows; + } + } +} diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/QueryStreamWriter.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/QueryStreamWriter.java index 1cf68c80ebea..3a45aef8fe57 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/QueryStreamWriter.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/QueryStreamWriter.java @@ -39,6 +39,7 @@ import org.slf4j.LoggerFactory; class QueryStreamWriter implements StreamingOutput { + private static final int WRITE_TIMEOUT_MS = 10 * 60000; private static final Logger log = LoggerFactory.getLogger(QueryStreamWriter.class); @@ -125,6 +126,11 @@ public synchronized void close() { } } + @Override + public int getWriteTimeoutMs() { + return WRITE_TIMEOUT_MS; + } + private StreamedRow buildHeader() { final QueryId queryId = queryMetadata.getQueryId(); diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java index be7d31bc5de8..9f1fcf948c81 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResource.java @@ -18,7 +18,6 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; import com.google.common.util.concurrent.RateLimiter; import io.confluent.ksql.GenericRow; import io.confluent.ksql.analyzer.PullQueryValidator; @@ -27,7 +26,6 @@ import io.confluent.ksql.engine.PullQueryExecutionUtil; import io.confluent.ksql.execution.streams.RoutingFilter.RoutingFilterFactory; import io.confluent.ksql.execution.streams.RoutingOptions; -import io.confluent.ksql.execution.streams.materialization.Locator.KsqlNode; import io.confluent.ksql.internal.PullQueryExecutorMetrics; import io.confluent.ksql.parser.KsqlParser.PreparedStatement; import io.confluent.ksql.parser.tree.PrintTopic; @@ -38,11 +36,8 @@ import io.confluent.ksql.rest.ApiJsonMapper; import io.confluent.ksql.rest.EndpointResponse; import io.confluent.ksql.rest.Errors; -import io.confluent.ksql.rest.entity.KsqlHostInfoEntity; import io.confluent.ksql.rest.entity.KsqlMediaType; import io.confluent.ksql.rest.entity.KsqlRequest; -import io.confluent.ksql.rest.entity.StreamedRow; -import io.confluent.ksql.rest.entity.TableRows; import io.confluent.ksql.rest.server.LocalCommands; import io.confluent.ksql.rest.server.StatementParser; import io.confluent.ksql.rest.server.computation.CommandQueue; @@ -56,9 +51,9 @@ import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.KsqlStatementException; -import io.confluent.ksql.util.Pair; import io.confluent.ksql.util.TransientQueryMetadata; import io.confluent.ksql.version.metrics.ActivenessRegistrar; +import java.time.Clock; import java.time.Duration; import java.util.Collection; import java.util.HashMap; @@ -68,7 +63,6 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; -import java.util.stream.IntStream; import org.apache.kafka.common.errors.TopicAuthorizationException; import org.apache.kafka.streams.StreamsConfig; import org.slf4j.Logger; @@ -247,7 +241,8 @@ private EndpointResponse handleStatement( configProperties, request.getRequestProperties(), isInternalRequest, - pullQueryMetrics + pullQueryMetrics, + connectionClosedFuture ); } @@ -286,7 +281,8 @@ private EndpointResponse handlePullQuery( final Map configOverrides, final Map requestProperties, final Optional isInternalRequest, - final Optional pullQueryMetrics + final Optional pullQueryMetrics, + final CompletableFuture connectionClosedFuture ) { final ConfiguredStatement configured = ConfiguredStatement .of(statement, SessionConfig.of(ksqlConfig, configOverrides)); @@ -325,41 +321,20 @@ private EndpointResponse handlePullQuery( serviceContext, configured, routing, - routingFilterFactory, routingOptions, - pullQueryMetrics - ); - final TableRows tableRows = new TableRows( - statement.getStatementText(), - result.getQueryId(), - result.getSchema(), - result.getTableRows()); - - final Optional> hosts = result.getSourceNodes() - .map(list -> list.stream().map(KsqlNode::location) - .map(location -> new KsqlHostInfoEntity(location.getHost(), location.getPort())) - .collect(Collectors.toList())); - - final StreamedRow header = StreamedRow.header( - tableRows.getQueryId(), - tableRows.getSchema() + pullQueryMetrics, + true ); - hosts.ifPresent(h -> Preconditions.checkState(h.size() == tableRows.getRows().size())); - final List rows = IntStream.range(0, tableRows.getRows().size()) - .mapToObj(i -> Pair.of( - StreamedQueryResource.toGenericRow(tableRows.getRows().get(i)), - hosts.map(h -> h.get(i)))) - .map(pair -> StreamedRow.pullRow(pair.getLeft(), pair.getRight())) - .collect(Collectors.toList()); - - rows.add(0, header); - - final String data = rows.stream() - .map(StreamedQueryResource::writeValueAsString) - .collect(Collectors.joining("," + System.lineSeparator(), "[", "]")); + final PullQueryStreamWriter pullQueryStreamWriter = new PullQueryStreamWriter( + result, + disconnectCheckInterval.toMillis(), + OBJECT_MAPPER, + result.getPullQueryQueue(), + Clock.systemUTC(), + connectionClosedFuture); - return EndpointResponse.ok(data); + return EndpointResponse.ok(pullQueryStreamWriter); } private EndpointResponse handlePushQuery( diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/TopicStreamWriter.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/TopicStreamWriter.java index 55fc89288a93..3b5c1c8af06c 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/TopicStreamWriter.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/TopicStreamWriter.java @@ -41,6 +41,7 @@ public class TopicStreamWriter implements StreamingOutput { private static final Logger log = LoggerFactory.getLogger(TopicStreamWriter.class); + private static final int WRITE_TIMEOUT_MS = 10 * 60000; private final long interval; private final Duration disconnectCheckInterval; private final KafkaConsumer topicConsumer; @@ -150,6 +151,11 @@ public synchronized void close() { } } + @Override + public int getWriteTimeoutMs() { + return WRITE_TIMEOUT_MS; + } + private static void outputException(final OutputStream out, final Exception exception) { try { out.write(exception.getMessage().getBytes(UTF_8)); diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpoint.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpoint.java index 8e46239fb1ec..b65aa18dcee1 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpoint.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/resources/streaming/WSQueryEndpoint.java @@ -364,7 +364,7 @@ private static void startPushQueryPublisher( private static void startPullQueryPublisher( final KsqlEngine ksqlEngine, final ServiceContext serviceContext, - final ListeningScheduledExecutorService ignored, + final ListeningScheduledExecutorService exec, final ConfiguredStatement query, final WebSocketSubscriber streamSubscriber, final Optional pullQueryMetrics, @@ -376,6 +376,7 @@ private static void startPullQueryPublisher( new PullQueryPublisher( ksqlEngine, serviceContext, + exec, query, pullQueryMetrics, startTimeNanos, diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/DefaultKsqlClient.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/DefaultKsqlClient.java index 7fb7c72d1457..d2de6abd0f75 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/DefaultKsqlClient.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/DefaultKsqlClient.java @@ -37,6 +37,7 @@ import java.util.Map; import java.util.Optional; import java.util.function.BiFunction; +import java.util.function.Consumer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -94,6 +95,7 @@ public RestResponse> makeQueryRequest( final Map configOverrides, final Map requestProperties ) { + final KsqlTarget target = sharedClient .target(serverEndPoint) .properties(configOverrides); @@ -108,6 +110,28 @@ public RestResponse> makeQueryRequest( return RestResponse.successful(resp.getStatusCode(), resp.getResponse()); } + @Override + public RestResponse makeQueryRequest( + final URI serverEndPoint, + final String sql, + final Map configOverrides, + final Map requestProperties, + final Consumer> rowConsumer + ) { + final KsqlTarget target = sharedClient + .target(serverEndPoint) + .properties(configOverrides); + + final RestResponse resp = getTarget(target, authHeader) + .postQueryRequest(sql, requestProperties, Optional.empty(), rowConsumer); + + if (resp.isErroneous()) { + return RestResponse.erroneous(resp.getStatusCode(), resp.getErrorMessage()); + } + + return RestResponse.successful(resp.getStatusCode(), resp.getResponse()); + } + @Override public void makeAsyncHeartbeatRequest( final URI serverEndPoint, diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/ServerInternalKsqlClient.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/ServerInternalKsqlClient.java index 35ab12e25096..d1463e5d0d21 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/ServerInternalKsqlClient.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/services/ServerInternalKsqlClient.java @@ -34,6 +34,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.Consumer; /** * A KSQL client implementation that sends requests to KsqlResource directly, rather than going @@ -82,6 +83,17 @@ public RestResponse> makeQueryRequest( throw new UnsupportedOperationException(); } + @Override + public RestResponse makeQueryRequest( + final URI serverEndpoint, + final String sql, + final Map configOverrides, + final Map requestProperties, + final Consumer> rowConsumer + ) { + throw new UnsupportedOperationException(); + } + @Override public void makeAsyncHeartbeatRequest( final URI serverEndPoint, diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/ApiTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/ApiTest.java index cb440f561266..fbe87053f2ca 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/ApiTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/ApiTest.java @@ -35,14 +35,19 @@ import io.confluent.ksql.rest.entity.PushQueryId; import io.confluent.ksql.util.AppInfo; import io.confluent.ksql.util.VertxCompletableFuture; +import io.vertx.core.Vertx; import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.HttpVersion; import io.vertx.core.json.JsonArray; import io.vertx.core.json.JsonObject; +import io.vertx.ext.web.client.HttpRequest; import io.vertx.ext.web.client.HttpResponse; import io.vertx.ext.web.client.WebClient; +import io.vertx.ext.web.client.WebClientOptions; import io.vertx.ext.web.codec.BodyCodec; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/QueryStreamRunner.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/QueryStreamRunner.java index c0d3228b7762..99910daecdfc 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/QueryStreamRunner.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/api/perf/QueryStreamRunner.java @@ -24,7 +24,7 @@ import io.confluent.ksql.api.impl.BlockingQueryPublisher; import io.confluent.ksql.api.server.InsertResult; import io.confluent.ksql.api.server.InsertsStreamSubscriber; -import io.confluent.ksql.api.server.PushQueryHandle; +import io.confluent.ksql.api.server.QueryHandle; import io.confluent.ksql.api.spi.Endpoints; import io.confluent.ksql.api.spi.QueryPublisher; import io.confluent.ksql.query.BlockingRowQueue; @@ -48,6 +48,7 @@ import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; import org.reactivestreams.Subscriber; public class QueryStreamRunner extends BasePerfRunner { @@ -101,7 +102,7 @@ public synchronized CompletableFuture createQueryPublisher(final final ApiSecurityContext apiSecurityContext) { QueryStreamPublisher publisher = new QueryStreamPublisher(context, server.getWorkerExecutor()); - publisher.setQueryHandle(new TestQueryHandle()); + publisher.setQueryHandle(new TestQueryHandle(), false); publishers.add(publisher); publisher.start(); return CompletableFuture.completedFuture(publisher); @@ -207,7 +208,7 @@ synchronized void closePublishers() { } } - private static class TestQueryHandle implements PushQueryHandle { + private static class TestQueryHandle implements QueryHandle { private final TransientQueryQueue queue = new TransientQueryQueue(OptionalInt.empty()); @@ -226,6 +227,10 @@ public BlockingRowQueue getQueue() { return queue; } + @Override + public void onException(Consumer onException) { + } + @Override public void start() { } @@ -252,9 +257,9 @@ public void start() { } @Override - public void setQueryHandle(final PushQueryHandle queryHandle) { + public void setQueryHandle(final QueryHandle queryHandle, boolean isPullQuery) { this.queue = (TransientQueryQueue) queryHandle.getQueue(); - super.setQueryHandle(queryHandle); + super.setQueryHandle(queryHandle, isPullQuery); } public void close() { diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/entity/TableRowsFactoryTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/entity/TableRowsFactoryTest.java deleted file mode 100644 index 825f40ed5585..000000000000 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/entity/TableRowsFactoryTest.java +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Copyright 2019 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.rest.entity; - -import static io.confluent.ksql.GenericKey.genericKey; -import static io.confluent.ksql.GenericRow.genericRow; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.is; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableList.Builder; -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.execution.streams.materialization.Row; -import io.confluent.ksql.execution.streams.materialization.TableRow; -import io.confluent.ksql.execution.streams.materialization.WindowedRow; -import io.confluent.ksql.name.ColumnName; -import io.confluent.ksql.schema.ksql.LogicalSchema; -import io.confluent.ksql.schema.ksql.types.SqlTypes; -import java.time.Instant; -import java.util.List; -import org.apache.kafka.streams.kstream.Windowed; -import org.apache.kafka.streams.kstream.internals.TimeWindow; -import org.junit.Test; - -public class TableRowsFactoryTest { - - private static final ColumnName K0 = ColumnName.of("k0"); - - private static final LogicalSchema SIMPLE_SCHEMA = LogicalSchema.builder() - .keyColumn(K0, SqlTypes.STRING) - .valueColumn(ColumnName.of("v0"), SqlTypes.BOOLEAN) - .build(); - - private static final LogicalSchema SCHEMA = LogicalSchema.builder() - .keyColumn(K0, SqlTypes.STRING) - .keyColumn(ColumnName.of("k1"), SqlTypes.BOOLEAN) - .valueColumn(ColumnName.of("v0"), SqlTypes.INTEGER) - .valueColumn(ColumnName.of("v1"), SqlTypes.BOOLEAN) - .build(); - - private static final LogicalSchema SCHEMA_NULL = LogicalSchema.builder() - .keyColumn(K0, SqlTypes.STRING) - .valueColumn(ColumnName.of("v0"), SqlTypes.STRING) - .valueColumn(ColumnName.of("v1"), SqlTypes.INTEGER) - .valueColumn(ColumnName.of("v2"), SqlTypes.DOUBLE) - .valueColumn(ColumnName.of("v3"), SqlTypes.BOOLEAN) - .build(); - - private static final long ROWTIME = 285775L; - - @Test - public void shouldAddNonWindowedRowToValues() { - // Given: - final List input = ImmutableList.of( - Row.of( - SIMPLE_SCHEMA, - genericKey("x"), - genericRow(false), - ROWTIME - ) - ); - - // When: - final List> output = TableRowsFactory.createRows(input); - - // Then: - assertThat(output, hasSize(1)); - assertThat(output.get(0), contains("x", false)); - } - - @Test - public void shouldAddWindowedRowToValues() { - // Given: - final Instant now = Instant.now(); - final TimeWindow window0 = new TimeWindow(now.toEpochMilli(), now.plusMillis(2).toEpochMilli()); - final TimeWindow window1 = new TimeWindow(now.toEpochMilli(), now.plusMillis(1).toEpochMilli()); - - final List input = ImmutableList.of( - WindowedRow.of( - SIMPLE_SCHEMA, - new Windowed<>(genericKey("x"), window0), - genericRow(true), - ROWTIME - ), - WindowedRow.of( - SIMPLE_SCHEMA, - new Windowed<>(genericKey("y"), window1), - genericRow(false), - ROWTIME - ) - ); - - // When: - final List> output = TableRowsFactory.createRows(input); - - // Then: - assertThat(output, hasSize(2)); - assertThat(output.get(0), - contains("x", now.toEpochMilli(), now.plusMillis(2).toEpochMilli(), true)); - assertThat(output.get(1), - contains("y", now.toEpochMilli(), now.plusMillis(1).toEpochMilli(), false)); - } - - @Test - public void shouldSupportNullColumns() { - // Given: - final GenericRow row = genericRow(null, null, null, null); - - final Builder builder = ImmutableList.builder(); - builder.add(Row.of(SCHEMA_NULL, genericKey("k"), row, ROWTIME)); - - // When: - final List> output = TableRowsFactory.createRows(builder.build()); - - // Then: - assertThat(output, hasSize(1)); - assertThat(output.get(0), contains("k", null, null, null, null)); - } - - @Test - public void shouldJustDuplicateRowTimeInValueIfNotWindowed() { - // When: - final LogicalSchema result = TableRowsFactory.buildSchema(SCHEMA, false); - - // Then: - assertThat(result, is(LogicalSchema.builder() - .keyColumn(K0, SqlTypes.STRING) - .keyColumn(ColumnName.of("k1"), SqlTypes.BOOLEAN) - .valueColumn(ColumnName.of("v0"), SqlTypes.INTEGER) - .valueColumn(ColumnName.of("v1"), SqlTypes.BOOLEAN) - .build() - )); - } - - @Test - public void shouldAddHoppingWindowFieldsToSchema() { - // When: - final LogicalSchema result = TableRowsFactory.buildSchema(SCHEMA, true); - - // Then: - assertThat(result, is(LogicalSchema.builder() - .keyColumn(K0, SqlTypes.STRING) - .keyColumn(ColumnName.of("k1"), SqlTypes.BOOLEAN) - .keyColumn(ColumnName.of("WINDOWSTART"), SqlTypes.BIGINT) - .keyColumn(ColumnName.of("WINDOWEND"), SqlTypes.BIGINT) - .valueColumn(ColumnName.of("v0"), SqlTypes.INTEGER) - .valueColumn(ColumnName.of("v1"), SqlTypes.BOOLEAN) - .build() - )); - } -} diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/entity/TableRowsTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/entity/TableRowsTest.java deleted file mode 100644 index eb1ba5e444e3..000000000000 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/entity/TableRowsTest.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright 2019 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.rest.entity; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; -import com.google.common.collect.ImmutableList; -import io.confluent.ksql.json.KsqlTypesSerializationModule; -import io.confluent.ksql.name.ColumnName; -import io.confluent.ksql.parser.json.KsqlTypesDeserializationModule; -import io.confluent.ksql.query.QueryId; -import io.confluent.ksql.schema.ksql.LogicalSchema; -import io.confluent.ksql.schema.ksql.types.SqlTypes; -import java.util.List; -import org.junit.Test; - -public class TableRowsTest { - - private static final ObjectMapper MAPPER; - private static final String SOME_SQL = "some SQL"; - - private static final LogicalSchema LOGICAL_SCHEMA = LogicalSchema.builder() - .keyColumn(ColumnName.of("ROWKEY"), SqlTypes.STRING) - .valueColumn(ColumnName.of("v0"), SqlTypes.DOUBLE) - .valueColumn(ColumnName.of("v1"), SqlTypes.STRING) - .build(); - - private static final QueryId QUERY_ID = new QueryId("bob"); - - private static final List A_VALUE = - ImmutableList.of("key value", 10.1D, "some text"); - - static { - MAPPER = new ObjectMapper(); - MAPPER.registerModule(new Jdk8Module()); - MAPPER.registerModule(new KsqlTypesSerializationModule()); - MAPPER.registerModule(new KsqlTypesDeserializationModule()); - } - - @Test(expected = IllegalArgumentException.class) - public void shouldThrowOnRowWindowTypeMismatch() { - new TableRows( - SOME_SQL, - QUERY_ID, - LOGICAL_SCHEMA, - ImmutableList.of(ImmutableList.of("too", "few")) - ); - } -} \ No newline at end of file diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/HighAvailabilityTestUtil.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/HighAvailabilityTestUtil.java index c9f011d284b4..bb9f105bc9be 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/HighAvailabilityTestUtil.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/HighAvailabilityTestUtil.java @@ -138,13 +138,13 @@ static void waitForClusterToBeDiscovered( } static void waitForStreamsMetadataToInitialize( - final TestKsqlRestApp restApp, List hosts, String queryId + final TestKsqlRestApp restApp, List hosts ) { - waitForStreamsMetadataToInitialize(restApp, hosts, queryId, Optional.empty()); + waitForStreamsMetadataToInitialize(restApp, hosts, Optional.empty()); } static void waitForStreamsMetadataToInitialize( - final TestKsqlRestApp restApp, List hosts, String queryId, + final TestKsqlRestApp restApp, List hosts, final Optional credentials ) { while (true) { diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/PullQueryRoutingFunctionalTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/PullQueryRoutingFunctionalTest.java index ed920002ab91..1c1c9a1c18db 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/PullQueryRoutingFunctionalTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/PullQueryRoutingFunctionalTest.java @@ -278,7 +278,7 @@ public void setUp() { waitForTableRows(); waitForStreamsMetadataToInitialize( - REST_APP_0, ImmutableList.of(HOST0, HOST1, HOST2), queryId, USER_CREDS); + REST_APP_0, ImmutableList.of(HOST0, HOST1, HOST2), USER_CREDS); } @After diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/PullQuerySingleNodeFunctionalTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/PullQuerySingleNodeFunctionalTest.java index bfa86a67708f..bb8d9439884e 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/PullQuerySingleNodeFunctionalTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/PullQuerySingleNodeFunctionalTest.java @@ -209,12 +209,12 @@ public void setUp() { sqlKey3 = "SELECT * FROM " + output + " WHERE USERID = '" + KEY_3 + "';"; waitForStreamsMetadataToInitialize( - REST_APP_0, ImmutableList.of(host0), queryId); + REST_APP_0, ImmutableList.of(host0)); } @Test public void restoreAfterClearState() { - waitForStreamsMetadataToInitialize(REST_APP_0, ImmutableList.of(host0), queryId); + waitForStreamsMetadataToInitialize(REST_APP_0, ImmutableList.of(host0)); waitForRemoteServerToChangeStatus(REST_APP_0, host0, HighAvailabilityTestUtil .lagsReported(host0, Optional.empty(), 5)); @@ -243,7 +243,7 @@ public void restoreAfterClearState() { LOG.info("Restarting the server " + host0.toString()); REST_APP_0.start(); - waitForStreamsMetadataToInitialize(REST_APP_0, ImmutableList.of(host0), queryId); + waitForStreamsMetadataToInitialize(REST_APP_0, ImmutableList.of(host0)); waitForRemoteServerToChangeStatus(REST_APP_0, host0, HighAvailabilityTestUtil .lagsReported(host0, Optional.of(2L), 5)); @@ -291,6 +291,7 @@ public void restoreAfterClearState() { assertThat(updatedRows.get(1).getRow().get().getColumns(), is(ImmutableList.of(KEY_3, 1))); } + private static String extractQueryId(final String outputString) { final java.util.regex.Matcher matcher = QUERY_ID_PATTERN.matcher(outputString); assertThat("Could not find query id in: " + outputString, matcher.find()); diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java index cb7cfa43e56f..a7b08b7d203a 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestApiTest.java @@ -22,6 +22,9 @@ import static io.confluent.ksql.test.util.EmbeddedSingleNodeKafkaCluster.ops; import static io.confluent.ksql.test.util.EmbeddedSingleNodeKafkaCluster.prefixedResource; import static io.confluent.ksql.test.util.EmbeddedSingleNodeKafkaCluster.resource; +import static io.netty.handler.codec.http.HttpResponseStatus.METHOD_NOT_ALLOWED; +import static io.vertx.core.http.HttpMethod.POST; +import static io.vertx.core.http.HttpVersion.HTTP_2; import static org.apache.kafka.common.acl.AclOperation.ALL; import static org.apache.kafka.common.acl.AclOperation.CREATE; import static org.apache.kafka.common.acl.AclOperation.DESCRIBE; @@ -40,8 +43,10 @@ import static org.hamcrest.Matchers.startsWith; import com.fasterxml.jackson.core.type.TypeReference; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.confluent.common.utils.IntegrationTest; +import io.confluent.ksql.api.utils.QueryResponse; import io.confluent.ksql.integration.IntegrationTestHarness; import io.confluent.ksql.rest.ApiJsonMapper; import io.confluent.ksql.rest.entity.CommandId; @@ -52,6 +57,7 @@ import io.confluent.ksql.rest.entity.CommandStatuses; import io.confluent.ksql.rest.entity.KsqlMediaType; import io.confluent.ksql.rest.entity.KsqlRequest; +import io.confluent.ksql.rest.entity.QueryStreamArgs; import io.confluent.ksql.rest.entity.ServerClusterId; import io.confluent.ksql.rest.entity.ServerInfo; import io.confluent.ksql.rest.entity.ServerMetadata; @@ -215,7 +221,9 @@ public void tearDown() { @AfterClass public static void classTearDown() { + System.out.println("TEARING DOWN CLASS"); REST_APP.getPersistentQueries().forEach(str -> makeKsqlRequest("TERMINATE " + str + ";")); + System.out.println("DONE TEARING DOWN CLASS"); } @Test @@ -584,7 +592,7 @@ public void shouldExecutePullQueryOverRest() { } @Test - public void shouldExecutePullQueryOverRestHttp2() { + public void shouldFailToExecutePullQueryOverRestHttp2() { // Given final KsqlRequest request = new KsqlRequest( "SELECT COUNT, USERID from " + AGG_TABLE + " WHERE USERID='" + AN_AGG_KEY + "';", @@ -592,22 +600,36 @@ public void shouldExecutePullQueryOverRestHttp2() { Collections.emptyMap(), null ); - final Supplier> call = () -> { - final String response = rawRestRequest( + final Supplier call = () -> { + return rawRestRequest( HttpVersion.HTTP_2, HttpMethod.POST, "/query", request - ).body().toString(); - return Arrays.asList(response.split(System.lineSeparator())); + ).statusCode(); }; // When: - final List messages = assertThatEventually(call, hasSize(HEADER + 1)); + assertThatEventually(call, is(METHOD_NOT_ALLOWED.code())); + } - // Then: - assertThat(messages, hasSize(HEADER + 1)); - assertThat(messages.get(0), startsWith("[{\"header\":{\"queryId\":\"")); - assertThat(messages.get(0), - endsWith("\",\"schema\":\"`COUNT` BIGINT, `USERID` STRING KEY\"}},")); - assertThat(messages.get(1), is("{\"row\":{\"columns\":[1,\"USER_1\"]}}]")); + @Test + public void shouldExecutePullQueryOverHttp2QueryStream() { + QueryStreamArgs queryStreamArgs = new QueryStreamArgs( + "SELECT COUNT, USERID from " + AGG_TABLE + " WHERE USERID='" + AN_AGG_KEY + "';", + Collections.emptyMap()); + + QueryResponse[] queryResponse = new QueryResponse[1]; + assertThatEventually(() -> { + try { + HttpResponse resp = RestIntegrationTestUtil.rawRestRequest(REST_APP, + HTTP_2, POST, + "/query-stream", queryStreamArgs, "application/vnd.ksqlapi.delimited.v1", + Optional.empty()); + queryResponse[0] = new QueryResponse(resp.body().toString()); + return queryResponse[0].rows.size(); + } catch (Throwable t) { + return Integer.MAX_VALUE; + } + }, is(1)); + assertThat(queryResponse[0].rows.get(0).getList(), is(ImmutableList.of(1, "USER_1"))); } @Test diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestIntegrationTestUtil.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestIntegrationTestUtil.java index 4905c59d818a..8fa0eee75501 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestIntegrationTestUtil.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/integration/RestIntegrationTestUtil.java @@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.net.UrlEscapers; +import io.confluent.ksql.api.utils.ReceiveStream; import io.confluent.ksql.rest.ApiJsonMapper; import io.confluent.ksql.rest.client.BasicCredentials; import io.confluent.ksql.rest.client.KsqlRestClient; @@ -52,10 +53,12 @@ import io.vertx.core.http.HttpMethod; import io.vertx.core.http.HttpVersion; import io.vertx.core.http.WebsocketVersion; +import io.vertx.core.streams.WriteStream; import io.vertx.ext.web.client.HttpRequest; import io.vertx.ext.web.client.HttpResponse; import io.vertx.ext.web.client.WebClient; import io.vertx.ext.web.client.WebClientOptions; +import io.vertx.ext.web.codec.BodyCodec; import java.net.URI; import java.nio.charset.Charset; import java.util.ArrayList; @@ -232,7 +235,8 @@ static HttpResponse rawRestQueryRequest( final KsqlRequest request = new KsqlRequest(sql, ImmutableMap.of(), ImmutableMap.of(), null); - return rawRestRequest(restApp, HTTP_1_1, POST, "/query", request, mediaType); + return rawRestRequest(restApp, HTTP_1_1, POST, "/query", request, mediaType, + Optional.empty()); } static HttpResponse rawRestRequest( @@ -242,14 +246,15 @@ static HttpResponse rawRestRequest( final String uri, final Object requestBody ) { - return rawRestRequest( - restApp, - httpVersion, - method, - uri, - requestBody, - "application/json" - ); + return rawRestRequest( + restApp, + httpVersion, + method, + uri, + requestBody, + "application/json", + Optional.empty() + ); } static HttpResponse rawRestRequest( @@ -258,11 +263,11 @@ static HttpResponse rawRestRequest( final HttpMethod method, final String uri, final Object requestBody, - final String mediaType + final String mediaType, + final Optional> writeStream ) { Vertx vertx = Vertx.vertx(); WebClient webClient = null; - try { WebClientOptions webClientOptions = new WebClientOptions() .setDefaultHost(restApp.getHttpListener().getHost()) @@ -274,7 +279,37 @@ static HttpResponse rawRestRequest( } webClient = WebClient.create(vertx, webClientOptions); + return rawRestRequest( + vertx, + webClient, + restApp, + httpVersion, + method, + uri, + requestBody, + mediaType, + writeStream + ); + } finally { + if (webClient != null) { + webClient.close(); + } + vertx.close(); + } + } + static HttpResponse rawRestRequest( + final Vertx vertx, + final WebClient webClient, + final TestKsqlRestApp restApp, + final HttpVersion httpVersion, + final HttpMethod method, + final String uri, + final Object requestBody, + final String mediaType, + final Optional> writeStream + ) { + try { byte[] bytes = ApiJsonMapper.INSTANCE.get().writeValueAsBytes(requestBody); Buffer bodyBuffer = Buffer.buffer(bytes); @@ -290,15 +325,11 @@ static HttpResponse rawRestRequest( } else { request.send(requestFuture); } + writeStream.ifPresent(s -> request.as(BodyCodec.pipe(s))); return requestFuture.get(); } catch (Exception e) { throw new RuntimeException(e); - } finally { - if (webClient != null) { - webClient.close(); - } - vertx.close(); } } diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/FaultyKsqlClient.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/FaultyKsqlClient.java index e4af7f825e80..989786900ee1 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/FaultyKsqlClient.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/FaultyKsqlClient.java @@ -25,6 +25,7 @@ import java.net.URI; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -65,10 +66,22 @@ public RestResponse> makeQueryRequest( final URI serverEndPoint, final String sql, final Map configOverrides, - final Map requestProperties) { + final Map requestProperties + ) { return getClient().makeQueryRequest(serverEndPoint, sql, configOverrides, requestProperties); } + @Override + public RestResponse makeQueryRequest( + final URI serverEndPoint, + final String sql, + final Map configOverrides, + final Map requestProperties, + final Consumer> rowConsumer) { + return getClient().makeQueryRequest(serverEndPoint, sql, configOverrides, requestProperties, + rowConsumer); + } + @Override public void makeAsyncHeartbeatRequest(final URI serverEndPoint, final KsqlHostInfo host, final long timestamp) { diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisherTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisherTest.java index 9f400f02f279..92ca04cc8ed2 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisherTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryPublisherTest.java @@ -16,15 +16,22 @@ package io.confluent.ksql.rest.server.resources.streaming; import static com.google.common.util.concurrent.RateLimiter.create; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.ListeningScheduledExecutorService; import io.confluent.ksql.GenericRow; import io.confluent.ksql.config.SessionConfig; import io.confluent.ksql.engine.KsqlEngine; @@ -33,19 +40,20 @@ import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.physical.pull.HARouting; import io.confluent.ksql.physical.pull.PullQueryResult; -import io.confluent.ksql.query.QueryId; +import io.confluent.ksql.query.PullQueryQueue; import io.confluent.ksql.rest.entity.StreamedRow; -import io.confluent.ksql.rest.entity.TableRows; import io.confluent.ksql.rest.server.resources.streaming.Flow.Subscriber; import io.confluent.ksql.rest.server.resources.streaming.Flow.Subscription; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; import io.confluent.ksql.services.ServiceContext; import io.confluent.ksql.statement.ConfiguredStatement; +import io.confluent.ksql.util.KeyValue; import io.confluent.ksql.util.KsqlConfig; import java.util.Collection; import java.util.List; import java.util.Optional; +import java.util.function.Consumer; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -54,7 +62,6 @@ import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.mockito.stubbing.Answer; @RunWith(MockitoJUnitRunner.class) public class PullQueryPublisherTest { @@ -62,11 +69,17 @@ public class PullQueryPublisherTest { private static final LogicalSchema PULL_SCHEMA = LogicalSchema.builder() .keyColumn(ColumnName.of("id"), SqlTypes.STRING) - .valueColumn(ColumnName.of("bob"), SqlTypes.INTEGER) - .valueColumn(ColumnName.of("foo"), SqlTypes.BIGINT) - .valueColumn(ColumnName.of("bar"), SqlTypes.DOUBLE) + .valueColumn(ColumnName.of("bob"), SqlTypes.STRING) .build(); + private static final List ROW1 = ImmutableList.of("a", "b"); + private static final List ROW2 = ImmutableList.of("c", "d"); + + private static final KeyValue, GenericRow> KV1 + = new KeyValue<>(null, GenericRow.fromList(ROW1)); + private static final KeyValue, GenericRow> KV2 + = new KeyValue<>(null, GenericRow.fromList(ROW2)); + @Mock private KsqlEngine engine; @Mock @@ -76,14 +89,12 @@ public class PullQueryPublisherTest { @Mock private Subscriber> subscriber; @Mock - private List> tableRows; + private ListeningScheduledExecutorService exec; @Mock - private TableRows entity; + private PullQueryQueue pullQueryQueue; @Mock private PullQueryResult pullQueryResult; @Mock - private QueryId queryId; - @Mock private RoutingFilterFactory routingFilterFactory; @Mock private SessionConfig sessionConfig; @@ -94,6 +105,10 @@ public class PullQueryPublisherTest { @Captor private ArgumentCaptor subscriptionCaptor; + @Captor + private ArgumentCaptor> completeCaptor; + @Captor + private ArgumentCaptor> onErrorCaptor; private Subscription subscription; private PullQueryPublisher publisher; @@ -103,6 +118,7 @@ public void setUp() { publisher = new PullQueryPublisher( engine, serviceContext, + exec, statement, Optional.empty(), TIME_NANOS, @@ -110,19 +126,32 @@ public void setUp() { create(1), haRouting); - - when(statement.getStatementText()).thenReturn(""); when(statement.getSessionConfig()).thenReturn(sessionConfig); when(sessionConfig.getConfig(false)).thenReturn(ksqlConfig); when(sessionConfig.getOverrides()).thenReturn(ImmutableMap.of()); - when(pullQueryResult.getQueryId()).thenReturn(queryId); when(pullQueryResult.getSchema()).thenReturn(PULL_SCHEMA); - when(pullQueryResult.getTableRows()).thenReturn(tableRows); - when(pullQueryResult.getSourceNodes()).thenReturn(Optional.empty()); - when(engine.executePullQuery(any(), any(), any(), any(), any(), any())) + when(pullQueryResult.getPullQueryQueue()).thenReturn(pullQueryQueue); + doNothing().when(pullQueryResult).onException(onErrorCaptor.capture()); + doNothing().when(pullQueryResult).onCompletion(completeCaptor.capture()); + int[] times = new int[1]; + doAnswer(inv -> { + Collection, GenericRow>> c = inv.getArgument(0); + if (times[0] == 0) { + c.add(KV1); + } else if (times[0] == 1) { + c.add(KV2); + completeCaptor.getValue().accept(null); + } + times[0]++; + return null; + }).when(pullQueryQueue).drainTo(any()); + when(engine.executePullQuery(any(), any(), any(), any(), any(), anyBoolean())) .thenReturn(pullQueryResult); - - doAnswer(callRequestAgain()).when(subscriber).onNext(any()); + when(exec.submit(any(Runnable.class))).thenAnswer(inv -> { + Runnable runnable = inv.getArgument(0); + runnable.run(); + return null; + }); } @Test @@ -144,7 +173,8 @@ public void shouldRunQueryWithCorrectParams() { // Then: verify(engine).executePullQuery( - eq(serviceContext), eq(statement), eq(haRouting), eq(routingFilterFactory), any(), eq(Optional.empty())); + eq(serviceContext), eq(statement), eq(haRouting), any(), eq(Optional.empty()), + anyBoolean()); } @Test @@ -158,7 +188,8 @@ public void shouldOnlyExecuteOnce() { // Then: verify(subscriber).onNext(any()); verify(engine).executePullQuery( - eq(serviceContext), eq(statement), eq(haRouting), eq(routingFilterFactory), any(), eq(Optional.empty())); + eq(serviceContext), eq(statement), eq(haRouting), any(), + eq(Optional.empty()), anyBoolean()); } @Test @@ -168,11 +199,13 @@ public void shouldCallOnSchemaThenOnNextThenOnCompleteOnSuccess() { // When: subscription.request(1); + subscription.request(1); + subscription.request(1); // Then: final InOrder inOrder = inOrder(subscriber); inOrder.verify(subscriber).onSchema(any()); - inOrder.verify(subscriber).onNext(any()); + inOrder.verify(subscriber, times(2)).onNext(any()); inOrder.verify(subscriber).onComplete(); } @@ -189,13 +222,29 @@ public void shouldPassSchema() { } @Test - public void shouldCallOnErrorOnFailure() { + public void shouldCallOnErrorOnFailure_initial() { + // Given: + when(engine.executePullQuery(any(), any(), any(), any(), any(), anyBoolean())) + .thenThrow(new RuntimeException("Boom!")); + + // When: + final Exception e = assertThrows( + RuntimeException.class, + () -> givenSubscribed() + ); + + // Then: + assertThat(e.getMessage(), containsString("Boom!")); + } + + @Test + public void shouldCallOnErrorOnFailure_duringStream() { // Given: givenSubscribed(); - final Throwable e = new RuntimeException("Boom!"); - when(engine.executePullQuery(any(), any(), any(), any(), any(), any())).thenThrow(e); + RuntimeException e = new RuntimeException("Boom!"); // When: + onErrorCaptor.getValue().accept(e); subscription.request(1); // Then: @@ -207,28 +256,19 @@ public void shouldBuildStreamingRows() { // Given: givenSubscribed(); - when(pullQueryResult.getTableRows()).thenReturn(ImmutableList.of( - ImmutableList.of("a", 1, 2L, 3.0f), - ImmutableList.of("b", 1, 2L, 3.0f) - )); - when(pullQueryResult.getSourceNodes()) - .thenReturn(Optional.empty()); - // When: subscription.request(1); + subscription.request(1); + subscription.request(1); // Then: + verify(subscriber, times(2)).onNext(any()); verify(subscriber).onNext(ImmutableList.of( - StreamedRow.pushRow(GenericRow.genericRow("a", 1, 2L, 3.0f)), - StreamedRow.pushRow(GenericRow.genericRow("b", 1, 2L, 3.0f)) + StreamedRow.pushRow(GenericRow.fromList(ROW1)) + )); + verify(subscriber).onNext(ImmutableList.of( + StreamedRow.pushRow(GenericRow.fromList(ROW2)) )); - } - - private Answer callRequestAgain() { - return inv -> { - subscription.request(1); - return null; - }; } private void givenSubscribed() { diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryStreamWriterTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryStreamWriterTest.java new file mode 100644 index 000000000000..c338f86cc283 --- /dev/null +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/PullQueryStreamWriterTest.java @@ -0,0 +1,311 @@ +package io.confluent.ksql.rest.server.resources.streaming; + +import static io.confluent.ksql.rest.server.resources.streaming.PullQueryStreamWriter.QueueWrapper.END_ROW; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasItems; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.physical.pull.PullQueryResult; +import io.confluent.ksql.physical.pull.PullQueryRow; +import io.confluent.ksql.query.PullQueryQueue; +import io.confluent.ksql.query.QueryId; +import io.confluent.ksql.rest.ApiJsonMapper; +import io.confluent.ksql.rest.entity.StreamedRow; +import io.confluent.ksql.schema.ksql.LogicalSchema; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.util.KsqlException; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; + +@SuppressWarnings("unchecked") +@RunWith(MockitoJUnitRunner.class) +public class PullQueryStreamWriterTest { + + private static final LogicalSchema SCHEMA = LogicalSchema.builder() + .valueColumn(ColumnName.of("a"), SqlTypes.STRING) + .build(); + + @Rule + public final Timeout timeout = Timeout.builder() + .withTimeout(30, TimeUnit.SECONDS) + .withLookingForStuckThread(true) + .build(); + + @Mock + private PullQueryResult pullQueryResult; + @Mock + private PullQueryQueue pullQueryQueue; + @Mock + private Clock clock; + + @Captor + private ArgumentCaptor> throwableConsumerCapture; + @Captor + private ArgumentCaptor> completeCapture; + + private ScheduledExecutorService executorService; + private ByteArrayOutputStream out; + private PullQueryStreamWriter writer; + + @Before + public void setUp() { + when(pullQueryResult.getQueryId()).thenReturn(new QueryId("Query id")); + when(pullQueryResult.getSchema()).thenReturn(SCHEMA); + doNothing().when(pullQueryResult).onCompletion(completeCapture.capture()); + writer = new PullQueryStreamWriter(pullQueryResult, 1000, ApiJsonMapper.INSTANCE.get(), + pullQueryQueue, clock, new CompletableFuture<>()); + + executorService = Executors.newSingleThreadScheduledExecutor(); + + out = new ByteArrayOutputStream(); + } + + @After + public void tearDown() { + executorService.shutdownNow(); + } + + @Test + public void shouldWriteAnyPendingRowsBeforeReportingException() throws IOException { + // Given: + doAnswer(streamRows("Row1", "Row2", "Row3")) + .when(pullQueryQueue).drainRowsTo(any()); + + givenUncaughtException(new KsqlException("Server went Boom")); + + // When: + writer.write(out); + + // Then: + final List lines = getOutput(out); + assertThat(lines, contains( + containsString("header"), + containsString("Row1"), + containsString("Row2"), + containsString("Row3"), + containsString("Server went Boom") + )); + } + + @Test + public void shouldWriteNoRows() throws IOException { + // Given: + completeCapture.getValue().accept(null); + doAnswer(streamRows()) + .when(pullQueryQueue).drainRowsTo(any()); + + // When: + writer.write(out); + + // Then: + verify(pullQueryQueue).putSentinelRow(END_ROW); + final List lines = getOutput(out); + assertThat(lines, contains( + containsString("header") + )); + } + + @Test + public void shouldExitAndDrainIfAlreadyComplete() throws IOException { + // Given: + completeCapture.getValue().accept(null); + doAnswer(streamRows("Row1", "Row2", "Row3")) + .when(pullQueryQueue).drainRowsTo(any()); + + // When: + writer.write(out); + + // Then: + verify(pullQueryQueue).putSentinelRow(END_ROW); + final List lines = getOutput(out); + assertThat(lines, contains( + containsString("header"), + containsString("Row1"), + containsString("Row2"), + containsString("Row3"))); + } + + @Test + public void shouldExitAndDrainIfLimitReached() throws IOException { + // Given: + doAnswer(streamRows("Row1", "Row2", "Row3")) + .when(pullQueryQueue).drainRowsTo(any()); + + completeCapture.getValue().accept(null); + + // When: + writer.write(out); + + // Then: + verify(pullQueryQueue).putSentinelRow(END_ROW); + final List lines = getOutput(out); + assertThat(lines, hasItems( + containsString("header"), + containsString("Row1"), + containsString("Row2"), + containsString("Row3"))); + } + + + @Test + public void shouldWriteOneAndClose() throws InterruptedException, IOException { + // Given: + when(pullQueryQueue.pollRow(anyLong(), any())) + .thenReturn(new PullQueryRow(ImmutableList.of("Row1"), SCHEMA, Optional.empty())) + .thenAnswer(inv -> { + completeCapture.getValue().accept(null); + return END_ROW; + }); + + // When: + writer.write(out); + + // Then: + verify(pullQueryQueue).putSentinelRow(END_ROW); + final List lines = getOutput(out); + assertThat(lines, contains( + containsString("header"), + containsString("Row1") + )); + } + + @Test + public void shouldWriteTwoAndClose() throws InterruptedException, IOException { + // Given: + when(pullQueryQueue.pollRow(anyLong(), any())) + .thenReturn(streamRow("Row1")) + .thenReturn(streamRow("Row2")) + .thenAnswer(inv -> { + completeCapture.getValue().accept(null); + return END_ROW; + }); + + // When: + writer.write(out); + + // Then: + verify(pullQueryQueue).putSentinelRow(END_ROW); + final List lines = getOutput(out); + assertThat(lines, contains( + containsString("header"), + containsString("Row1"), + containsString("Row2") + )); + } + + @Test + public void shouldWriteTwoAndCloseWithOneMoreQueue() throws InterruptedException, IOException { + // Given: + when(pullQueryQueue.pollRow(anyLong(), any())) + .thenReturn(streamRow("Row1")) + .thenReturn(streamRow("Row2")) + .thenAnswer(inv -> { + completeCapture.getValue().accept(null); + return END_ROW; + }); + doAnswer(streamRows("Row3")) + .when(pullQueryQueue).drainRowsTo(any()); + + // When: + writer.write(out); + + // Then: + verify(pullQueryQueue).putSentinelRow(END_ROW); + final List lines = getOutput(out); + assertThat(lines, contains( + containsString("header"), + containsString("Row1"), + containsString("Row2"), + containsString("Row3") + )); + } + + @Test + public void shouldProperlyEscapeJSON() throws InterruptedException, IOException { + // Given: + when(pullQueryQueue.pollRow(anyLong(), any())) + .thenReturn(streamRow("foo\nbar")) + .thenAnswer(inv -> { + completeCapture.getValue().accept(null); + return null; + }); + + // When: + writer.write(out); + + // Then: + final List lines = getOutput(out); + assertThat(lines.size(), is(2)); + assertThat(lines, contains( + containsString("header"), + containsString("foo\\nbar") + )); + String lastRow = lines.get(1).replaceFirst("]$", ""); + StreamedRow row = ApiJsonMapper.INSTANCE.get().readValue(lastRow, StreamedRow.class); + assertThat(row.getRow().isPresent(), is(true)); + assertThat(row.getRow().get().getColumns().get(0), is("foo\nbar")); + } + + private void givenUncaughtException(final KsqlException e) { + verify(pullQueryResult).onException(throwableConsumerCapture.capture()); + throwableConsumerCapture.getValue().accept(e); + } + + private static Answer streamRows(final Object... rows) { + return inv -> { + final Collection output = inv.getArgument(0); + + Arrays.stream(rows) + .map(row -> new PullQueryRow(ImmutableList.of(row), SCHEMA, Optional.empty())) + .forEach(output::add); + + return null; + }; + } + + private static PullQueryRow streamRow(final Object row) { + return new PullQueryRow(ImmutableList.of(row), SCHEMA, Optional.empty()); + } + + private static List getOutput(final ByteArrayOutputStream out) throws IOException { + // Make sure it's parsable as valid JSON + ApiJsonMapper.INSTANCE.get().readTree(out.toByteArray()); + final String[] lines = new String(out.toByteArray(), StandardCharsets.UTF_8).split("\n"); + return Arrays.stream(lines) + .filter(line -> !line.isEmpty()) + .collect(Collectors.toList()); + } +} diff --git a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResourceTest.java b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResourceTest.java index b8e862f828d8..c0bb79265b44 100644 --- a/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResourceTest.java +++ b/ksqldb-rest-app/src/test/java/io/confluent/ksql/rest/server/resources/streaming/StreamedQueryResourceTest.java @@ -34,6 +34,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; @@ -65,6 +66,7 @@ import io.confluent.ksql.query.BlockingRowQueue; import io.confluent.ksql.query.KafkaStreamsBuilder; import io.confluent.ksql.query.LimitHandler; +import io.confluent.ksql.query.PullQueryQueue; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.rest.ApiJsonMapper; import io.confluent.ksql.rest.EndpointResponse; @@ -189,6 +191,8 @@ public class StreamedQueryResourceTest { private LogicalSchema schema; @Mock private HARouting haRouting; + @Mock + private PullQueryQueue pullQueryQueue; private StreamedQueryResource testResource; private PreparedStatement invalid; @@ -311,10 +315,9 @@ public void shouldRateLimit() { Optional.empty() ); testResource.configure(VALID_CONFIG); - when(mockKsqlEngine.executePullQuery(any(), any(), any(), any(), any(), any())).thenReturn(pullQueryResult); - when(pullQueryResult.getTableRows()).thenReturn(Collections.emptyList()); - when(pullQueryResult.getSchema()).thenReturn(schema); - when(pullQueryResult.getQueryId()).thenReturn(queryId); + when(mockKsqlEngine.executePullQuery(any(), any(), any(), any(), any(), anyBoolean())) + .thenReturn(pullQueryResult); + when(pullQueryResult.getPullQueryQueue()).thenReturn(pullQueryQueue); // When: testResource.streamQuery( diff --git a/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlTarget.java b/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlTarget.java index 5589139a2b93..a0a9893c3414 100644 --- a/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlTarget.java +++ b/ksqldb-rest-client/src/main/java/io/confluent/ksql/rest/client/KsqlTarget.java @@ -49,13 +49,16 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -@SuppressWarnings("WeakerAccess") // Public API +@SuppressWarnings({"WeakerAccess", "checkstyle:ClassDataAbstractionCoupling"}) // Public API public final class KsqlTarget { private static final Logger log = LoggerFactory.getLogger(KsqlTarget.class); @@ -171,6 +174,25 @@ public RestResponse postKsqlRequest( ); } + public RestResponse postQueryRequest( + final String ksql, + final Map requestProperties, + final Optional previousCommandSeqNum, + final Consumer> rowConsumer + ) { + final AtomicInteger rowCount = new AtomicInteger(0); + return post( + QUERY_PATH, + createKsqlRequest(ksql, requestProperties, previousCommandSeqNum), + rowCount::get, + rows -> { + final List streamedRows = toRows(rows); + rowCount.addAndGet(streamedRows.size()); + return streamedRows; + }, + rowConsumer); + } + public RestResponse> postQueryRequest( final String ksql, final Map requestProperties, @@ -179,8 +201,7 @@ public RestResponse> postQueryRequest( return post( QUERY_PATH, createKsqlRequest(ksql, requestProperties, previousCommandSeqNum), - KsqlTarget::toRows - ); + KsqlTarget::toRows); } public RestResponse> postQueryRequestStreamed( @@ -223,6 +244,17 @@ private RestResponse post( return executeRequestSync(HttpMethod.POST, path, jsonEntity, mapper); } + private RestResponse post( + final String path, + final Object jsonEntity, + final Supplier responseSupplier, + final Function mapper, + final Consumer chunkHandler + ) { + return executeRequestSync(HttpMethod.POST, path, jsonEntity, responseSupplier, mapper, + chunkHandler); + } + private CompletableFuture> executeRequestAsync( final HttpMethod httpMethod, final String path, @@ -245,6 +277,36 @@ private RestResponse executeRequestSync( }); } + private RestResponse executeRequestSync( + final HttpMethod httpMethod, + final String path, + final Object requestBody, + final Supplier responseSupplier, + final Function chunkMapper, + final Consumer chunkHandler + ) { + return executeSync(httpMethod, path, requestBody, resp -> responseSupplier.get(), + (resp, vcf) -> { + resp.handler(buff -> { + try { + chunkHandler.accept(chunkMapper.apply(buff)); + } catch (Throwable t) { + log.error("Error while handling chunk", t); + vcf.completeExceptionally(t); + } + }); + resp.endHandler(v -> { + try { + chunkHandler.accept(null); + vcf.complete(new ResponseWithBody(resp, Buffer.buffer())); + } catch (Throwable t) { + log.error("Error while handling end", t); + vcf.completeExceptionally(t); + } + }); + }); + } + private RestResponse> executeQueryRequestWithStreamResponse( final String ksql, final Optional previousCommandSeqNum, @@ -324,13 +386,18 @@ private CompletableFuture execute( } private static List toRows(final ResponseWithBody resp) { + return toRows(resp.getBody()); + } + + // This is meant to parse partial chunk responses as well as full pull query responses. + private static List toRows(final Buffer buff) { final List rows = new ArrayList<>(); - final Buffer buff = resp.getBody(); int begin = 0; for (int i = 0; i <= buff.length(); i++) { - if ((i == buff.length() && (i - begin > 1)) || buff.getByte(i) == (byte) '\n') { + if ((i == buff.length() && (i - begin > 1)) + || (i < buff.length() && buff.getByte(i) == (byte) '\n')) { if (begin != i) { // Ignore random newlines - the server can send these final Buffer sliced = buff.slice(begin, i); final Buffer tidied = StreamPublisher.toJsonMsg(sliced);