From 8a89c29bff8303a23ee6633605cb0e521acaf578 Mon Sep 17 00:00:00 2001
From: James Taylor <jamestaylor@lyft.com>
Date: Sat, 15 Feb 2020 11:54:18 -0800
Subject: [PATCH] Ensure QueryCompletedEvent occurs when query fails during
 planning

---
 .../dispatcher/LocalDispatchQueryFactory.java | 10 ++-
 .../execution/QueryStateMachine.java          | 18 +++--
 .../execution/TestEventListener.java          | 73 ++++++++++++++++++-
 3 files changed, 94 insertions(+), 7 deletions(-)

diff --git a/presto-main/src/main/java/io/prestosql/dispatcher/LocalDispatchQueryFactory.java b/presto-main/src/main/java/io/prestosql/dispatcher/LocalDispatchQueryFactory.java
index 5ebb0b2c76ec..1e24f78b78aa 100644
--- a/presto-main/src/main/java/io/prestosql/dispatcher/LocalDispatchQueryFactory.java
+++ b/presto-main/src/main/java/io/prestosql/dispatcher/LocalDispatchQueryFactory.java
@@ -40,6 +40,7 @@
 import java.util.Optional;
 
 import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED;
+import static io.prestosql.util.Failures.toFailure;
 import static io.prestosql.util.StatementUtils.isTransactionControlStatement;
 import static java.util.Objects.requireNonNull;
 
@@ -116,7 +117,14 @@ public DispatchQuery createDispatchQuery(
                 throw new PrestoException(NOT_SUPPORTED, "Unsupported statement type: " + preparedQuery.getStatement().getClass().getSimpleName());
             }
 
-            return queryExecutionFactory.createQueryExecution(preparedQuery, stateMachine, slug, warningCollector);
+            try {
+                return queryExecutionFactory.createQueryExecution(preparedQuery, stateMachine, slug, warningCollector);
+            }
+            catch (Throwable e) {
+                stateMachine.transitionToFailed(e);
+                queryMonitor.queryImmediateFailureEvent(stateMachine.getBasicQueryInfo(Optional.empty()), toFailure(e));
+                throw e;
+            }
         });
 
         return new LocalDispatchQuery(
diff --git a/presto-main/src/main/java/io/prestosql/execution/QueryStateMachine.java b/presto-main/src/main/java/io/prestosql/execution/QueryStateMachine.java
index 5fc231af6c18..6e316632075b 100644
--- a/presto-main/src/main/java/io/prestosql/execution/QueryStateMachine.java
+++ b/presto-main/src/main/java/io/prestosql/execution/QueryStateMachine.java
@@ -815,8 +815,13 @@ public boolean transitionToFailed(Throwable throwable)
         requireNonNull(throwable, "throwable is null");
         failureCause.compareAndSet(null, toFailure(throwable));
 
-        boolean failed = queryState.setIf(FAILED, currentState -> !currentState.isDone());
-        if (failed) {
+        QueryState oldState = queryState.trySet(FAILED);
+        if (oldState.isDone()) {
+            QUERY_STATE_LOG.debug(throwable, "Failure after query %s finished", queryId);
+            return false;
+        }
+
+        try {
             QUERY_STATE_LOG.debug(throwable, "Query %s failed", queryId);
             session.getTransactionId().ifPresent(transactionId -> {
                 if (transactionManager.isAutoCommit(transactionId)) {
@@ -827,11 +832,14 @@ public boolean transitionToFailed(Throwable throwable)
                 }
             });
         }
-        else {
-            QUERY_STATE_LOG.debug(throwable, "Failure after query %s finished", queryId);
+        finally {
+            // if the query has not started, then there is no final query info to wait for
+            if (oldState.ordinal() <= PLANNING.ordinal()) {
+                finalQueryInfo.compareAndSet(Optional.empty(), Optional.of(getQueryInfo(Optional.empty())));
+            }
         }
 
-        return failed;
+        return true;
     }
 
     public boolean transitionToCanceled()
diff --git a/presto-tests/src/test/java/io/prestosql/execution/TestEventListener.java b/presto-tests/src/test/java/io/prestosql/execution/TestEventListener.java
index c61de6c504cc..334f323c5829 100644
--- a/presto-tests/src/test/java/io/prestosql/execution/TestEventListener.java
+++ b/presto-tests/src/test/java/io/prestosql/execution/TestEventListener.java
@@ -17,12 +17,17 @@
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import io.prestosql.Session;
+import io.prestosql.connector.MockConnectorFactory;
 import io.prestosql.execution.TestEventListenerPlugin.TestingEventListenerPlugin;
 import io.prestosql.plugin.resourcegroups.ResourceGroupManagerPlugin;
 import io.prestosql.plugin.tpch.TpchPlugin;
+import io.prestosql.spi.Plugin;
 import io.prestosql.spi.QueryId;
+import io.prestosql.spi.connector.ConnectorFactory;
+import io.prestosql.spi.connector.SchemaTableName;
 import io.prestosql.spi.eventlistener.QueryCompletedEvent;
 import io.prestosql.spi.eventlistener.QueryCreatedEvent;
+import io.prestosql.spi.eventlistener.QueryFailureInfo;
 import io.prestosql.spi.eventlistener.SplitCompletedEvent;
 import io.prestosql.testing.DistributedQueryRunner;
 import io.prestosql.testing.MaterializedResult;
@@ -37,13 +42,16 @@
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 
+import static com.google.common.base.Strings.nullToEmpty;
 import static com.google.common.collect.Iterables.getOnlyElement;
 import static io.prestosql.execution.TestQueues.createResourceGroupId;
 import static io.prestosql.testing.TestingSession.testSessionBuilder;
+import static java.lang.String.format;
 import static java.util.stream.Collectors.toSet;
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertFalse;
 import static org.testng.Assert.assertTrue;
+import static org.testng.Assert.fail;
 
 @Test(singleThreaded = true)
 public class TestEventListener
@@ -69,6 +77,21 @@ private void setUp()
         queryRunner.installPlugin(new TestingEventListenerPlugin(generatedEvents));
         queryRunner.installPlugin(new ResourceGroupManagerPlugin());
         queryRunner.createCatalog("tpch", "tpch", ImmutableMap.of("tpch.splits-per-node", Integer.toString(SPLITS_PER_NODE)));
+        queryRunner.installPlugin(new Plugin()
+        {
+            @Override
+            public Iterable<ConnectorFactory> getConnectorFactories()
+            {
+                MockConnectorFactory connectorFactory = MockConnectorFactory.builder()
+                        .withListTables((session, s) -> ImmutableList.of(new SchemaTableName("default", "test_table")))
+                        .withApplyProjection((session, handle, projections, assignments) -> {
+                            throw new RuntimeException("Throw from apply projection");
+                        })
+                        .build();
+                return ImmutableList.of(connectorFactory);
+            }
+        });
+        queryRunner.createCatalog("mock", "mock", ImmutableMap.of());
         queryRunner.getCoordinator().getResourceGroupManager().get()
                 .setConfigurationManager("file", ImmutableMap.of("resource-groups.config-file", getResourceFilePath("resource_groups_config_simple.json")));
     }
@@ -93,9 +116,30 @@ private MaterializedResult runQueryAndWaitForEvents(@Language("SQL") String sql,
 
     private MaterializedResult runQueryAndWaitForEvents(@Language("SQL") String sql, int numEventsExpected, Session alternateSession)
             throws Exception
+    {
+        return runQueryAndWaitForEvents(sql, numEventsExpected, alternateSession, Optional.empty());
+    }
+
+    private MaterializedResult runQueryAndWaitForEvents(@Language("SQL") String sql, int numEventsExpected, Session alternateSession, Optional<String> expectedExceptionRegEx)
+            throws Exception
     {
         generatedEvents.initialize(numEventsExpected);
-        MaterializedResult result = queryRunner.execute(alternateSession, sql);
+        MaterializedResult result = null;
+        try {
+            result = queryRunner.execute(alternateSession, sql);
+        }
+        catch (RuntimeException exception) {
+            if (expectedExceptionRegEx.isPresent()) {
+                String regex = expectedExceptionRegEx.get();
+                if (!nullToEmpty(exception.getMessage()).matches(regex)) {
+                    fail(format("Expected exception message '%s' to match '%s' for query: %s", exception.getMessage(), regex, sql), exception);
+                }
+            }
+            else {
+                throw exception;
+            }
+        }
+
         generatedEvents.waitForEvents(10);
 
         return result;
@@ -129,6 +173,33 @@ public void testConstantQuery()
         assertEquals(splitCompletedEvents.get(0).getStatistics().getCompletedPositions(), 1);
     }
 
+    @Test
+    public void testAnalysisFailure()
+            throws Exception
+    {
+        assertFailedQuery("EXPLAIN (TYPE IO) SELECT sum(bogus) FROM lineitem", "line 1:30: Column 'bogus' cannot be resolved");
+    }
+
+    @Test
+    public void testPlanningFailure()
+            throws Exception
+    {
+        assertFailedQuery("SELECT * FROM mock.default.tests_table", "Throw from apply projection");
+    }
+
+    private void assertFailedQuery(@Language("SQL") String sql, String expectedFailure)
+            throws Exception
+    {
+        runQueryAndWaitForEvents(sql, 2, session, Optional.of(expectedFailure));
+
+        QueryCompletedEvent queryCompletedEvent = getOnlyElement(generatedEvents.getQueryCompletedEvents());
+        assertEquals(sql, queryCompletedEvent.getMetadata().getQuery());
+
+        QueryFailureInfo failureInfo = queryCompletedEvent.getFailureInfo()
+                .orElseThrow(() -> new AssertionError("Expected query event to be failed"));
+        assertEquals(expectedFailure, failureInfo.getFailureMessage().orElse(null));
+    }
+
     @Test
     public void testNormalQuery()
             throws Exception