diff --git a/core/src/main/java/org/elasticsearch/cluster/service/ClusterService.java b/core/src/main/java/org/elasticsearch/cluster/service/ClusterService.java index 2d3820e14e8ef..1dfb49775b15d 100644 --- a/core/src/main/java/org/elasticsearch/cluster/service/ClusterService.java +++ b/core/src/main/java/org/elasticsearch/cluster/service/ClusterService.java @@ -61,19 +61,16 @@ import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.util.concurrent.FutureUtils; import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor; -import org.elasticsearch.common.util.concurrent.PrioritizedRunnable; import org.elasticsearch.common.util.iterable.Iterables; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.DiscoverySettings; import org.elasticsearch.threadpool.ThreadPool; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; -import java.util.IdentityHashMap; import java.util.Iterator; -import java.util.LinkedHashSet; import java.util.List; import java.util.Locale; import java.util.Map; @@ -85,7 +82,6 @@ import java.util.concurrent.Future; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.UnaryOperator; @@ -114,6 +110,7 @@ public class ClusterService extends AbstractLifecycleComponent { private TimeValue slowTaskLoggingThreshold; private volatile PrioritizedEsThreadPoolExecutor threadPoolExecutor; + private volatile ClusterServiceTaskBatcher taskBatcher; /** * Those 3 state listeners are changing infrequently - CopyOnWriteArrayList is just fine @@ -121,7 +118,6 @@ public class ClusterService extends AbstractLifecycleComponent { private final Collection highPriorityStateAppliers = new CopyOnWriteArrayList<>(); private final Collection normalPriorityStateAppliers = new CopyOnWriteArrayList<>(); private final Collection lowPriorityStateAppliers = new CopyOnWriteArrayList<>(); - final Map> updateTasksPerExecutor = new HashMap<>(); private final Iterable clusterStateAppliers = Iterables.concat(highPriorityStateAppliers, normalPriorityStateAppliers, lowPriorityStateAppliers); @@ -219,8 +215,9 @@ protected synchronized void doStart() { DiscoveryNodes nodes = DiscoveryNodes.builder(state.nodes()).add(localNode).localNodeId(localNode.getId()).build(); return ClusterState.builder(state).nodes(nodes).blocks(initialBlocks).build(); }); - this.threadPoolExecutor = EsExecutors.newSinglePrioritizing(UPDATE_THREAD_NAME, daemonThreadFactory(settings, UPDATE_THREAD_NAME), - threadPool.getThreadContext()); + this.threadPoolExecutor = EsExecutors.newSinglePrioritizing(UPDATE_THREAD_NAME, + daemonThreadFactory(settings, UPDATE_THREAD_NAME), threadPool.getThreadContext(), threadPool.scheduler()); + this.taskBatcher = new ClusterServiceTaskBatcher(logger, threadPoolExecutor); } @Override @@ -244,6 +241,44 @@ protected synchronized void doStop() { protected synchronized void doClose() { } + class ClusterServiceTaskBatcher extends TaskBatcher { + + ClusterServiceTaskBatcher(Logger logger, PrioritizedEsThreadPoolExecutor threadExecutor) { + super(logger, threadExecutor); + } + + @Override + protected void onTimeout(List tasks, TimeValue timeout) { + threadPool.generic().execute( + () -> tasks.forEach( + task -> ((UpdateTask) task).listener.onFailure(task.source, + new ProcessClusterEventTimeoutException(timeout, task.source)))); + } + + @Override + protected void run(Object batchingKey, List tasks, String tasksSummary) { + ClusterStateTaskExecutor taskExecutor = (ClusterStateTaskExecutor) batchingKey; + List updateTasks = (List) tasks; + runTasks(new ClusterService.TaskInputs(taskExecutor, updateTasks, tasksSummary)); + } + + class UpdateTask extends BatchedTask { + final ClusterStateTaskListener listener; + + UpdateTask(Priority priority, String source, Object task, ClusterStateTaskListener listener, + ClusterStateTaskExecutor executor) { + super(priority, source, executor, task); + this.listener = listener; + } + + @Override + public String describeTasks(List tasks) { + return ((ClusterStateTaskExecutor) batchingKey).describeTasks( + tasks.stream().map(BatchedTask::getTask).collect(Collectors.toList())); + } + } + } + /** * The local node. */ @@ -350,6 +385,7 @@ public void addTimeoutListener(@Nullable final TimeValue timeout, final TimeoutC listener.onClose(); return; } + // call the post added notification on the same event thread try { threadPoolExecutor.execute(new SourcePrioritizedRunnable(Priority.HIGH, "_add_listener_") { @@ -432,38 +468,11 @@ public void submitStateUpdateTasks(final String source, if (!lifecycle.started()) { return; } - if (tasks.isEmpty()) { - return; - } try { - @SuppressWarnings("unchecked") - ClusterStateTaskExecutor taskExecutor = (ClusterStateTaskExecutor) executor; - // convert to an identity map to check for dups based on update tasks semantics of using identity instead of equal - final IdentityHashMap tasksIdentity = new IdentityHashMap<>(tasks); - final List updateTasks = tasksIdentity.entrySet().stream().map( - entry -> new UpdateTask(source, entry.getKey(), config.priority(), taskExecutor, safe(entry.getValue(), logger)) - ).collect(Collectors.toList()); - - synchronized (updateTasksPerExecutor) { - LinkedHashSet existingTasks = updateTasksPerExecutor.computeIfAbsent(executor, - k -> new LinkedHashSet<>(updateTasks.size())); - for (UpdateTask existing : existingTasks) { - if (tasksIdentity.containsKey(existing.task)) { - throw new IllegalStateException("task [" + taskExecutor.describeTasks(Collections.singletonList(existing.task)) + - "] with source [" + source + "] is already queued"); - } - } - existingTasks.addAll(updateTasks); - } - - final UpdateTask firstTask = updateTasks.get(0); - - final TimeValue timeout = config.timeout(); - if (timeout != null) { - threadPoolExecutor.execute(firstTask, threadPool.scheduler(), timeout, () -> onTimeout(updateTasks, source, timeout)); - } else { - threadPoolExecutor.execute(firstTask); - } + List safeTasks = tasks.entrySet().stream() + .map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), safe(e.getValue(), logger), executor)) + .collect(Collectors.toList()); + taskBatcher.submitTasks(safeTasks, config.timeout()); } catch (EsRejectedExecutionException e) { // ignore cases where we are shutting down..., there is really nothing interesting // to be done here... @@ -473,60 +482,17 @@ public void submitStateUpdateTasks(final String source, } } - private void onTimeout(List updateTasks, String source, TimeValue timeout) { - threadPool.generic().execute(() -> { - final ArrayList toRemove = new ArrayList<>(); - for (UpdateTask task : updateTasks) { - if (task.processed.getAndSet(true) == false) { - logger.debug("cluster state update task [{}] timed out after [{}]", source, timeout); - toRemove.add(task); - } - } - if (toRemove.isEmpty() == false) { - ClusterStateTaskExecutor clusterStateTaskExecutor = toRemove.get(0).executor; - synchronized (updateTasksPerExecutor) { - LinkedHashSet existingTasks = updateTasksPerExecutor.get(clusterStateTaskExecutor); - if (existingTasks != null) { - existingTasks.removeAll(toRemove); - if (existingTasks.isEmpty()) { - updateTasksPerExecutor.remove(clusterStateTaskExecutor); - } - } - } - for (UpdateTask task : toRemove) { - task.listener.onFailure(source, new ProcessClusterEventTimeoutException(timeout, source)); - } - } - }); - } - /** * Returns the tasks that are pending. */ public List pendingTasks() { - PrioritizedEsThreadPoolExecutor.Pending[] pendings = threadPoolExecutor.getPending(); - List pendingClusterTasks = new ArrayList<>(pendings.length); - for (PrioritizedEsThreadPoolExecutor.Pending pending : pendings) { - final String source; - final long timeInQueue; - // we have to capture the task as it will be nulled after execution and we don't want to change while we check things here. - final Object task = pending.task; - if (task == null) { - continue; - } else if (task instanceof SourcePrioritizedRunnable) { - SourcePrioritizedRunnable runnable = (SourcePrioritizedRunnable) task; - source = runnable.source(); - timeInQueue = runnable.getAgeInMillis(); - } else { - assert false : "expected SourcePrioritizedRunnable got " + task.getClass(); - source = "unknown [" + task.getClass() + "]"; - timeInQueue = 0; - } - - pendingClusterTasks.add( - new PendingClusterTask(pending.insertionOrder, pending.priority, new Text(source), timeInQueue, pending.executing)); - } - return pendingClusterTasks; + return Arrays.stream(threadPoolExecutor.getPending()).map(pending -> { + assert pending.task instanceof SourcePrioritizedRunnable : + "thread pool executor should only use SourcePrioritizedRunnable instances but found: " + pending.task.getClass().getName(); + SourcePrioritizedRunnable task = (SourcePrioritizedRunnable) pending.task; + return new PendingClusterTask(pending.insertionOrder, pending.priority, new Text(task.source()), + task.getAgeInMillis(), pending.executing); + }).collect(Collectors.toList()); } /** @@ -585,19 +551,6 @@ public void setDiscoverySettings(DiscoverySettings discoverySettings) { this.discoverySettings = discoverySettings; } - abstract static class SourcePrioritizedRunnable extends PrioritizedRunnable { - protected final String source; - - SourcePrioritizedRunnable(Priority priority, String source) { - super(priority); - this.source = source; - } - - public String source() { - return source; - } - } - void runTasks(TaskInputs taskInputs) { if (!lifecycle.started()) { logger.debug("processing [{}]: ignoring, cluster service not started", taskInputs.summary); @@ -657,8 +610,8 @@ void runTasks(TaskInputs taskInputs) { public TaskOutputs calculateTaskOutputs(TaskInputs taskInputs, ClusterState previousClusterState, long startTimeNS) { ClusterTasksResult clusterTasksResult = executeTasks(taskInputs, startTimeNS, previousClusterState); // extract those that are waiting for results - List nonFailedTasks = new ArrayList<>(); - for (UpdateTask updateTask : taskInputs.updateTasks) { + List nonFailedTasks = new ArrayList<>(); + for (ClusterServiceTaskBatcher.UpdateTask updateTask : taskInputs.updateTasks) { assert clusterTasksResult.executionResults.containsKey(updateTask.task) : "missing " + updateTask; final ClusterStateTaskExecutor.TaskResult taskResult = clusterTasksResult.executionResults.get(updateTask.task); @@ -675,7 +628,8 @@ public TaskOutputs calculateTaskOutputs(TaskInputs taskInputs, ClusterState prev private ClusterTasksResult executeTasks(TaskInputs taskInputs, long startTimeNS, ClusterState previousClusterState) { ClusterTasksResult clusterTasksResult; try { - List inputs = taskInputs.updateTasks.stream().map(tUpdateTask -> tUpdateTask.task).collect(Collectors.toList()); + List inputs = taskInputs.updateTasks.stream() + .map(ClusterServiceTaskBatcher.UpdateTask::getTask).collect(Collectors.toList()); clusterTasksResult = taskInputs.executor.execute(previousClusterState, inputs); } catch (Exception e) { TimeValue executionTime = TimeValue.timeValueMillis(Math.max(0, TimeValue.nsecToMSec(currentTimeInNanos() - startTimeNS))); @@ -693,7 +647,7 @@ private ClusterTasksResult executeTasks(TaskInputs taskInputs, long star } warnAboutSlowTaskIfNeeded(executionTime, taskInputs.summary); clusterTasksResult = ClusterTasksResult.builder() - .failures(taskInputs.updateTasks.stream().map(updateTask -> updateTask.task)::iterator, e) + .failures(taskInputs.updateTasks.stream().map(ClusterServiceTaskBatcher.UpdateTask::getTask)::iterator, e) .build(previousClusterState); } @@ -704,7 +658,7 @@ private ClusterTasksResult executeTasks(TaskInputs taskInputs, long star boolean assertsEnabled = false; assert (assertsEnabled = true); if (assertsEnabled) { - for (UpdateTask updateTask : taskInputs.updateTasks) { + for (ClusterServiceTaskBatcher.UpdateTask updateTask : taskInputs.updateTasks) { assert clusterTasksResult.executionResults.containsKey(updateTask.task) : "missing task result for " + updateTask; } @@ -870,10 +824,10 @@ private void callClusterStateAppliers(ClusterState newClusterState, ClusterChang */ class TaskInputs { public final String summary; - public final ArrayList updateTasks; + public final List updateTasks; public final ClusterStateTaskExecutor executor; - TaskInputs(ClusterStateTaskExecutor executor, ArrayList updateTasks, String summary) { + TaskInputs(ClusterStateTaskExecutor executor, List updateTasks, String summary) { this.summary = summary; this.executor = executor; this.updateTasks = updateTasks; @@ -895,11 +849,11 @@ class TaskOutputs { public final TaskInputs taskInputs; public final ClusterState previousClusterState; public final ClusterState newClusterState; - public final List nonFailedTasks; + public final List nonFailedTasks; public final Map executionResults; TaskOutputs(TaskInputs taskInputs, ClusterState previousClusterState, - ClusterState newClusterState, List nonFailedTasks, + ClusterState newClusterState, List nonFailedTasks, Map executionResults) { this.taskInputs = taskInputs; this.previousClusterState = previousClusterState; @@ -951,7 +905,7 @@ public boolean clusterStateUnchanged() { public void notifyFailedTasks() { // fail all tasks that have failed - for (UpdateTask updateTask : taskInputs.updateTasks) { + for (ClusterServiceTaskBatcher.UpdateTask updateTask : taskInputs.updateTasks) { assert executionResults.containsKey(updateTask.task) : "missing " + updateTask; final ClusterStateTaskExecutor.TaskResult taskResult = executionResults.get(updateTask.task); if (taskResult.isSuccess() == false) { @@ -1071,65 +1025,6 @@ public TimeValue ackTimeout() { } } - class UpdateTask extends SourcePrioritizedRunnable { - - public final Object task; - public final ClusterStateTaskListener listener; - private final ClusterStateTaskExecutor executor; - public final AtomicBoolean processed = new AtomicBoolean(); - - UpdateTask(String source, Object task, Priority priority, ClusterStateTaskExecutor executor, - ClusterStateTaskListener listener) { - super(priority, source); - this.task = task; - this.executor = executor; - this.listener = listener; - } - - @Override - public void run() { - // if this task is already processed, the executor shouldn't execute other tasks (that arrived later), - // to give other executors a chance to execute their tasks. - if (processed.get() == false) { - final ArrayList toExecute = new ArrayList<>(); - final Map> processTasksBySource = new HashMap<>(); - synchronized (updateTasksPerExecutor) { - LinkedHashSet pending = updateTasksPerExecutor.remove(executor); - if (pending != null) { - for (UpdateTask task : pending) { - if (task.processed.getAndSet(true) == false) { - logger.trace("will process {}", task); - toExecute.add(task); - processTasksBySource.computeIfAbsent(task.source, s -> new ArrayList<>()).add(task.task); - } else { - logger.trace("skipping {}, already processed", task); - } - } - } - } - - if (toExecute.isEmpty() == false) { - final String tasksSummary = processTasksBySource.entrySet().stream().map(entry -> { - String tasks = executor.describeTasks(entry.getValue()); - return tasks.isEmpty() ? entry.getKey() : entry.getKey() + "[" + tasks + "]"; - }).reduce((s1, s2) -> s1 + ", " + s2).orElse(""); - - runTasks(new TaskInputs(executor, toExecute, tasksSummary)); - } - } - } - - @Override - public String toString() { - String taskDescription = executor.describeTasks(Collections.singletonList(task)); - if (taskDescription.isEmpty()) { - return "[" + source + "]"; - } else { - return "[" + source + "[" + taskDescription + "]]"; - } - } - } - private void warnAboutSlowTaskIfNeeded(TimeValue executionTime, String source) { if (executionTime.getMillis() > slowTaskLoggingThreshold.getMillis()) { logger.warn("cluster state update task [{}] took [{}] above the warn threshold of {}", source, executionTime, diff --git a/core/src/main/java/org/elasticsearch/cluster/service/SourcePrioritizedRunnable.java b/core/src/main/java/org/elasticsearch/cluster/service/SourcePrioritizedRunnable.java new file mode 100644 index 0000000000000..6358acf7e1c7c --- /dev/null +++ b/core/src/main/java/org/elasticsearch/cluster/service/SourcePrioritizedRunnable.java @@ -0,0 +1,44 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.cluster.service; + +import org.elasticsearch.common.Priority; +import org.elasticsearch.common.util.concurrent.PrioritizedRunnable; + +/** + * PrioritizedRunnable that also has a source string + */ +public abstract class SourcePrioritizedRunnable extends PrioritizedRunnable { + protected final String source; + + public SourcePrioritizedRunnable(Priority priority, String source) { + super(priority); + this.source = source; + } + + public String source() { + return source; + } + + @Override + public String toString() { + return "[" + source + "]"; + } +} diff --git a/core/src/main/java/org/elasticsearch/cluster/service/TaskBatcher.java b/core/src/main/java/org/elasticsearch/cluster/service/TaskBatcher.java new file mode 100644 index 0000000000000..867d4191f800f --- /dev/null +++ b/core/src/main/java/org/elasticsearch/cluster/service/TaskBatcher.java @@ -0,0 +1,207 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.cluster.service; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.Priority; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.IdentityHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Batching support for {@link PrioritizedEsThreadPoolExecutor} + * Tasks that share the same batching key are batched (see {@link BatchedTask#batchingKey}) + */ +public abstract class TaskBatcher { + + private final Logger logger; + private final PrioritizedEsThreadPoolExecutor threadExecutor; + // package visible for tests + final Map> tasksPerBatchingKey = new HashMap<>(); + + public TaskBatcher(Logger logger, PrioritizedEsThreadPoolExecutor threadExecutor) { + this.logger = logger; + this.threadExecutor = threadExecutor; + } + + public void submitTasks(List tasks, @Nullable TimeValue timeout) throws EsRejectedExecutionException { + if (tasks.isEmpty()) { + return; + } + final BatchedTask firstTask = tasks.get(0); + assert tasks.stream().allMatch(t -> t.batchingKey == firstTask.batchingKey) : + "tasks submitted in a batch should share the same batching key: " + tasks; + // convert to an identity map to check for dups based on task identity + final Map tasksIdentity = tasks.stream().collect(Collectors.toMap( + BatchedTask::getTask, + Function.identity(), + (a, b) -> { throw new IllegalStateException("cannot add duplicate task: " + a); }, + IdentityHashMap::new)); + + synchronized (tasksPerBatchingKey) { + LinkedHashSet existingTasks = tasksPerBatchingKey.computeIfAbsent(firstTask.batchingKey, + k -> new LinkedHashSet<>(tasks.size())); + for (BatchedTask existing : existingTasks) { + // check that there won't be two tasks with the same identity for the same batching key + BatchedTask duplicateTask = tasksIdentity.get(existing.getTask()); + if (duplicateTask != null) { + throw new IllegalStateException("task [" + duplicateTask.describeTasks( + Collections.singletonList(existing)) + "] with source [" + duplicateTask.source + "] is already queued"); + } + } + existingTasks.addAll(tasks); + } + + if (timeout != null) { + threadExecutor.execute(firstTask, timeout, () -> onTimeoutInternal(tasks, timeout)); + } else { + threadExecutor.execute(firstTask); + } + } + + private void onTimeoutInternal(List tasks, TimeValue timeout) { + final ArrayList toRemove = new ArrayList<>(); + for (BatchedTask task : tasks) { + if (task.processed.getAndSet(true) == false) { + logger.debug("task [{}] timed out after [{}]", task.source, timeout); + toRemove.add(task); + } + } + if (toRemove.isEmpty() == false) { + BatchedTask firstTask = toRemove.get(0); + Object batchingKey = firstTask.batchingKey; + assert tasks.stream().allMatch(t -> t.batchingKey == batchingKey) : + "tasks submitted in a batch should share the same batching key: " + tasks; + synchronized (tasksPerBatchingKey) { + LinkedHashSet existingTasks = tasksPerBatchingKey.get(batchingKey); + if (existingTasks != null) { + existingTasks.removeAll(toRemove); + if (existingTasks.isEmpty()) { + tasksPerBatchingKey.remove(batchingKey); + } + } + } + onTimeout(toRemove, timeout); + } + } + + /** + * Action to be implemented by the specific batching implementation. + * All tasks have the same batching key. + */ + protected abstract void onTimeout(List tasks, TimeValue timeout); + + void runIfNotProcessed(BatchedTask updateTask) { + // if this task is already processed, it shouldn't execute other tasks with same batching key that arrived later, + // to give other tasks with different batching key a chance to execute. + if (updateTask.processed.get() == false) { + final List toExecute = new ArrayList<>(); + final Map> processTasksBySource = new HashMap<>(); + synchronized (tasksPerBatchingKey) { + LinkedHashSet pending = tasksPerBatchingKey.remove(updateTask.batchingKey); + if (pending != null) { + for (BatchedTask task : pending) { + if (task.processed.getAndSet(true) == false) { + logger.trace("will process {}", task); + toExecute.add(task); + processTasksBySource.computeIfAbsent(task.source, s -> new ArrayList<>()).add(task); + } else { + logger.trace("skipping {}, already processed", task); + } + } + } + } + + if (toExecute.isEmpty() == false) { + final String tasksSummary = processTasksBySource.entrySet().stream().map(entry -> { + String tasks = updateTask.describeTasks(entry.getValue()); + return tasks.isEmpty() ? entry.getKey() : entry.getKey() + "[" + tasks + "]"; + }).reduce((s1, s2) -> s1 + ", " + s2).orElse(""); + + run(updateTask.batchingKey, toExecute, tasksSummary); + } + } + } + + /** + * Action to be implemented by the specific batching implementation + * All tasks have the given batching key. + */ + protected abstract void run(Object batchingKey, List tasks, String tasksSummary); + + /** + * Represents a runnable task that supports batching. + * Implementors of TaskBatcher can subclass this to add a payload to the task. + */ + protected abstract class BatchedTask extends SourcePrioritizedRunnable { + /** + * whether the task has been processed already + */ + protected final AtomicBoolean processed = new AtomicBoolean(); + + /** + * the object that is used as batching key + */ + protected final Object batchingKey; + /** + * the task object that is wrapped + */ + protected final Object task; + + protected BatchedTask(Priority priority, String source, Object batchingKey, Object task) { + super(priority, source); + this.batchingKey = batchingKey; + this.task = task; + } + + @Override + public void run() { + runIfNotProcessed(this); + } + + @Override + public String toString() { + String taskDescription = describeTasks(Collections.singletonList(this)); + if (taskDescription.isEmpty()) { + return "[" + source + "]"; + } else { + return "[" + source + "[" + taskDescription + "]]"; + } + } + + public abstract String describeTasks(List tasks); + + public Object getTask() { + return task; + } + } +} diff --git a/core/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java b/core/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java index 5e4c00523c938..7ec587cf72718 100644 --- a/core/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java +++ b/core/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java @@ -30,6 +30,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedTransferQueue; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -56,8 +57,8 @@ public static int boundedNumberOfProcessors(Settings settings) { return PROCESSORS_SETTING.get(settings); } - public static PrioritizedEsThreadPoolExecutor newSinglePrioritizing(String name, ThreadFactory threadFactory, ThreadContext contextHolder) { - return new PrioritizedEsThreadPoolExecutor(name, 1, 1, 0L, TimeUnit.MILLISECONDS, threadFactory, contextHolder); + public static PrioritizedEsThreadPoolExecutor newSinglePrioritizing(String name, ThreadFactory threadFactory, ThreadContext contextHolder, ScheduledExecutorService timer) { + return new PrioritizedEsThreadPoolExecutor(name, 1, 1, 0L, TimeUnit.MILLISECONDS, threadFactory, contextHolder, timer); } public static EsThreadPoolExecutor newScaling(String name, int min, int max, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory, ThreadContext contextHolder) { diff --git a/core/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedEsThreadPoolExecutor.java b/core/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedEsThreadPoolExecutor.java index 1b01455c1ca79..5b3dae7ffae71 100644 --- a/core/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedEsThreadPoolExecutor.java +++ b/core/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedEsThreadPoolExecutor.java @@ -44,11 +44,14 @@ public class PrioritizedEsThreadPoolExecutor extends EsThreadPoolExecutor { private static final TimeValue NO_WAIT_TIME_VALUE = TimeValue.timeValueMillis(0); - private AtomicLong insertionOrder = new AtomicLong(); - private Queue current = ConcurrentCollections.newQueue(); + private final AtomicLong insertionOrder = new AtomicLong(); + private final Queue current = ConcurrentCollections.newQueue(); + private final ScheduledExecutorService timer; - PrioritizedEsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory, ThreadContext contextHolder) { + PrioritizedEsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, + ThreadFactory threadFactory, ThreadContext contextHolder, ScheduledExecutorService timer) { super(name, corePoolSize, maximumPoolSize, keepAliveTime, unit, new PriorityBlockingQueue<>(), threadFactory, contextHolder); + this.timer = timer; } public Pending[] getPending() { @@ -111,7 +114,7 @@ protected void afterExecute(Runnable r, Throwable t) { current.remove(r); } - public void execute(Runnable command, final ScheduledExecutorService timer, final TimeValue timeout, final Runnable timeoutCallback) { + public void execute(Runnable command, final TimeValue timeout, final Runnable timeoutCallback) { command = wrapRunnable(command); doExecute(command); if (timeout.nanos() >= 0) { diff --git a/core/src/test/java/org/elasticsearch/cluster/service/ClusterServiceTests.java b/core/src/test/java/org/elasticsearch/cluster/service/ClusterServiceTests.java index c0f8a58114feb..b8f85f733c394 100644 --- a/core/src/test/java/org/elasticsearch/cluster/service/ClusterServiceTests.java +++ b/core/src/test/java/org/elasticsearch/cluster/service/ClusterServiceTests.java @@ -20,7 +20,6 @@ import org.apache.logging.log4j.Level; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; @@ -39,7 +38,6 @@ import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.Priority; import org.elasticsearch.common.collect.Tuple; -import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; @@ -60,8 +58,6 @@ import org.junit.BeforeClass; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -70,7 +66,6 @@ import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Semaphore; @@ -78,17 +73,14 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; import static org.elasticsearch.test.ClusterServiceUtils.setState; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasKey; -import static org.hamcrest.Matchers.hasToString; import static org.hamcrest.Matchers.is; public class ClusterServiceTests extends ESTestCase { @@ -151,118 +143,6 @@ public void disconnectFromNodesExcept(DiscoveryNodes nodesToKeep) { return timedClusterService; } - public void testTimedOutUpdateTaskCleanedUp() throws Exception { - final CountDownLatch block = new CountDownLatch(1); - final CountDownLatch blockCompleted = new CountDownLatch(1); - clusterService.submitStateUpdateTask("block-task", new ClusterStateUpdateTask() { - @Override - public ClusterState execute(ClusterState currentState) { - try { - block.await(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - blockCompleted.countDown(); - return currentState; - } - - @Override - public void onFailure(String source, Exception e) { - throw new RuntimeException(e); - } - }); - - final CountDownLatch block2 = new CountDownLatch(1); - clusterService.submitStateUpdateTask("test", new ClusterStateUpdateTask() { - @Override - public ClusterState execute(ClusterState currentState) { - block2.countDown(); - return currentState; - } - - @Override - public TimeValue timeout() { - return TimeValue.ZERO; - } - - @Override - public void onFailure(String source, Exception e) { - block2.countDown(); - } - }); - block.countDown(); - block2.await(); - blockCompleted.await(); - synchronized (clusterService.updateTasksPerExecutor) { - assertTrue("expected empty map but was " + clusterService.updateTasksPerExecutor, - clusterService.updateTasksPerExecutor.isEmpty()); - } - } - - public void testTimeoutUpdateTask() throws Exception { - final CountDownLatch block = new CountDownLatch(1); - clusterService.submitStateUpdateTask("test1", new ClusterStateUpdateTask() { - @Override - public ClusterState execute(ClusterState currentState) { - try { - block.await(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - return currentState; - } - - @Override - public void onFailure(String source, Exception e) { - throw new RuntimeException(e); - } - }); - - final CountDownLatch timedOut = new CountDownLatch(1); - final AtomicBoolean executeCalled = new AtomicBoolean(); - clusterService.submitStateUpdateTask("test2", new ClusterStateUpdateTask() { - @Override - public TimeValue timeout() { - return TimeValue.timeValueMillis(2); - } - - @Override - public void onFailure(String source, Exception e) { - timedOut.countDown(); - } - - @Override - public ClusterState execute(ClusterState currentState) { - executeCalled.set(true); - return currentState; - } - - @Override - public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { - } - }); - - timedOut.await(); - block.countDown(); - final CountDownLatch allProcessed = new CountDownLatch(1); - clusterService.submitStateUpdateTask("test3", new ClusterStateUpdateTask() { - @Override - public void onFailure(String source, Exception e) { - throw new RuntimeException(e); - } - - @Override - public ClusterState execute(ClusterState currentState) { - allProcessed.countDown(); - return currentState; - } - - }); - allProcessed.await(); // executed another task to double check that execute on the timed out update task is not called... - assertThat(executeCalled.get(), equalTo(false)); - } - - public void testMasterAwareExecution() throws Exception { ClusterService nonMaster = createTimedClusterService(false); @@ -394,164 +274,6 @@ public void onFailure(String source, Exception e) { assertThat(assertionRef.get().getMessage(), containsString("not be the cluster state update thread. Reason: [Blocking operation]")); } - public void testOneExecutorDontStarveAnother() throws InterruptedException { - final List executionOrder = Collections.synchronizedList(new ArrayList<>()); - final Semaphore allowProcessing = new Semaphore(0); - final Semaphore startedProcessing = new Semaphore(0); - - class TaskExecutor implements ClusterStateTaskExecutor { - - @Override - public ClusterTasksResult execute(ClusterState currentState, List tasks) throws Exception { - executionOrder.addAll(tasks); // do this first, so startedProcessing can be used as a notification that this is done. - startedProcessing.release(tasks.size()); - allowProcessing.acquire(tasks.size()); - return ClusterTasksResult.builder().successes(tasks).build(ClusterState.builder(currentState).build()); - } - } - - TaskExecutor executorA = new TaskExecutor(); - TaskExecutor executorB = new TaskExecutor(); - - final ClusterStateTaskConfig config = ClusterStateTaskConfig.build(Priority.NORMAL); - final ClusterStateTaskListener noopListener = (source, e) -> { throw new AssertionError(source, e); }; - // this blocks the cluster state queue, so we can set it up right - clusterService.submitStateUpdateTask("0", "A0", config, executorA, noopListener); - // wait to be processed - startedProcessing.acquire(1); - assertThat(executionOrder, equalTo(Arrays.asList("A0"))); - - - // these will be the first batch - clusterService.submitStateUpdateTask("1", "A1", config, executorA, noopListener); - clusterService.submitStateUpdateTask("2", "A2", config, executorA, noopListener); - - // release the first 0 task, but not the second - allowProcessing.release(1); - startedProcessing.acquire(2); - assertThat(executionOrder, equalTo(Arrays.asList("A0", "A1", "A2"))); - - // setup the queue with pending tasks for another executor same priority - clusterService.submitStateUpdateTask("3", "B3", config, executorB, noopListener); - clusterService.submitStateUpdateTask("4", "B4", config, executorB, noopListener); - - - clusterService.submitStateUpdateTask("5", "A5", config, executorA, noopListener); - clusterService.submitStateUpdateTask("6", "A6", config, executorA, noopListener); - - // now release the processing - allowProcessing.release(6); - - // wait for last task to be processed - startedProcessing.acquire(4); - - assertThat(executionOrder, equalTo(Arrays.asList("A0", "A1", "A2", "B3", "B4", "A5", "A6"))); - - } - - // test that for a single thread, tasks are executed in the order - // that they are submitted - public void testClusterStateUpdateTasksAreExecutedInOrder() throws BrokenBarrierException, InterruptedException { - class TaskExecutor implements ClusterStateTaskExecutor { - List tasks = new ArrayList<>(); - - @Override - public ClusterTasksResult execute(ClusterState currentState, List tasks) throws Exception { - this.tasks.addAll(tasks); - return ClusterTasksResult.builder().successes(tasks).build(ClusterState.builder(currentState).build()); - } - } - - int numberOfThreads = randomIntBetween(2, 8); - TaskExecutor[] executors = new TaskExecutor[numberOfThreads]; - for (int i = 0; i < numberOfThreads; i++) { - executors[i] = new TaskExecutor(); - } - - int tasksSubmittedPerThread = randomIntBetween(2, 1024); - - CopyOnWriteArrayList> failures = new CopyOnWriteArrayList<>(); - CountDownLatch updateLatch = new CountDownLatch(numberOfThreads * tasksSubmittedPerThread); - - ClusterStateTaskListener listener = new ClusterStateTaskListener() { - @Override - public void onFailure(String source, Exception e) { - logger.error((Supplier) () -> new ParameterizedMessage("unexpected failure: [{}]", source), e); - failures.add(new Tuple<>(source, e)); - updateLatch.countDown(); - } - - @Override - public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { - updateLatch.countDown(); - } - }; - - CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads); - - for (int i = 0; i < numberOfThreads; i++) { - final int index = i; - Thread thread = new Thread(() -> { - try { - barrier.await(); - for (int j = 0; j < tasksSubmittedPerThread; j++) { - clusterService.submitStateUpdateTask("[" + index + "][" + j + "]", j, - ClusterStateTaskConfig.build(randomFrom(Priority.values())), executors[index], listener); - } - barrier.await(); - } catch (InterruptedException | BrokenBarrierException e) { - throw new AssertionError(e); - } - }); - thread.start(); - } - - // wait for all threads to be ready - barrier.await(); - // wait for all threads to finish - barrier.await(); - - updateLatch.await(); - - assertThat(failures, empty()); - - for (int i = 0; i < numberOfThreads; i++) { - assertEquals(tasksSubmittedPerThread, executors[i].tasks.size()); - for (int j = 0; j < tasksSubmittedPerThread; j++) { - assertNotNull(executors[i].tasks.get(j)); - assertEquals("cluster state update task executed out of order", j, (int) executors[i].tasks.get(j)); - } - } - } - - public void testSingleBatchSubmission() throws InterruptedException { - Map tasks = new HashMap<>(); - final int numOfTasks = randomInt(10); - final CountDownLatch latch = new CountDownLatch(numOfTasks); - for (int i = 0; i < numOfTasks; i++) { - while (null != tasks.put(randomInt(1024), new ClusterStateTaskListener() { - @Override - public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { - latch.countDown(); - } - - @Override - public void onFailure(String source, Exception e) { - fail(ExceptionsHelper.detailedMessage(e)); - } - })) ; - } - - clusterService.submitStateUpdateTasks("test", tasks, ClusterStateTaskConfig.build(Priority.LANGUID), - (currentState, taskList) -> { - assertThat(taskList.size(), equalTo(tasks.size())); - assertThat(taskList.stream().collect(Collectors.toSet()), equalTo(tasks.keySet())); - return ClusterStateTaskExecutor.ClusterTasksResult.builder().successes(taskList).build(currentState); - }); - - latch.await(); - } - public void testClusterStateBatchedUpdates() throws BrokenBarrierException, InterruptedException { AtomicInteger counter = new AtomicInteger(); class Task { @@ -745,76 +467,6 @@ public void clusterStateProcessed(String source, ClusterState oldState, ClusterS } } - /** - * Note, this test can only work as long as we have a single thread executor executing the state update tasks! - */ - public void testPrioritizedTasks() throws Exception { - BlockingTask block = new BlockingTask(Priority.IMMEDIATE); - clusterService.submitStateUpdateTask("test", block); - int taskCount = randomIntBetween(5, 20); - - // will hold all the tasks in the order in which they were executed - List tasks = new ArrayList<>(taskCount); - CountDownLatch latch = new CountDownLatch(taskCount); - for (int i = 0; i < taskCount; i++) { - Priority priority = randomFrom(Priority.values()); - clusterService.submitStateUpdateTask("test", new PrioritizedTask(priority, latch, tasks)); - } - - block.close(); - latch.await(); - - Priority prevPriority = null; - for (PrioritizedTask task : tasks) { - if (prevPriority == null) { - prevPriority = task.priority(); - } else { - assertThat(task.priority().sameOrAfter(prevPriority), is(true)); - } - } - } - - public void testDuplicateSubmission() throws InterruptedException { - final CountDownLatch latch = new CountDownLatch(2); - try (BlockingTask blockingTask = new BlockingTask(Priority.IMMEDIATE)) { - clusterService.submitStateUpdateTask("blocking", blockingTask); - - ClusterStateTaskExecutor executor = (currentState, tasks) -> - ClusterStateTaskExecutor.ClusterTasksResult.builder().successes(tasks).build(currentState); - - SimpleTask task = new SimpleTask(1); - ClusterStateTaskListener listener = new ClusterStateTaskListener() { - @Override - public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { - latch.countDown(); - } - - @Override - public void onFailure(String source, Exception e) { - fail(ExceptionsHelper.detailedMessage(e)); - } - }; - - clusterService.submitStateUpdateTask("first time", task, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener); - - final IllegalStateException e = - expectThrows( - IllegalStateException.class, - () -> clusterService.submitStateUpdateTask( - "second time", - task, - ClusterStateTaskConfig.build(Priority.NORMAL), - executor, listener)); - assertThat(e, hasToString(containsString("task [1] with source [second time] is already queued"))); - - clusterService.submitStateUpdateTask("third time a charm", new SimpleTask(1), - ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener); - - assertThat(latch.getCount(), equalTo(2L)); - } - latch.await(); - } - @TestLogging("org.elasticsearch.cluster.service:TRACE") // To ensure that we log cluster state events on TRACE level public void testClusterStateUpdateLogging() throws Exception { MockLogAppender mockAppender = new MockLogAppender(); @@ -1249,77 +901,6 @@ public void clusterStateProcessed(String source, ClusterState oldState, ClusterS assertTrue(applierCalled.get()); } - - private static class SimpleTask { - private final int id; - - private SimpleTask(int id) { - this.id = id; - } - - @Override - public int hashCode() { - return super.hashCode(); - } - - @Override - public boolean equals(Object obj) { - return super.equals(obj); - } - - @Override - public String toString() { - return Integer.toString(id); - } - } - - private static class BlockingTask extends ClusterStateUpdateTask implements Releasable { - private final CountDownLatch latch = new CountDownLatch(1); - - BlockingTask(Priority priority) { - super(priority); - } - - @Override - public ClusterState execute(ClusterState currentState) throws Exception { - latch.await(); - return currentState; - } - - @Override - public void onFailure(String source, Exception e) { - } - - public void close() { - latch.countDown(); - } - - } - - private static class PrioritizedTask extends ClusterStateUpdateTask { - - private final CountDownLatch latch; - private final List tasks; - - private PrioritizedTask(Priority priority, CountDownLatch latch, List tasks) { - super(priority); - this.latch = latch; - this.tasks = tasks; - } - - @Override - public ClusterState execute(ClusterState currentState) throws Exception { - tasks.add(this); - latch.countDown(); - return currentState; - } - - @Override - public void onFailure(String source, Exception e) { - latch.countDown(); - } - } - static class TimedClusterService extends ClusterService { public volatile Long currentTimeOverride = null; diff --git a/core/src/test/java/org/elasticsearch/cluster/service/TaskBatcherTests.java b/core/src/test/java/org/elasticsearch/cluster/service/TaskBatcherTests.java new file mode 100644 index 0000000000000..d5af9dd558155 --- /dev/null +++ b/core/src/test/java/org/elasticsearch/cluster/service/TaskBatcherTests.java @@ -0,0 +1,350 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.cluster.service; + +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.logging.log4j.util.Supplier; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.cluster.ClusterStateTaskConfig; +import org.elasticsearch.cluster.metadata.ProcessClusterEventTimeoutException; +import org.elasticsearch.common.Priority; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BrokenBarrierException; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.Semaphore; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasToString; + +public class TaskBatcherTests extends TaskExecutorTests { + + protected TestTaskBatcher taskBatcher; + + @Before + public void setUpBatchingTaskExecutor() throws Exception { + taskBatcher = new TestTaskBatcher(logger, threadExecutor); + } + + class TestTaskBatcher extends TaskBatcher { + + TestTaskBatcher(Logger logger, PrioritizedEsThreadPoolExecutor threadExecutor) { + super(logger, threadExecutor); + } + + @Override + protected void run(Object batchingKey, List tasks, String tasksSummary) { + List updateTasks = (List) tasks; + ((TestExecutor) batchingKey).execute(updateTasks.stream().map(t -> t.task).collect(Collectors.toList())); + updateTasks.forEach(updateTask -> updateTask.listener.processed(updateTask.source)); + } + + @Override + protected void onTimeout(List tasks, TimeValue timeout) { + threadPool.generic().execute( + () -> tasks.forEach( + task -> ((UpdateTask) task).listener.onFailure(task.source, + new ProcessClusterEventTimeoutException(timeout, task.source)))); + } + + class UpdateTask extends BatchedTask { + final TestListener listener; + + UpdateTask(Priority priority, String source, Object task, TestListener listener, TestExecutor executor) { + super(priority, source, executor, task); + this.listener = listener; + } + + @Override + public String describeTasks(List tasks) { + return ((TestExecutor) batchingKey).describeTasks( + tasks.stream().map(BatchedTask::getTask).collect(Collectors.toList())); + } + } + + } + + @Override + protected void submitTask(String source, TestTask testTask) { + submitTask(source, testTask, testTask, testTask, testTask); + } + + private void submitTask(String source, T task, ClusterStateTaskConfig config, TestExecutor executor, + TestListener listener) { + submitTasks(source, Collections.singletonMap(task, listener), config, executor); + } + + private void submitTasks(final String source, + final Map tasks, final ClusterStateTaskConfig config, + final TestExecutor executor) { + List safeTasks = tasks.entrySet().stream() + .map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), e.getValue(), executor)) + .collect(Collectors.toList()); + taskBatcher.submitTasks(safeTasks, config.timeout()); + } + + @Override + public void testTimedOutTaskCleanedUp() throws Exception { + super.testTimedOutTaskCleanedUp(); + synchronized (taskBatcher.tasksPerBatchingKey) { + assertTrue("expected empty map but was " + taskBatcher.tasksPerBatchingKey, + taskBatcher.tasksPerBatchingKey.isEmpty()); + } + } + + public void testOneExecutorDoesntStarveAnother() throws InterruptedException { + final List executionOrder = Collections.synchronizedList(new ArrayList<>()); + final Semaphore allowProcessing = new Semaphore(0); + final Semaphore startedProcessing = new Semaphore(0); + + class TaskExecutor implements TestExecutor { + + @Override + public void execute(List tasks) { + executionOrder.addAll(tasks); // do this first, so startedProcessing can be used as a notification that this is done. + startedProcessing.release(tasks.size()); + try { + allowProcessing.acquire(tasks.size()); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + + TaskExecutor executorA = new TaskExecutor(); + TaskExecutor executorB = new TaskExecutor(); + + final ClusterStateTaskConfig config = ClusterStateTaskConfig.build(Priority.NORMAL); + final TestListener noopListener = (source, e) -> { + throw new AssertionError(e); + }; + // this blocks the cluster state queue, so we can set it up right + submitTask("0", "A0", config, executorA, noopListener); + // wait to be processed + startedProcessing.acquire(1); + assertThat(executionOrder, equalTo(Arrays.asList("A0"))); + + + // these will be the first batch + submitTask("1", "A1", config, executorA, noopListener); + submitTask("2", "A2", config, executorA, noopListener); + + // release the first 0 task, but not the second + allowProcessing.release(1); + startedProcessing.acquire(2); + assertThat(executionOrder, equalTo(Arrays.asList("A0", "A1", "A2"))); + + // setup the queue with pending tasks for another executor same priority + submitTask("3", "B3", config, executorB, noopListener); + submitTask("4", "B4", config, executorB, noopListener); + + + submitTask("5", "A5", config, executorA, noopListener); + submitTask("6", "A6", config, executorA, noopListener); + + // now release the processing + allowProcessing.release(6); + + // wait for last task to be processed + startedProcessing.acquire(4); + + assertThat(executionOrder, equalTo(Arrays.asList("A0", "A1", "A2", "B3", "B4", "A5", "A6"))); + } + + static class TaskExecutor implements TestExecutor { + List tasks = new ArrayList<>(); + + @Override + public void execute(List tasks) { + this.tasks.addAll(tasks); + } + } + + // test that for a single thread, tasks are executed in the order + // that they are submitted + public void testTasksAreExecutedInOrder() throws BrokenBarrierException, InterruptedException { + int numberOfThreads = randomIntBetween(2, 8); + TaskExecutor[] executors = new TaskExecutor[numberOfThreads]; + for (int i = 0; i < numberOfThreads; i++) { + executors[i] = new TaskExecutor(); + } + + int tasksSubmittedPerThread = randomIntBetween(2, 1024); + + CopyOnWriteArrayList> failures = new CopyOnWriteArrayList<>(); + CountDownLatch updateLatch = new CountDownLatch(numberOfThreads * tasksSubmittedPerThread); + + final TestListener listener = new TestListener() { + @Override + public void onFailure(String source, Exception e) { + logger.error((Supplier) () -> new ParameterizedMessage("unexpected failure: [{}]", source), e); + failures.add(new Tuple<>(source, e)); + updateLatch.countDown(); + } + + @Override + public void processed(String source) { + updateLatch.countDown(); + } + }; + + CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads); + + for (int i = 0; i < numberOfThreads; i++) { + final int index = i; + Thread thread = new Thread(() -> { + try { + barrier.await(); + for (int j = 0; j < tasksSubmittedPerThread; j++) { + submitTask("[" + index + "][" + j + "]", j, + ClusterStateTaskConfig.build(randomFrom(Priority.values())), executors[index], listener); + } + barrier.await(); + } catch (InterruptedException | BrokenBarrierException e) { + throw new AssertionError(e); + } + }); + thread.start(); + } + + // wait for all threads to be ready + barrier.await(); + // wait for all threads to finish + barrier.await(); + + updateLatch.await(); + + assertThat(failures, empty()); + + for (int i = 0; i < numberOfThreads; i++) { + assertEquals(tasksSubmittedPerThread, executors[i].tasks.size()); + for (int j = 0; j < tasksSubmittedPerThread; j++) { + assertNotNull(executors[i].tasks.get(j)); + assertEquals("cluster state update task executed out of order", j, (int) executors[i].tasks.get(j)); + } + } + } + + public void testSingleBatchSubmission() throws InterruptedException { + Map tasks = new HashMap<>(); + final int numOfTasks = randomInt(10); + final CountDownLatch latch = new CountDownLatch(numOfTasks); + for (int i = 0; i < numOfTasks; i++) { + while (null != tasks.put(randomInt(1024), new TestListener() { + @Override + public void processed(String source) { + latch.countDown(); + } + + @Override + public void onFailure(String source, Exception e) { + fail(ExceptionsHelper.detailedMessage(e)); + } + })) ; + } + + TestExecutor executor = taskList -> { + assertThat(taskList.size(), equalTo(tasks.size())); + assertThat(taskList.stream().collect(Collectors.toSet()), equalTo(tasks.keySet())); + }; + submitTasks("test", tasks, ClusterStateTaskConfig.build(Priority.LANGUID), executor); + + latch.await(); + } + + public void testDuplicateSubmission() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(2); + try (BlockingTask blockingTask = new BlockingTask(Priority.IMMEDIATE)) { + submitTask("blocking", blockingTask); + + TestExecutor executor = tasks -> {}; + SimpleTask task = new SimpleTask(1); + TestListener listener = new TestListener() { + @Override + public void processed(String source) { + latch.countDown(); + } + + @Override + public void onFailure(String source, Exception e) { + fail(ExceptionsHelper.detailedMessage(e)); + } + }; + + submitTask("first time", task, ClusterStateTaskConfig.build(Priority.NORMAL), executor, + listener); + + final IllegalStateException e = + expectThrows( + IllegalStateException.class, + () -> submitTask( + "second time", + task, + ClusterStateTaskConfig.build(Priority.NORMAL), + executor, listener)); + assertThat(e, hasToString(containsString("task [1] with source [second time] is already queued"))); + + submitTask("third time a charm", new SimpleTask(1), + ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener); + + assertThat(latch.getCount(), equalTo(2L)); + } + latch.await(); + } + + private static class SimpleTask { + private final int id; + + private SimpleTask(int id) { + this.id = id; + } + + @Override + public int hashCode() { + return super.hashCode(); + } + + @Override + public boolean equals(Object obj) { + return super.equals(obj); + } + + @Override + public String toString() { + return Integer.toString(id); + } + } + +} diff --git a/core/src/test/java/org/elasticsearch/cluster/service/TaskExecutorTests.java b/core/src/test/java/org/elasticsearch/cluster/service/TaskExecutorTests.java new file mode 100644 index 0000000000000..fe426fdd42a9d --- /dev/null +++ b/core/src/test/java/org/elasticsearch/cluster/service/TaskExecutorTests.java @@ -0,0 +1,365 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.cluster.service; + +import org.elasticsearch.cluster.ClusterStateTaskConfig; +import org.elasticsearch.cluster.metadata.ProcessClusterEventTimeoutException; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.Priority; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.Is.is; + +public class TaskExecutorTests extends ESTestCase { + + protected static ThreadPool threadPool; + protected PrioritizedEsThreadPoolExecutor threadExecutor; + + @BeforeClass + public static void createThreadPool() { + threadPool = new TestThreadPool(getTestClass().getName()); + } + + @AfterClass + public static void stopThreadPool() { + if (threadPool != null) { + threadPool.shutdownNow(); + threadPool = null; + } + } + + @Before + public void setUpExecutor() { + threadExecutor = EsExecutors.newSinglePrioritizing("test_thread", + daemonThreadFactory(Settings.EMPTY, "test_thread"), threadPool.getThreadContext(), threadPool.scheduler()); + } + + @After + public void shutDownThreadExecutor() { + ThreadPool.terminate(threadExecutor, 10, TimeUnit.SECONDS); + } + + protected interface TestListener { + void onFailure(String source, Exception e); + + default void processed(String source) { + // do nothing by default + } + } + + protected interface TestExecutor { + void execute(List tasks); + + default String describeTasks(List tasks) { + return tasks.stream().map(T::toString).reduce((s1,s2) -> { + if (s1.isEmpty()) { + return s2; + } else if (s2.isEmpty()) { + return s1; + } else { + return s1 + ", " + s2; + } + }).orElse(""); + } + } + + /** + * Task class that works for single tasks as well as batching (see {@link TaskBatcherTests}) + */ + protected abstract static class TestTask implements TestExecutor, TestListener, ClusterStateTaskConfig { + + @Override + public void execute(List tasks) { + tasks.forEach(TestTask::run); + } + + @Nullable + @Override + public TimeValue timeout() { + return null; + } + + @Override + public Priority priority() { + return Priority.NORMAL; + } + + public abstract void run(); + } + + class UpdateTask extends SourcePrioritizedRunnable { + final TestTask testTask; + + UpdateTask(String source, TestTask testTask) { + super(testTask.priority(), source); + this.testTask = testTask; + } + + @Override + public void run() { + logger.trace("will process {}", source); + testTask.execute(Collections.singletonList(testTask)); + testTask.processed(source); + } + } + + // can be overridden by TaskBatcherTests + protected void submitTask(String source, TestTask testTask) { + SourcePrioritizedRunnable task = new UpdateTask(source, testTask); + TimeValue timeout = testTask.timeout(); + if (timeout != null) { + threadExecutor.execute(task, timeout, () -> threadPool.generic().execute(() -> { + logger.debug("task [{}] timed out after [{}]", task, timeout); + testTask.onFailure(source, new ProcessClusterEventTimeoutException(timeout, source)); + })); + } else { + threadExecutor.execute(task); + } + } + + + public void testTimedOutTaskCleanedUp() throws Exception { + final CountDownLatch block = new CountDownLatch(1); + final CountDownLatch blockCompleted = new CountDownLatch(1); + TestTask blockTask = new TestTask() { + + @Override + public void run() { + try { + block.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + blockCompleted.countDown(); + } + + @Override + public void onFailure(String source, Exception e) { + throw new RuntimeException(e); + } + }; + submitTask("block-task", blockTask); + + final CountDownLatch block2 = new CountDownLatch(1); + TestTask unblockTask = new TestTask() { + + @Override + public void run() { + block2.countDown(); + } + + @Override + public void onFailure(String source, Exception e) { + block2.countDown(); + } + + @Override + public TimeValue timeout() { + return TimeValue.ZERO; + } + }; + submitTask("unblock-task", unblockTask); + + block.countDown(); + block2.await(); + blockCompleted.await(); + } + + public void testTimeoutTask() throws Exception { + final CountDownLatch block = new CountDownLatch(1); + TestTask test1 = new TestTask() { + @Override + public void run() { + try { + block.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onFailure(String source, Exception e) { + throw new RuntimeException(e); + } + }; + submitTask("block-task", test1); + + final CountDownLatch timedOut = new CountDownLatch(1); + final AtomicBoolean executeCalled = new AtomicBoolean(); + TestTask test2 = new TestTask() { + + @Override + public TimeValue timeout() { + return TimeValue.timeValueMillis(2); + } + + @Override + public void run() { + executeCalled.set(true); + } + + @Override + public void onFailure(String source, Exception e) { + timedOut.countDown(); + } + }; + submitTask("block-task", test2); + + timedOut.await(); + block.countDown(); + final CountDownLatch allProcessed = new CountDownLatch(1); + TestTask test3 = new TestTask() { + + @Override + public void run() { + allProcessed.countDown(); + } + + @Override + public void onFailure(String source, Exception e) { + throw new RuntimeException(e); + } + }; + submitTask("block-task", test3); + allProcessed.await(); // executed another task to double check that execute on the timed out update task is not called... + assertThat(executeCalled.get(), equalTo(false)); + } + + static class TaskExecutor implements TestExecutor { + List tasks = new ArrayList<>(); + + @Override + public void execute(List tasks) { + this.tasks.addAll(tasks); + } + } + + /** + * Note, this test can only work as long as we have a single thread executor executing the state update tasks! + */ + public void testPrioritizedTasks() throws Exception { + BlockingTask block = new BlockingTask(Priority.IMMEDIATE); + submitTask("test", block); + int taskCount = randomIntBetween(5, 20); + + // will hold all the tasks in the order in which they were executed + List tasks = new ArrayList<>(taskCount); + CountDownLatch latch = new CountDownLatch(taskCount); + for (int i = 0; i < taskCount; i++) { + Priority priority = randomFrom(Priority.values()); + PrioritizedTask task = new PrioritizedTask(priority, latch, tasks); + submitTask("test", task); + } + + block.close(); + latch.await(); + + Priority prevPriority = null; + for (PrioritizedTask task : tasks) { + if (prevPriority == null) { + prevPriority = task.priority(); + } else { + assertThat(task.priority().sameOrAfter(prevPriority), is(true)); + } + } + } + + protected static class BlockingTask extends TestTask implements Releasable { + private final CountDownLatch latch = new CountDownLatch(1); + private final Priority priority; + + BlockingTask(Priority priority) { + super(); + this.priority = priority; + } + + @Override + public void run() { + try { + latch.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onFailure(String source, Exception e) { + } + + @Override + public Priority priority() { + return priority; + } + + public void close() { + latch.countDown(); + } + + } + + protected static class PrioritizedTask extends TestTask { + private final CountDownLatch latch; + private final List tasks; + private final Priority priority; + + private PrioritizedTask(Priority priority, CountDownLatch latch, List tasks) { + super(); + this.latch = latch; + this.tasks = tasks; + this.priority = priority; + } + + @Override + public void run() { + tasks.add(this); + latch.countDown(); + } + + @Override + public Priority priority() { + return priority; + } + + @Override + public void onFailure(String source, Exception e) { + latch.countDown(); + } + } + +} diff --git a/core/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedExecutorsTests.java b/core/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedExecutorsTests.java index 933a46de510dc..3ed105080b30b 100644 --- a/core/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedExecutorsTests.java +++ b/core/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedExecutorsTests.java @@ -65,7 +65,7 @@ public void testPriorityQueue() throws Exception { } public void testSubmitPrioritizedExecutorWithRunnables() throws Exception { - ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder); + ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder, null); List results = new ArrayList<>(8); CountDownLatch awaitingLatch = new CountDownLatch(1); CountDownLatch finishedLatch = new CountDownLatch(8); @@ -94,7 +94,7 @@ public void testSubmitPrioritizedExecutorWithRunnables() throws Exception { } public void testExecutePrioritizedExecutorWithRunnables() throws Exception { - ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder); + ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder, null); List results = new ArrayList<>(8); CountDownLatch awaitingLatch = new CountDownLatch(1); CountDownLatch finishedLatch = new CountDownLatch(8); @@ -123,7 +123,7 @@ public void testExecutePrioritizedExecutorWithRunnables() throws Exception { } public void testSubmitPrioritizedExecutorWithCallables() throws Exception { - ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder); + ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder, null); List results = new ArrayList<>(8); CountDownLatch awaitingLatch = new CountDownLatch(1); CountDownLatch finishedLatch = new CountDownLatch(8); @@ -152,7 +152,7 @@ public void testSubmitPrioritizedExecutorWithCallables() throws Exception { } public void testSubmitPrioritizedExecutorWithMixed() throws Exception { - ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder); + ExecutorService executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder, null); List results = new ArrayList<>(8); CountDownLatch awaitingLatch = new CountDownLatch(1); CountDownLatch finishedLatch = new CountDownLatch(8); @@ -182,7 +182,7 @@ public void testSubmitPrioritizedExecutorWithMixed() throws Exception { public void testTimeout() throws Exception { ScheduledExecutorService timer = Executors.newSingleThreadScheduledExecutor(EsExecutors.daemonThreadFactory(getTestName())); - PrioritizedEsThreadPoolExecutor executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder); + PrioritizedEsThreadPoolExecutor executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder, timer); final CountDownLatch invoked = new CountDownLatch(1); final CountDownLatch block = new CountDownLatch(1); executor.execute(new Runnable() { @@ -219,7 +219,7 @@ public void run() { public String toString() { return "the waiting"; } - }, timer, TimeValue.timeValueMillis(100) /* enough timeout to catch them in the pending list... */, new Runnable() { + }, TimeValue.timeValueMillis(100) /* enough timeout to catch them in the pending list... */, new Runnable() { @Override public void run() { timedOut.countDown(); @@ -245,14 +245,14 @@ public void testTimeoutCleanup() throws Exception { ThreadPool threadPool = new TestThreadPool("test"); final ScheduledThreadPoolExecutor timer = (ScheduledThreadPoolExecutor) threadPool.scheduler(); final AtomicBoolean timeoutCalled = new AtomicBoolean(); - PrioritizedEsThreadPoolExecutor executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder); + PrioritizedEsThreadPoolExecutor executor = EsExecutors.newSinglePrioritizing(getTestName(), EsExecutors.daemonThreadFactory(getTestName()), holder, timer); final CountDownLatch invoked = new CountDownLatch(1); executor.execute(new Runnable() { @Override public void run() { invoked.countDown(); } - }, timer, TimeValue.timeValueHours(1), new Runnable() { + }, TimeValue.timeValueHours(1), new Runnable() { @Override public void run() { // We should never get here