From 1b6fd481a6b4bbbdea0c93f652654a2ff2f7758d Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 7 Apr 2023 11:48:51 -0700 Subject: [PATCH] Add support of .opensearch-knn-model as system index to transport actions (#847) * Add thread context stashing to transport actions Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../knn/common/ThreadContextHelper.java | 39 +++++ .../org/opensearch/knn/indices/ModelDao.java | 134 +++++++----------- .../transport/DeleteModelTransportAction.java | 12 +- .../transport/GetModelTransportAction.java | 15 +- .../transport/SearchModelTransportAction.java | 19 ++- .../TrainingJobRouterTransportAction.java | 11 +- .../TrainingModelTransportAction.java | 39 +++-- .../knn/common/ThreadContextHelperTests.java | 52 +++++++ 9 files changed, 210 insertions(+), 112 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/common/ThreadContextHelper.java create mode 100644 src/test/java/org/opensearch/knn/common/ThreadContextHelperTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e6e1a958..0a871e143 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Bump numpy version from 1.22.x to 1.24.2 ([#811](https://github.com/opensearch-project/k-NN/pull/811)) * Support .opensearch-knn-model index as system index with security enabled ([#827](https://github.com/opensearch-project/k-NN/pull/827)) * Set gradle dependency scope for common-utils to testFixturesImplementation ([#844](https://github.com/opensearch-project/k-NN/pull/844)) +* Add support of .opensearch-knn-model as system index to transport actions ([#847](https://github.com/opensearch-project/k-NN/pull/847)) ### Documentation ### Maintenance ### Refactoring diff --git a/src/main/java/org/opensearch/knn/common/ThreadContextHelper.java b/src/main/java/org/opensearch/knn/common/ThreadContextHelper.java new file mode 100644 index 000000000..e0c5ad8f9 --- /dev/null +++ b/src/main/java/org/opensearch/knn/common/ThreadContextHelper.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.common; + +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; + +import java.util.function.Supplier; + +/** + * Class abstracts execution of runnable or function in specific context + */ +public class ThreadContextHelper { + + /** + * Sets the thread context to default and execute function, this needed to allow actions on model system index + * when security plugin is enabled + * @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing + */ + public static void runWithStashedThreadContext(Client client, Runnable function) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + function.run(); + } + } + + /** + * Sets the thread context to default and execute function, this needed to allow actions on model system index + * when security plugin is enabled + * @param function supplier function that needs to be executed after thread context has been stashed, return object + */ + public static T runWithStashedThreadContext(Client client, Supplier function) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + return function.get(); + } + } +} diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index cf0dd1890..a5b478213 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -13,7 +13,6 @@ import com.google.common.base.Charsets; import com.google.common.io.Resources; -import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchException; @@ -43,18 +42,18 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.index.IndexNotFoundException; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.common.ThreadContextHelper; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.knn.plugin.transport.GetModelResponse; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse; -import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; -import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest; import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest; +import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; +import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest; import java.io.IOException; import java.net.URL; @@ -64,7 +63,6 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.ExecutionException; -import java.util.function.Supplier; import static java.util.Objects.isNull; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH; @@ -219,21 +217,14 @@ public void create(ActionListener actionListener) throws IO if (isCreated()) { return; } - runWithStashedThreadContext(() -> { - CreateIndexRequest request; - try { - request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping()) - .settings( - Settings.builder() - .put("index.hidden", true) - .put("index.number_of_shards", this.numberOfShards) - .put("index.number_of_replicas", this.numberOfReplicas) - ); - } catch (IOException e) { - throw new RuntimeException(e); - } - client.admin().indices().create(request, actionListener); - }); + CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping()) + .settings( + Settings.builder() + .put("index.hidden", true) + .put("index.number_of_shards", this.numberOfShards) + .put("index.number_of_replicas", this.numberOfReplicas) + ); + client.admin().indices().create(request, actionListener); } @Override @@ -303,9 +294,8 @@ private void putInternal(Model model, ActionListener listener, Do parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, base64Model); } - final IndexRequestBuilder indexRequestBuilder = ModelDao.runWithStashedThreadContext( - () -> client.prepareIndex(MODEL_INDEX_NAME) - ); + IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME); + indexRequestBuilder.setId(model.getModelID()); indexRequestBuilder.setSource(parameters); @@ -315,8 +305,8 @@ private void putInternal(Model model, ActionListener listener, Do // After metadata update finishes, remove item from every node's cache if necessary. If no model id is // passed then nothing needs to be removed from the cache ActionListener onMetaListener; - onMetaListener = ActionListener.wrap(indexResponse -> { - client.execute( + onMetaListener = ActionListener.wrap( + indexResponse -> client.execute( RemoveModelFromCacheAction.INSTANCE, new RemoveModelFromCacheRequest(model.getModelID()), ActionListener.wrap(removeModelFromCacheResponse -> { @@ -329,8 +319,9 @@ private void putInternal(Model model, ActionListener listener, Do listener.onFailure(new RuntimeException(failureMessage)); }, listener::onFailure) - ); - }, listener::onFailure); + ), + listener::onFailure + ); // After the model is indexed, update metadata only if the model is in CREATED state ActionListener onIndexListener; @@ -367,14 +358,13 @@ private ActionListener getUpdateModelMetadataListener( ); } - @SneakyThrows @Override - public Model get(String modelId) { + public Model get(String modelId) throws ExecutionException, InterruptedException { /* GET //?_local */ try { - return ModelDao.runWithStashedThreadContext(() -> { + return ThreadContextHelper.runWithStashedThreadContext(client, () -> { GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) .setPreference("_local"); GetResponse getResponse; @@ -389,7 +379,16 @@ public Model get(String modelId) { } catch (RuntimeException runtimeException) { // we need to use RuntimeException as container for real exception to keep signature // of runWithStashedThreadContext generic - throw runtimeException.getCause(); + Throwable throwable = runtimeException.getCause(); + if (throwable != null) { + if (throwable instanceof InterruptedException) { + throw (InterruptedException) throwable; + } + if (throwable instanceof ExecutionException) { + throw (ExecutionException) throwable; + } + } + throw runtimeException; } } @@ -404,22 +403,20 @@ public void get(String modelId, ActionListener actionListener) /* GET //?_local */ - ModelDao.runWithStashedThreadContext(() -> { - GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) - .setPreference("_local"); - - getRequestBuilder.execute(ActionListener.wrap(response -> { - if (response.isSourceEmpty()) { - String errorMessage = String.format("Model \" %s \" does not exist", modelId); - actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage)); - return; - } - final Map responseMap = response.getSourceAsMap(); - Model model = Model.getModelFromSourceMap(responseMap); - actionListener.onResponse(new GetModelResponse(model)); + GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) + .setPreference("_local"); - }, actionListener::onFailure)); - }); + getRequestBuilder.execute(ActionListener.wrap(response -> { + if (response.isSourceEmpty()) { + String errorMessage = String.format("Model \" %s \" does not exist", modelId); + actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage)); + return; + } + final Map responseMap = response.getSourceAsMap(); + Model model = Model.getModelFromSourceMap(responseMap); + actionListener.onResponse(new GetModelResponse(model)); + + }, actionListener::onFailure)); } /** @@ -430,7 +427,7 @@ public void get(String modelId, ActionListener actionListener) */ @Override public void search(SearchRequest request, ActionListener actionListener) { - ModelDao.runWithStashedThreadContext(() -> { + ThreadContextHelper.runWithStashedThreadContext(client, () -> { request.indices(MODEL_INDEX_NAME); client.search(request, actionListener); }); @@ -533,17 +530,16 @@ public void delete(String modelId, ActionListener listener) ); // Setup delete model request - ModelDao.runWithStashedThreadContext(() -> { - DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME); - deleteRequestBuilder.setId(modelId); - deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - // On model metadata removal, delete the model from the index - clearModelMetadataStep.whenComplete( - acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder), - listener::onFailure - ); - }); + DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME); + deleteRequestBuilder.setId(modelId); + deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + // On model metadata removal, delete the model from the index + clearModelMetadataStep.whenComplete( + acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder), + listener::onFailure + ); + deleteModelFromIndexStep.whenComplete(deleteResponse -> { // If model is not deleted, remove modelId from model graveyard and return with error message if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) { @@ -682,26 +678,4 @@ private String buildRemoveModelErrorMessage(String modelId, RemoveModelFromCache return stringBuilder.toString(); } } - - /** - * Set the thread context to default, this is needed to allow actions on model system index - * when security plugin is enabled - * @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing - */ - private static void runWithStashedThreadContext(Runnable function) { - try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) { - function.run(); - } - } - - /** - * Set the thread context to default, this is needed to allow actions on model system index - * when security plugin is enabled - * @param function supplier function that needs to be executed after thread context has been stashed, return object - */ - private static T runWithStashedThreadContext(Supplier function) { - try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) { - return function.get(); - } - } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java index ee7f9e939..f535f37dc 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java @@ -14,7 +14,9 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; +import org.opensearch.knn.common.ThreadContextHelper; import org.opensearch.knn.indices.ModelDao; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -22,16 +24,20 @@ public class DeleteModelTransportAction extends HandledTransportAction { private final ModelDao modelDao; + private final Client client; @Inject - public DeleteModelTransportAction(TransportService transportService, ActionFilters filters) { + public DeleteModelTransportAction(TransportService transportService, ActionFilters filters, Client client) { super(DeleteModelAction.NAME, transportService, filters, DeleteModelRequest::new); this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + this.client = client; } @Override protected void doExecute(Task task, DeleteModelRequest request, ActionListener listener) { - String modelID = request.getModelID(); - modelDao.delete(modelID, listener); + ThreadContextHelper.runWithStashedThreadContext(client, () -> { + String modelID = request.getModelID(); + modelDao.delete(modelID, listener); + }); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java index e47a42d8d..23fa431d3 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java @@ -15,7 +15,9 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; +import org.opensearch.knn.common.ThreadContextHelper; import org.opensearch.knn.indices.ModelDao; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -27,17 +29,20 @@ public class GetModelTransportAction extends HandledTransportAction actionListener) { - String modelID = request.getModelID(); - - modelDao.get(modelID, actionListener); - + ThreadContextHelper.runWithStashedThreadContext(client, () -> { + String modelID = request.getModelID(); + modelDao.get(modelID, actionListener); + }); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/SearchModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/SearchModelTransportAction.java index 4d9f67059..53a08d80e 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/SearchModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/SearchModelTransportAction.java @@ -16,7 +16,9 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; +import org.opensearch.knn.common.ThreadContextHelper; import org.opensearch.knn.indices.ModelDao; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -26,18 +28,23 @@ public class SearchModelTransportAction extends HandledTransportAction { private ModelDao modelDao; + private final Client client; + @Inject - public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters) { + public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { super(SearchModelAction.NAME, transportService, actionFilters, SearchRequest::new); this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + this.client = client; } @Override protected void doExecute(Task task, SearchRequest request, ActionListener listener) { - try { - this.modelDao.search(request, listener); - } catch (IOException e) { - listener.onFailure(e); - } + ThreadContextHelper.runWithStashedThreadContext(client, () -> { + try { + this.modelDao.search(request, listener); + } catch (IOException e) { + listener.onFailure(e); + } + }); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java index 774029c58..b37c03a59 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java @@ -23,6 +23,7 @@ import org.opensearch.common.ValidationException; import org.opensearch.common.collect.ImmutableOpenMap; import org.opensearch.common.inject.Inject; +import org.opensearch.knn.common.ThreadContextHelper; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportRequestOptions; @@ -58,10 +59,12 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener // Get the size of the training request and then route the request. We get/set this here, as opposed to in // TrainingModelTransportAction, because in the future, we may want to use size to factor into our routing // decision. - getTrainingIndexSizeInKB(request, ActionListener.wrap(size -> { - request.setTrainingDataSizeInKB(size); - routeRequest(request, listener); - }, listener::onFailure)); + ThreadContextHelper.runWithStashedThreadContext(client, () -> { + getTrainingIndexSizeInKB(request, ActionListener.wrap(size -> { + request.setTrainingDataSizeInKB(size); + routeRequest(request, listener); + }, listener::onFailure)); + }); } protected void routeRequest(TrainingModelRequest request, ActionListener listener) { diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index a3c4be16e..1f5b85afd 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -14,8 +14,10 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.knn.common.ThreadContextHelper; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; @@ -34,10 +36,18 @@ public class TrainingModelTransportAction extends HandledTransportAction wrappedListener.onResponse(new TrainingModelResponse(indexResponse.getId())), - wrappedListener::onFailure - ) - ); - } catch (IOException e) { - wrappedListener.onFailure(e); - } + ThreadContextHelper.runWithStashedThreadContext(client, () -> { + try { + TrainingJobRunner.getInstance() + .execute( + trainingJob, + ActionListener.wrap( + indexResponse -> wrappedListener.onResponse(new TrainingModelResponse(indexResponse.getId())), + wrappedListener::onFailure + ) + ); + } catch (IOException e) { + wrappedListener.onFailure(e); + } + }); } } diff --git a/src/test/java/org/opensearch/knn/common/ThreadContextHelperTests.java b/src/test/java/org/opensearch/knn/common/ThreadContextHelperTests.java new file mode 100644 index 000000000..cb1e2ef02 --- /dev/null +++ b/src/test/java/org/opensearch/knn/common/ThreadContextHelperTests.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.common; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import java.util.function.Supplier; + +public class ThreadContextHelperTests extends KNNTestCase { + + public void testRunWithStashedContextRunnable() { + ThreadPool threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + threadPool.getThreadContext().putHeader("key", "value"); + NodeClient client = new NodeClient(Settings.EMPTY, threadPool); + + assertTrue(client.threadPool().getThreadContext().getHeaders().containsKey("key")); + + Runnable runnable = () -> { assertFalse(client.threadPool().getThreadContext().getHeaders().containsKey("key")); }; + ThreadContextHelper.runWithStashedThreadContext(client, () -> runnable); + + assertTrue(client.threadPool().getThreadContext().getHeaders().containsKey("key")); + + threadPool.shutdownNow(); + client.close(); + } + + public void testRunWithStashedContextSupplier() { + ThreadPool threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + threadPool.getThreadContext().putHeader("key", "value"); + NodeClient client = new NodeClient(Settings.EMPTY, threadPool); + + assertTrue(client.threadPool().getThreadContext().getHeaders().containsKey("key")); + + Supplier supplier = () -> { + assertFalse(client.threadPool().getThreadContext().getHeaders().containsKey("key")); + return this.getClass().getName(); + }; + ThreadContextHelper.runWithStashedThreadContext(client, () -> supplier); + + assertTrue(client.threadPool().getThreadContext().getHeaders().containsKey("key")); + + threadPool.shutdownNow(); + client.close(); + } +}