Skip to content

Commit

Permalink
restore thread context before running action listener
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Oct 2, 2023
1 parent df06324 commit 4839729
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,14 @@ void registerModel(ActionListener<MLRegisterModelResponse> listener) throws Inte
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
modelGroup.toXContent(builder, ToXContent.EMPTY_PARAMS);
createModelGroupRequest.source(builder);
client.index(createModelGroupRequest, ActionListener.wrap(r -> {
client.index(createModelGroupRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
client.execute(MLRegisterModelAction.INSTANCE, registerRequest, ActionListener.wrap(listener::onResponse, e -> {
log.error("Failed to Register Model", e);
listener.onFailure(e);
}));
}, e-> {
listener.onFailure(e);
}));
}), () -> context.restore()));
} catch (IOException e) {
throw new MLException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ private void initMasterKey() {
if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
client.get(getRequest, new LatchedActionListener(ActionListener.<GetResponse>wrap(r -> {
client.get(getRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.<GetResponse>wrap(r -> {
if (r.isExists()) {
String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY);
this.masterKey = masterKey;
Expand All @@ -120,7 +120,7 @@ private void initMasterKey() {
}, e -> {
log.error("Failed to get ML encryption master key", e);
exceptionRef.set(e);
}), latch));
}), latch), () -> context.restore()));
}
} else {
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.query(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId));
searchRequest.source(sourceBuilder);
client.search(searchRequest, ActionListener.wrap(searchResponse -> {
client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(searchResponse -> {
SearchHit[] searchHits = searchResponse.getHits().getHits();
if (searchHits.length == 0) {
deleteConnector(deleteRequest, connectorId, actionListener);
Expand All @@ -92,7 +92,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
}
log.error("Failed to delete ML connector: " + connectorId, e);
actionListener.onFailure(e);
}));
}), () -> context.restore()));
} catch (Exception e) {
log.error(e.getMessage(), e);
actionListener.onFailure(e);
Expand All @@ -108,7 +108,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete

private void deleteConnector(DeleteRequest deleteRequest, String connectorId, ActionListener<DeleteResponse> actionListener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.delete(deleteRequest, new ActionListener<>() {
client.delete(deleteRequest, ActionListener.runBefore(new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) {
Expand All @@ -125,7 +125,7 @@ public void onFailure(Exception e) {
log.error("Failed to delete ML connector: " + connectorId, e);
actionListener.onFailure(e);
}
});
}, () -> context.restore()));
} catch (Exception e) {
log.error("Failed to delete ML connector: " + connectorId, e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ protected void doExecute(Task task, SearchRequest request, ActionListener<Search
private void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
User user = RestActionUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
List<String> excludes = Optional
.ofNullable(request.source())
.map(SearchSourceBuilder::fetchSource)
Expand All @@ -78,11 +79,11 @@ private void search(SearchRequest request, ActionListener<SearchResponse> action
);
request.source().fetchSource(rebuiltFetchSourceContext);
if (connectorAccessControlHelper.skipConnectorAccessControl(user)) {
client.search(request, actionListener);
client.search(request, wrappedListener);
} else {
SearchSourceBuilder sourceBuilder = connectorAccessControlHelper.addUserBackendRolesFilter(user, request.source());
request.source(sourceBuilder);
client.search(request, actionListener);
client.search(request, wrappedListener);
}
} catch (Exception e) {
log.error(e.getMessage(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,15 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLDeployModelResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
if (functionName == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
listener
wrappedListener
.onFailure(new MLValidationException("User Doesn't have privilege to perform this operation on this model"));
} else {
String[] targetNodeIds = deployModelRequest.getModelNodeIds();
Expand Down Expand Up @@ -172,7 +173,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
Set<String> difference = new HashSet<String>(Arrays.asList(workerNodes));
difference.removeAll(Arrays.asList(targetNodeIds));
if (difference.size() > 0) {
listener
wrappedListener
.onFailure(
new IllegalArgumentException(
"Model already deployed to these nodes: "
Expand All @@ -188,7 +189,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
eligibleNodes.addAll(Arrays.asList(allEligibleNodes));
}
if (nodeIds.size() == 0) {
listener.onFailure(new IllegalArgumentException("no eligible node found"));
wrappedListener.onFailure(new IllegalArgumentException("no eligible node found"));
return;
}

Expand All @@ -215,7 +216,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
mlTask.setTaskId(taskId);
try {
mlTaskManager.add(mlTask, nodeIds);
listener.onResponse(new MLDeployModelResponse(taskId, MLTaskState.CREATED.name()));
wrappedListener.onResponse(new MLDeployModelResponse(taskId, MLTaskState.CREATED.name()));
threadPool
.executor(DEPLOY_THREAD_POOL)
.execute(
Expand All @@ -238,20 +239,20 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
TASK_SEMAPHORE_TIMEOUT,
true
);
listener.onFailure(ex);
wrappedListener.onFailure(ex);
}
}, exception -> {
log.error("Failed to create deploy model task for " + modelId, exception);
listener.onFailure(exception);
wrappedListener.onFailure(exception);
}));
}
}, e -> {
log.error("Failed to Validate Access for ModelId " + modelId, e);
listener.onFailure(e);
wrappedListener.onFailure(e);
}));
}, e -> {
log.error("Failed to deploy model " + modelId, e);
listener.onFailure(e);
wrappedListener.onFailure(e);
}));
} catch (Exception e) {
log.error("Failed to get ML model " + modelId, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ protected void doExecute(Task task, SearchRequest request, ActionListener<Search

private void preProcessRoleAndPerformSearch(SearchRequest request, User user, ActionListener<SearchResponse> listener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
if (modelAccessControlHelper.skipModelAccessControl(user)) {
client.search(request, listener);
client.search(request, wrappedListener);
} else {
// Security is enabled, filter is enabled and user isn't admin
modelAccessControlHelper.addUserBackendRolesFilter(user, request.source());
log.debug("Filtering result by " + user.getBackendRoles());
client.search(request, listener);
client.search(request, wrappedListener);
}
} catch (Exception e) {
log.error("Failed to search", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
final User userInfo = user;

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLTaskResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
mlModelManager.getModel(modelId, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
modelAccessControlHelper
.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
listener
wrappedListener
.onFailure(
new MLValidationException("User Doesn't have privilege to perform this operation on this model")
);
Expand All @@ -100,7 +101,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
long startTime = System.nanoTime();
mlPredictTaskRunner
.run(functionName, mlPredictionTaskRequest, transportService, ActionListener.runAfter(listener, () -> {
.run(functionName, mlPredictionTaskRequest, transportService, ActionListener.runAfter(wrappedListener, () -> {
long endTime = System.nanoTime();
double durationInMs = (endTime - startTime) / 1e6;
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
Expand All @@ -109,11 +110,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
}
}, e -> {
log.error("Failed to Validate Access for ModelId " + modelId, e);
listener.onFailure(e);
wrappedListener.onFailure(e);
}));
}, e -> {
log.error("Failed to find model " + modelId, e);
listener.onFailure(e);
wrappedListener.onFailure(e);
}));

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,15 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str
String modelName = input.getName();
User user = RestActionUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<String> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
validateUniqueModelGroupName(input.getName(), ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
Iterator<SearchHit> iterator = modelGroups.getHits().iterator();
while (iterator.hasNext()) {
String id = iterator.next().getId();
listener
wrappedListener
.onFailure(
new IllegalArgumentException(
"The name you provided is already being used by another model with ID: "
Expand Down Expand Up @@ -121,19 +122,19 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str

client.index(indexRequest, ActionListener.wrap(r -> {
log.debug("Indexed model group doc successfully {}", modelName);
listener.onResponse(r.getId());
wrappedListener.onResponse(r.getId());
}, e -> {
log.error("Failed to index model group doc", e);
listener.onFailure(e);
wrappedListener.onFailure(e);
}));
}, ex -> {
log.error("Failed to init model group index", ex);
listener.onFailure(ex);
wrappedListener.onFailure(ex);
}));
}
}, e -> {
log.error("Failed to search model group index", e);
listener.onFailure(e);
wrappedListener.onFailure(e);
}));
} catch (Exception e) {
log.error("Failed to create model group doc", e);
Expand Down
Loading

0 comments on commit 4839729

Please sign in to comment.