diff --git a/x-pack/plugin/profiler/src/internalClusterTest/java/org/elasticsearch/xpack/profiler/GetProfilingActionIT.java b/x-pack/plugin/profiler/src/internalClusterTest/java/org/elasticsearch/xpack/profiler/GetProfilingActionIT.java index 7f3113f73ad5c..9c4cf0449ae2d 100644 --- a/x-pack/plugin/profiler/src/internalClusterTest/java/org/elasticsearch/xpack/profiler/GetProfilingActionIT.java +++ b/x-pack/plugin/profiler/src/internalClusterTest/java/org/elasticsearch/xpack/profiler/GetProfilingActionIT.java @@ -7,24 +7,59 @@ package org.elasticsearch.xpack.profiler; +import org.apache.http.entity.ContentType; +import org.apache.http.entity.StringEntity; +import org.apache.logging.log4j.LogManager; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.admin.cluster.node.info.NodeInfo; +import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; +import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse; import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.Cancellable; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.network.NetworkModule; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.PluginsService; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.script.MockScriptPlugin; +import org.elasticsearch.search.lookup.LeafStoredFieldsLookup; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskInfo; +import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.transport.netty4.Netty4Plugin; import org.elasticsearch.xcontent.XContentType; import org.junit.Before; import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CancellationException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; + +import static org.elasticsearch.action.support.ActionTestUtils.wrapAsRestResponseListener; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; @ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 1) public class GetProfilingActionIT extends ESIntegTestCase { @Override protected Collection> nodePlugins() { - return List.of(ProfilingPlugin.class); + return List.of(ProfilingPlugin.class, ScriptedBlockPlugin.class, getTestTransportPlugin()); } @Override @@ -32,9 +67,21 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { return Settings.builder() .put(super.nodeSettings(nodeOrdinal, otherSettings)) .put(ProfilingPlugin.PROFILING_ENABLED.getKey(), true) + .put(NetworkModule.TRANSPORT_TYPE_KEY, Netty4Plugin.NETTY_TRANSPORT_NAME) + .put(NetworkModule.HTTP_TYPE_KEY, Netty4Plugin.NETTY_HTTP_TRANSPORT_NAME) .build(); } + @Override + protected boolean addMockHttpTransport() { + return false; // enable http + } + + @Override + protected boolean ignoreExternalCluster() { + return true; + } + private byte[] read(String resource) throws IOException { return GetProfilingAction.class.getClassLoader().getResourceAsStream(resource).readAllBytes(); } @@ -104,4 +151,137 @@ public void testGetProfilingDataUnfiltered() throws Exception { assertNotNull(response.getExecutables()); assertNotNull("libc.so.6", response.getExecutables().get("QCCDqjSg3bMK1C4YRK6Tiw")); } + + public void testAutomaticCancellation() throws Exception { + Request restRequest = new Request("POST", "/_profiling/stacktraces"); + restRequest.setEntity(new StringEntity(""" + { + "sample_size": 10000, + "query": { + "bool": { + "filter": [ + { + "script": { + "script": { + "lang": "mockscript", + "source": "search_block", + "params": {} + } + } + } + ] + } + } + } + """, ContentType.APPLICATION_JSON.withCharset(StandardCharsets.UTF_8))); + verifyCancellation(GetProfilingAction.NAME, restRequest); + } + + void verifyCancellation(String action, Request restRequest) throws Exception { + Map nodeIdToName = readNodesInfo(); + List plugins = initBlockFactory(); + + PlainActionFuture future = PlainActionFuture.newFuture(); + Cancellable cancellable = getRestClient().performRequestAsync(restRequest, wrapAsRestResponseListener(future)); + + awaitForBlock(plugins); + cancellable.cancel(); + ensureTaskIsCancelled(action, nodeIdToName::get); + + disableBlocks(plugins); + expectThrows(CancellationException.class, future::actionGet); + } + + private static Map readNodesInfo() { + Map nodeIdToName = new HashMap<>(); + NodesInfoResponse nodesInfoResponse = client().admin().cluster().prepareNodesInfo().get(); + assertFalse(nodesInfoResponse.hasFailures()); + for (NodeInfo node : nodesInfoResponse.getNodes()) { + nodeIdToName.put(node.getNode().getId(), node.getNode().getName()); + } + return nodeIdToName; + } + + private static void ensureTaskIsCancelled(String transportAction, Function nodeIdToName) throws Exception { + SetOnce searchTask = new SetOnce<>(); + ListTasksResponse listTasksResponse = client().admin().cluster().prepareListTasks().get(); + for (TaskInfo task : listTasksResponse.getTasks()) { + if (task.action().equals(transportAction)) { + searchTask.set(task); + } + } + assertNotNull(searchTask.get()); + TaskId taskId = searchTask.get().taskId(); + String nodeName = nodeIdToName.apply(taskId.getNodeId()); + assertBusy(() -> { + TaskManager taskManager = internalCluster().getInstance(TransportService.class, nodeName).getTaskManager(); + Task task = taskManager.getTask(taskId.getId()); + assertThat(task, instanceOf(CancellableTask.class)); + assertTrue(((CancellableTask) task).isCancelled()); + }); + } + + private static List initBlockFactory() { + List plugins = new ArrayList<>(); + for (PluginsService pluginsService : internalCluster().getDataNodeInstances(PluginsService.class)) { + plugins.addAll(pluginsService.filterPlugins(ScriptedBlockPlugin.class)); + } + for (ScriptedBlockPlugin plugin : plugins) { + plugin.reset(); + plugin.enableBlock(); + } + return plugins; + } + + private void awaitForBlock(List plugins) throws Exception { + assertBusy(() -> { + int numberOfBlockedPlugins = 0; + for (ScriptedBlockPlugin plugin : plugins) { + numberOfBlockedPlugins += plugin.hits.get(); + } + logger.info("The plugin blocked on {} shards", numberOfBlockedPlugins); + assertThat(numberOfBlockedPlugins, greaterThan(0)); + }, 10, TimeUnit.SECONDS); + } + + private static void disableBlocks(List plugins) { + for (ScriptedBlockPlugin plugin : plugins) { + plugin.disableBlock(); + } + } + + public static class ScriptedBlockPlugin extends MockScriptPlugin { + static final String SCRIPT_NAME = "search_block"; + + private final AtomicInteger hits = new AtomicInteger(); + + private final AtomicBoolean shouldBlock = new AtomicBoolean(true); + + void reset() { + hits.set(0); + } + + void disableBlock() { + shouldBlock.set(false); + } + + void enableBlock() { + shouldBlock.set(true); + } + + @Override + public Map, Object>> pluginScripts() { + return Collections.singletonMap(SCRIPT_NAME, params -> { + LeafStoredFieldsLookup fieldsLookup = (LeafStoredFieldsLookup) params.get("_fields"); + LogManager.getLogger(GetProfilingActionIT.class).info("Blocking on the document {}", fieldsLookup.get("_id")); + hits.incrementAndGet(); + try { + waitUntil(() -> shouldBlock.get() == false); + } catch (Exception e) { + throw new RuntimeException(e); + } + return true; + }); + } + } }