diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java index 8cf38ae30ed5..ff1cdff2b67d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java @@ -1111,7 +1111,6 @@ private void transitionToFinishedIfReady() public boolean transitionToFailed(Throwable throwable) { - cleanupQueryQuietly(); queryStateTimer.endQuery(); // NOTE: The failure cause must be set before triggering the state change, so @@ -1120,6 +1119,8 @@ public boolean transitionToFailed(Throwable throwable) requireNonNull(throwable, "throwable is null"); failureCause.compareAndSet(null, toFailure(throwable)); + cleanupQueryQuietly(); + QueryState oldState = queryState.trySet(FAILED); if (oldState.isDone()) { QUERY_STATE_LOG.debug(throwable, "Failure after query %s finished", queryId); @@ -1155,7 +1156,6 @@ public boolean transitionToFailed(Throwable throwable) public boolean transitionToCanceled() { - cleanupQueryQuietly(); queryStateTimer.endQuery(); // NOTE: The failure cause must be set before triggering the state change, so @@ -1163,6 +1163,8 @@ public boolean transitionToCanceled() // can only be observed if the transition to FAILED is successful. failureCause.compareAndSet(null, toFailure(new TrinoException(USER_CANCELED, "Query was canceled"))); + cleanupQueryQuietly(); + boolean canceled = queryState.setIf(FAILED, currentState -> !currentState.isDone()); if (canceled) { session.getTransactionId().flatMap(transactionManager::getTransactionInfoIfExist).ifPresent(transaction -> { diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryStateMachine.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryStateMachine.java index d1a2bc418046..d21b2f92e3f8 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryStateMachine.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryStateMachine.java @@ -29,6 +29,7 @@ import io.trino.plugin.base.security.DefaultSystemAccessControl; import io.trino.security.AccessControlConfig; import io.trino.security.AccessControlManager; +import io.trino.server.BasicQueryInfo; import io.trino.spi.TrinoException; import io.trino.spi.connector.CatalogHandle.CatalogVersion; import io.trino.spi.resourcegroups.QueryType; @@ -37,6 +38,7 @@ import io.trino.sql.analyzer.Output; import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.tracing.TracingMetadata; import io.trino.transaction.TransactionManager; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; @@ -50,11 +52,16 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Optional; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import java.util.function.Consumer; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.Uninterruptibles.awaitUninterruptibly; import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.tracing.Tracing.noopTracer; import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.execution.QueryState.DISPATCHING; import static io.trino.execution.QueryState.FAILED; @@ -68,6 +75,7 @@ import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector; import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.trino.spi.StandardErrorCode.TYPE_MISMATCH; import static io.trino.spi.StandardErrorCode.USER_CANCELED; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.testing.TestingEventListenerManager.emptyEventListenerManager; @@ -411,6 +419,74 @@ public void testUpdateMemoryUsage() assertThat(stateMachine.getPeakTaskRevocableMemory()).isEqualTo(10); } + @Test + public void testPreserveFirstFailure() + throws Exception + { + CountDownLatch cleanup = new CountDownLatch(1); + QueryStateMachine queryStateMachine = queryStateMachine() + .withMetadata(new TracingMetadata(noopTracer(), createTestMetadataManager()) + { + @Override + public void cleanupQuery(Session session) + { + cleanup.countDown(); + super.cleanupQuery(session); + } + }) + .build(); + + Future anotherThread = executor.submit(() -> { + checkState(awaitUninterruptibly(cleanup, 10, SECONDS), "Timed out waiting for cleanup latch"); + queryStateMachine.transitionToFailed(new IllegalStateException("Second exception")); + }); + Future failingThread = executor.submit(() -> { + queryStateMachine.transitionToFailed(new TrinoException(TYPE_MISMATCH, "First exception")); + }); + + failingThread.get(10, SECONDS); + anotherThread.get(10, SECONDS); + + ExecutionFailureInfo failureInfo = queryStateMachine.getFinalQueryInfo().orElseThrow().getFailureInfo(); + assertThat(failureInfo).isNotNull(); + assertThat(failureInfo.getErrorCode()).isEqualTo(TYPE_MISMATCH.toErrorCode()); + assertThat(failureInfo.getMessage()).isEqualTo("First exception"); + + BasicQueryInfo basicQueryInfo = queryStateMachine.getBasicQueryInfo(Optional.empty()); + assertThat(basicQueryInfo.getErrorCode()).isEqualTo(TYPE_MISMATCH.toErrorCode()); + } + + @Test + public void testPreserveCancellation() + throws Exception + { + CountDownLatch cleanup = new CountDownLatch(1); + QueryStateMachine queryStateMachine = queryStateMachine() + .withMetadata(new TracingMetadata(noopTracer(), createTestMetadataManager()) + { + @Override + public void cleanupQuery(Session session) + { + cleanup.countDown(); + super.cleanupQuery(session); + } + }) + .build(); + + Future anotherThread = executor.submit(() -> { + checkState(awaitUninterruptibly(cleanup, 10, SECONDS), "Timed out waiting for cleanup latch"); + queryStateMachine.transitionToFailed(new IllegalStateException("Second exception")); + }); + Future cancellingThread = executor.submit(queryStateMachine::transitionToCanceled); + + cancellingThread.get(10, SECONDS); + anotherThread.get(10, SECONDS); + + // TODO queryStateMachine.getFinalQueryInfo() does not exist for cancelled queries, but may be created by anotherThread due to a race + BasicQueryInfo basicQueryInfo = queryStateMachine.getBasicQueryInfo(Optional.empty()); + assertThat(basicQueryInfo.getErrorCode()).isEqualTo(USER_CANCELED.toErrorCode()); + } + private static void assertFinalState(QueryStateMachine stateMachine, QueryState expectedState) { assertFinalState(stateMachine, expectedState, null); @@ -526,6 +602,7 @@ private QueryStateMachineBuilder queryStateMachine() private class QueryStateMachineBuilder { private Ticker ticker = Ticker.systemTicker(); + private Metadata metadata; @CanIgnoreReturnValue public QueryStateMachineBuilder withTicker(Ticker ticker) @@ -534,9 +611,18 @@ public QueryStateMachineBuilder withTicker(Ticker ticker) return this; } + @CanIgnoreReturnValue + public QueryStateMachineBuilder withMetadata(Metadata metadata) + { + this.metadata = metadata; + return this; + } + public QueryStateMachine build() { - Metadata metadata = createTestMetadataManager(); + if (metadata == null) { + metadata = createTestMetadataManager(); + } TransactionManager transactionManager = createTestTransactionManager(); AccessControlManager accessControl = new AccessControlManager( NodeVersion.UNKNOWN,