From b362a1b640f4a710ec4d68f93db234295f758417 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 4 Oct 2023 17:15:16 -0700 Subject: [PATCH] fix more places where thread context not restored (#1421) (#1423) (#1436) * fix more places where thread context not restored * fix failed ut * remove unused import --------- Signed-off-by: Yaliang Wu (cherry picked from commit d8c1162a14b1e64d2d4f1b73ea2c135054cedca3) Co-authored-by: Yaliang Wu --- .../ml/action/handler/MLSearchHandler.java | 11 +++++----- .../DeleteModelGroupTransportAction.java | 13 ++++++------ .../TransportUpdateModelGroupAction.java | 16 +++++++------- .../models/DeleteModelTransportAction.java | 19 +++++++++-------- .../models/GetModelTransportAction.java | 19 +++++++++-------- .../tasks/DeleteTaskTransportAction.java | 11 +++++----- .../action/tasks/GetTaskTransportAction.java | 4 ++-- .../tasks/SearchTaskTransportAction.java | 2 +- .../TransportUndeployModelsAction.java | 4 ++-- .../upload_chunk/MLModelChunkUploader.java | 21 ++++++++++--------- .../tasks/SearchTaskTransportActionTests.java | 4 +++- 11 files changed, 67 insertions(+), 57 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index 029605fca9..682189ad34 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -80,6 +80,7 @@ public void search(SearchRequest request, ActionListener actionL User user = RestActionUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, "Fail to search model version"); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); List excludes = Optional .ofNullable(request.source()) .map(SearchSourceBuilder::fetchSource) @@ -98,9 +99,9 @@ public void search(SearchRequest request, ActionListener actionL ); request.source().fetchSource(rebuiltFetchSourceContext); if (modelAccessControlHelper.skipModelAccessControl(user)) { - client.search(request, listener); + client.search(request, wrappedListener); } else if (!clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) { - client.search(request, listener); + client.search(request, wrappedListener); } else { SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user); SearchRequest modelGroupSearchRequest = new SearchRequest(); @@ -119,15 +120,15 @@ public void search(SearchRequest request, ActionListener actionL Arrays.stream(r.getHits().getHits()).forEach(hit -> { modelGroupIds.add(hit.getId()); }); request.source().query(rewriteQueryBuilder(request.source().query(), modelGroupIds)); - client.search(request, listener); + client.search(request, wrappedListener); } else { log.debug("No model group found"); request.source().query(rewriteQueryBuilder(request.source().query(), null)); - client.search(request, listener); + client.search(request, wrappedListener); } }, e -> { log.error("Fail to search model groups!", e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); }); client.search(modelGroupSearchRequest, modelGroupSearchActionListener); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java index cbd55a8c22..d7b9d8748b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java @@ -72,9 +72,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> { if (!access) { - actionListener.onFailure(new MLValidationException("User doesn't have privilege to delete this model group")); + wrappedListener.onFailure(new MLValidationException("User doesn't have privilege to delete this model group")); } else { BoolQueryBuilder query = new BoolQueryBuilder(); query.filter(new TermQueryBuilder(PARAMETER_MODEL_GROUP_ID, modelGroupId)); @@ -87,13 +88,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { - actionListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); } else { log.error("Failed to search models with the specified Model Group Id " + modelGroupId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); } })); } }, e -> { log.error("Failed to validate Access for Model Group " + modelGroupId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); })); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 5d53c4dea9..494f197857 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -93,6 +93,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { if (modelGroup.isExists()) { try ( @@ -102,17 +103,17 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { - listener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); } else { logException("Failed to get model group", e, log); - listener.onFailure(e); + wrappedListener.onFailure(e); } })); } catch (Exception e) { @@ -186,15 +187,16 @@ private void updateModelGroup(String modelGroupId, Map source, A UpdateRequest updateModelGroupRequest = new UpdateRequest(); updateModelGroupRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId).doc(source); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); client .update( updateModelGroupRequest, - ActionListener.wrap(r -> { listener.onResponse(new MLUpdateModelGroupResponse("Updated")); }, e -> { + ActionListener.wrap(r -> { wrappedListener.onResponse(new MLUpdateModelGroupResponse("Updated")); }, e -> { if (e instanceof IndexNotFoundException) { - listener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); } else { log.error("Failed to update model group", e, log); - listener.onFailure(new MLValidationException("Failed to update Model Group")); + wrappedListener.onFailure(new MLValidationException("Failed to update Model Group")); } }) ); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index 2948b255d1..5d89e1c113 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -99,6 +99,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { @@ -113,7 +114,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (!access) { - actionListener + wrappedListener .onFailure( new MLValidationException("User doesn't have privilege to perform this operation on this model") ); @@ -125,7 +126,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { log.error("Failed to Search Model index " + modelId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); })); } else { - deleteModel(modelId, mlModel.getModelGroupId(), false, actionListener); + deleteModel(modelId, mlModel.getModelGroupId(), false, wrappedListener); } } }, e -> { log.error("Failed to validate Access for Model Id " + modelId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); })); } catch (Exception e) { log.error("Failed to parse ml model " + r.getId(), e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); } } else { - actionListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); + wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); } - }, e -> { actionListener.onFailure(e); })); + }, e -> { wrappedListener.onFailure(e); })); } catch (Exception e) { log.error("Failed to delete ML model " + modelId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index 8aa03d54a1..3e508f1f64 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -79,7 +79,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); + client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); @@ -90,7 +91,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (!access) { - actionListener + wrappedListener .onFailure( new MLValidationException("User Doesn't have privilege to perform this operation on this model") ); @@ -100,19 +101,19 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { log.error("Failed to validate Access for Model Id " + modelId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); })); } catch (Exception e) { log.error("Failed to parse ml model " + r.getId(), e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); } } else { - actionListener + wrappedListener .onFailure( new OpenSearchStatusException( "Failed to find model with the provided model id: " + modelId, @@ -122,12 +123,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { - actionListener.onFailure(new MLResourceNotFoundException("Fail to find model")); + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model")); } else { log.error("Failed to get ML model " + modelId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); } - }), () -> context.restore())); + })); } catch (Exception e) { log.error("Failed to get ML model " + modelId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java index 753dcb4a05..eea31bb37e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java @@ -57,6 +57,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { @@ -72,24 +73,24 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { actionListener.onFailure(new MLResourceNotFoundException("Fail to find task")); })); + }, e -> { wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find task")); })); } catch (Exception e) { log.error("Failed to delete ml task " + taskId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index ac88e12821..88c05f71c1 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -57,7 +57,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { log.debug("Completed Get Task Request, id:{}", taskId); if (r != null && r.isExists()) { @@ -79,7 +79,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener context.restore())); } catch (Exception e) { log.error("Failed to get ML task " + taskId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java index 8e12e29966..37b0d49e01 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java @@ -32,7 +32,7 @@ public SearchTaskTransportAction(TransportService transportService, ActionFilter @Override protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.search(request, actionListener); + client.search(request, ActionListener.runBefore(actionListener, () -> context.restore())); } catch (Exception e) { log.error(e.getMessage(), e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index 5bf67f291c..a4f0b9f2f9 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -116,12 +116,12 @@ private void validateAccess(String modelId, ActionListener listener) { User user = RestActionUtils.getUserContext(client); String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + mlModelManager.getModel(modelId, null, excludes, ActionListener.runBefore(ActionListener.wrap(mlModel -> { modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, listener); }, e -> { log.error("Failed to find Model", e); listener.onFailure(e); - })); + }), () -> context.restore())); } catch (Exception e) { log.error("Failed to undeploy ML model"); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java index b98cdbe98a..1227703a21 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java @@ -68,6 +68,7 @@ public void uploadModelChunk(MLUploadModelChunkInput uploadModelChunkInput, Acti User user = RestActionUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { @@ -82,7 +83,7 @@ public void uploadModelChunk(MLUploadModelChunkInput uploadModelChunkInput, Acti .validateModelGroupAccess(user, existingModel.getModelGroupId(), client, ActionListener.wrap(access -> { if (!access) { log.error("You don't have permissions to perform this operation on this model."); - listener + wrappedListener .onFailure( new IllegalArgumentException( "You don't have permissions to perform this operation on this model." @@ -167,36 +168,36 @@ public void uploadModelChunk(MLUploadModelChunkInput uploadModelChunkInput, Acti }, e -> { log.error("Failed to update model state", e); semaphore.release(); - listener.onFailure(e); + wrappedListener.onFailure(e); })); } - listener.onResponse(new MLUploadModelChunkResponse("Uploaded")); + wrappedListener.onResponse(new MLUploadModelChunkResponse("Uploaded")); }, e -> { log.error("Failed to upload chunk model", e); - listener.onFailure(e); + wrappedListener.onFailure(e); })); }, ex -> { log.error("Failed to init model index", ex); - listener.onFailure(ex); + wrappedListener.onFailure(ex); })); } }, e -> { logException("Failed to validate model access", e, log); - listener.onFailure(e); + wrappedListener.onFailure(e); })); } catch (Exception e) { log.error("Failed to parse ml model " + r.getId(), e); - listener.onFailure(e); + wrappedListener.onFailure(e); } } else { - listener.onFailure(new MLResourceNotFoundException("Failed to find model")); + wrappedListener.onFailure(new MLResourceNotFoundException("Failed to find model")); } }, e -> { if (e instanceof IndexNotFoundException) { - listener.onFailure(new MLResourceNotFoundException("Failed to find model")); + wrappedListener.onFailure(new MLResourceNotFoundException("Failed to find model")); } else { log.error("Failed to get ML model " + modelId, e); - listener.onFailure(e); + wrappedListener.onFailure(e); } })); } catch (Exception e) { diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java index 1db0344333..3ad05f337a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java @@ -5,6 +5,8 @@ package org.opensearch.ml.action.tasks; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -61,6 +63,6 @@ public void setup() { public void test_DoExecute() { searchTaskTransportAction.doExecute(null, searchRequest, actionListener); - verify(client).search(searchRequest, actionListener); + verify(client).search(eq(searchRequest), any()); } }