Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x]Support task resource tracking in OpenSearch (#3982) #4087

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,9 @@ public void onTaskUnregistered(Task task) {}

@Override
public void waitForTaskCompletion(Task task) {}

@Override
public void taskExecutionStarted(Task task, Boolean closeableInvoked) {}
});
}
// Need to run the task in a separate thread because node client's .execute() is blocked by our task listener
Expand Down Expand Up @@ -651,6 +654,9 @@ public void waitForTaskCompletion(Task task) {
waitForWaitingToStart.countDown();
}

@Override
public void taskExecutionStarted(Task task, Boolean closeableInvoked) {}

@Override
public void onTaskRegistered(Task task) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskInfo;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

Expand All @@ -65,8 +66,15 @@ public static long waitForCompletionTimeout(TimeValue timeout) {

private static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = timeValueSeconds(30);

private final TaskResourceTrackingService taskResourceTrackingService;

@Inject
public TransportListTasksAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters) {
public TransportListTasksAction(
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
TaskResourceTrackingService taskResourceTrackingService
) {
super(
ListTasksAction.NAME,
clusterService,
Expand All @@ -77,6 +85,7 @@ public TransportListTasksAction(ClusterService clusterService, TransportService
TaskInfo::new,
ThreadPool.Names.MANAGEMENT
);
this.taskResourceTrackingService = taskResourceTrackingService;
}

@Override
Expand Down Expand Up @@ -106,6 +115,8 @@ protected void processTasks(ListTasksRequest request, Consumer<Task> operation)
}
taskManager.waitForTaskCompletion(task, timeoutNanos);
});
} else {
operation = operation.andThen(taskResourceTrackingService::refreshResourceStats);
}
super.processTasks(request, operation);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ public SearchShardTask(long id, String type, String action, String description,
super(id, type, action, description, parentTaskId, headers);
}

@Override
public boolean supportsResourceTracking() {
return true;
}

@Override
public boolean shouldCancelChildrenOnCancellation() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public final String getDescription() {
return descriptionSupplier.get();
}

@Override
public boolean supportsResourceTracking() {
return true;
}

/**
* Attach a {@link SearchProgressListener} to this task.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.action.ActionResponse;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskCancelledException;
import org.opensearch.tasks.TaskId;
Expand Down Expand Up @@ -93,31 +94,39 @@ public final Task execute(Request request, ActionListener<Response> listener) {
*/
final Releasable unregisterChildNode = registerChildNode(request.getParentTask());
final Task task;

try {
task = taskManager.register("transport", actionName, request);
} catch (TaskCancelledException e) {
unregisterChildNode.close();
throw e;
}
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(response);

ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task);
try {
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(response);
}
}
}

@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(e);
@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(e);
}
}
}
});
});
} finally {
storedContext.close();
}

return task;
}

Expand All @@ -134,25 +143,30 @@ public final Task execute(Request request, TaskListener<Response> listener) {
unregisterChildNode.close();
throw e;
}
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(task, response);
ThreadContext.StoredContext storedContext = taskManager.taskExecutionStarted(task);
try {
execute(task, request, new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onResponse(task, response);
}
}
}

@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(task, e);
@Override
public void onFailure(Exception e) {
try {
Releasables.close(unregisterChildNode, () -> taskManager.unregister(task));
} finally {
listener.onFailure(task, e);
}
}
}
});
});
} finally {
storedContext.close();
}
return task;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
import org.opensearch.script.ScriptMetadata;
import org.opensearch.snapshots.SnapshotsInfoService;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.tasks.TaskResultsService;

import java.util.ArrayList;
Expand Down Expand Up @@ -396,6 +397,7 @@ protected void configure() {
bind(NodeMappingRefreshAction.class).asEagerSingleton();
bind(MappingUpdatedAction.class).asEagerSingleton();
bind(TaskResultsService.class).asEagerSingleton();
bind(TaskResourceTrackingService.class).asEagerSingleton();
bind(AllocationDeciders.class).toInstance(allocationDeciders);
bind(ShardsAllocator.class).toInstance(shardsAllocator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.index.ShardIndexingPressureMemoryManager;
import org.opensearch.index.ShardIndexingPressureSettings;
import org.opensearch.index.ShardIndexingPressureStore;
import org.opensearch.tasks.TaskResourceTrackingService;
import org.opensearch.watcher.ResourceWatcherService;
import org.opensearch.action.admin.cluster.configuration.TransportAddVotingConfigExclusionsAction;
import org.opensearch.action.admin.indices.close.TransportCloseIndexAction;
Expand Down Expand Up @@ -571,7 +572,8 @@ public void apply(Settings value, Settings current, Settings previous) {
ShardIndexingPressureMemoryManager.THROUGHPUT_DEGRADATION_LIMITS,
ShardIndexingPressureMemoryManager.SUCCESSFUL_REQUEST_ELAPSED_TIMEOUT,
ShardIndexingPressureMemoryManager.MAX_OUTSTANDING_REQUESTS,
IndexingPressure.MAX_INDEXING_BYTES
IndexingPressure.MAX_INDEXING_BYTES,
TaskResourceTrackingService.TASK_RESOURCE_TRACKING_ENABLED
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.node.Node;
import org.opensearch.threadpool.RunnableTaskExecutionListener;
import org.opensearch.threadpool.TaskAwareRunnable;

import java.util.List;
import java.util.Optional;
Expand All @@ -55,6 +57,7 @@
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -177,6 +180,31 @@ public static OpenSearchThreadPoolExecutor newFixed(
);
}

public static OpenSearchThreadPoolExecutor newAutoQueueFixed(
String name,
int size,
int initialQueueCapacity,
int minQueueSize,
int maxQueueSize,
int frameSize,
TimeValue targetedResponseTime,
ThreadFactory threadFactory,
ThreadContext contextHolder
) {
return newAutoQueueFixed(
name,
size,
initialQueueCapacity,
minQueueSize,
maxQueueSize,
frameSize,
targetedResponseTime,
threadFactory,
contextHolder,
null
);
}

/**
* Return a new executor that will automatically adjust the queue size based on queue throughput.
*
Expand All @@ -185,6 +213,7 @@ public static OpenSearchThreadPoolExecutor newFixed(
* @param minQueueSize minimum queue size that the queue can be adjusted to
* @param maxQueueSize maximum queue size that the queue can be adjusted to
* @param frameSize number of tasks during which stats are collected before adjusting queue size
* @param runnableTaskListener callback listener for a TaskAwareRunnable
*/
public static OpenSearchThreadPoolExecutor newAutoQueueFixed(
String name,
Expand All @@ -195,17 +224,30 @@ public static OpenSearchThreadPoolExecutor newAutoQueueFixed(
int frameSize,
TimeValue targetedResponseTime,
ThreadFactory threadFactory,
ThreadContext contextHolder
ThreadContext contextHolder,
AtomicReference<RunnableTaskExecutionListener> runnableTaskListener
) {
if (initialQueueCapacity <= 0) {
throw new IllegalArgumentException(
"initial queue capacity for [" + name + "] executor must be positive, got: " + initialQueueCapacity
);
}

ResizableBlockingQueue<Runnable> queue = new ResizableBlockingQueue<>(
ConcurrentCollections.<Runnable>newBlockingQueue(),
initialQueueCapacity
);

Function<Runnable, WrappedRunnable> runnableWrapper;
if (runnableTaskListener != null) {
runnableWrapper = (runnable) -> {
TaskAwareRunnable taskAwareRunnable = new TaskAwareRunnable(contextHolder, runnable, runnableTaskListener);
return new TimedRunnable(taskAwareRunnable);
};
} else {
runnableWrapper = TimedRunnable::new;
}

return new QueueResizingOpenSearchThreadPoolExecutor(
name,
size,
Expand All @@ -215,7 +257,7 @@ public static OpenSearchThreadPoolExecutor newAutoQueueFixed(
queue,
minQueueSize,
maxQueueSize,
TimedRunnable::new,
runnableWrapper,
frameSize,
targetedResponseTime,
threadFactory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@

import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_COUNT;
import static org.opensearch.http.HttpTransportSettings.SETTING_HTTP_MAX_WARNING_HEADER_SIZE;
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;

/**
* A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with
Expand Down Expand Up @@ -135,16 +136,23 @@ public StoredContext stashContext() {
* This is needed so the DeprecationLogger in another thread can see the value of X-Opaque-ID provided by a user.
* Otherwise when context is stash, it should be empty.
*/

ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT;

if (context.requestHeaders.containsKey(Task.X_OPAQUE_ID)) {
ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT.putHeaders(
threadContextStruct = threadContextStruct.putHeaders(
MapBuilder.<String, String>newMapBuilder()
.put(Task.X_OPAQUE_ID, context.requestHeaders.get(Task.X_OPAQUE_ID))
.immutableMap()
);
threadLocal.set(threadContextStruct);
} else {
threadLocal.set(DEFAULT_CONTEXT);
}

if (context.transientHeaders.containsKey(TASK_ID)) {
threadContextStruct = threadContextStruct.putTransient(TASK_ID, context.transientHeaders.get(TASK_ID));
}

threadLocal.set(threadContextStruct);

return () -> {
// If the node and thus the threadLocal get closed while this task
// is still executing, we don't want this runnable to fail with an
Expand Down
Loading