Skip to content

Commit

Permalink
Add support of .opensearch-knn-model as system index to transport act…
Browse files Browse the repository at this point in the history
…ions (#847)

* Add thread context stashing to transport actions

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Apr 7, 2023
1 parent 427cd32 commit 1b6fd48
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 112 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Bump numpy version from 1.22.x to 1.24.2 ([#811](https://github.com/opensearch-project/k-NN/pull/811))
* Support .opensearch-knn-model index as system index with security enabled ([#827](https://github.com/opensearch-project/k-NN/pull/827))
* Set gradle dependency scope for common-utils to testFixturesImplementation ([#844](https://github.com/opensearch-project/k-NN/pull/844))
* Add support of .opensearch-knn-model as system index to transport actions ([#847](https://github.com/opensearch-project/k-NN/pull/847))
### Documentation
### Maintenance
### Refactoring
Expand Down
39 changes: 39 additions & 0 deletions src/main/java/org/opensearch/knn/common/ThreadContextHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.common;

import org.opensearch.client.Client;
import org.opensearch.common.util.concurrent.ThreadContext;

import java.util.function.Supplier;

/**
* Class abstracts execution of runnable or function in specific context
*/
public class ThreadContextHelper {

/**
* Sets the thread context to default and execute function, this needed to allow actions on model system index
* when security plugin is enabled
* @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing
*/
public static void runWithStashedThreadContext(Client client, Runnable function) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
function.run();
}
}

/**
* Sets the thread context to default and execute function, this needed to allow actions on model system index
* when security plugin is enabled
* @param function supplier function that needs to be executed after thread context has been stashed, return object
*/
public static <T> T runWithStashedThreadContext(Client client, Supplier<T> function) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
return function.get();
}
}
}
134 changes: 54 additions & 80 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import lombok.SneakyThrows;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
Expand Down Expand Up @@ -43,18 +42,18 @@
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.common.ThreadContextHelper;
import org.opensearch.knn.plugin.transport.DeleteModelResponse;
import org.opensearch.knn.plugin.transport.GetModelResponse;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest;

import java.io.IOException;
import java.net.URL;
Expand All @@ -64,7 +63,6 @@
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;

import static java.util.Objects.isNull;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH;
Expand Down Expand Up @@ -219,21 +217,14 @@ public void create(ActionListener<CreateIndexResponse> actionListener) throws IO
if (isCreated()) {
return;
}
runWithStashedThreadContext(() -> {
CreateIndexRequest request;
try {
request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping())
.settings(
Settings.builder()
.put("index.hidden", true)
.put("index.number_of_shards", this.numberOfShards)
.put("index.number_of_replicas", this.numberOfReplicas)
);
} catch (IOException e) {
throw new RuntimeException(e);
}
client.admin().indices().create(request, actionListener);
});
CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping())
.settings(
Settings.builder()
.put("index.hidden", true)
.put("index.number_of_shards", this.numberOfShards)
.put("index.number_of_replicas", this.numberOfReplicas)
);
client.admin().indices().create(request, actionListener);
}

@Override
Expand Down Expand Up @@ -303,9 +294,8 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, base64Model);
}

final IndexRequestBuilder indexRequestBuilder = ModelDao.runWithStashedThreadContext(
() -> client.prepareIndex(MODEL_INDEX_NAME)
);
IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME);

indexRequestBuilder.setId(model.getModelID());
indexRequestBuilder.setSource(parameters);

Expand All @@ -315,8 +305,8 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
// After metadata update finishes, remove item from every node's cache if necessary. If no model id is
// passed then nothing needs to be removed from the cache
ActionListener<IndexResponse> onMetaListener;
onMetaListener = ActionListener.wrap(indexResponse -> {
client.execute(
onMetaListener = ActionListener.wrap(
indexResponse -> client.execute(
RemoveModelFromCacheAction.INSTANCE,
new RemoveModelFromCacheRequest(model.getModelID()),
ActionListener.wrap(removeModelFromCacheResponse -> {
Expand All @@ -329,8 +319,9 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do

listener.onFailure(new RuntimeException(failureMessage));
}, listener::onFailure)
);
}, listener::onFailure);
),
listener::onFailure
);

// After the model is indexed, update metadata only if the model is in CREATED state
ActionListener<IndexResponse> onIndexListener;
Expand Down Expand Up @@ -367,14 +358,13 @@ private ActionListener<IndexResponse> getUpdateModelMetadataListener(
);
}

@SneakyThrows
@Override
public Model get(String modelId) {
public Model get(String modelId) throws ExecutionException, InterruptedException {
/*
GET /<model_index>/<modelId>?_local
*/
try {
return ModelDao.runWithStashedThreadContext(() -> {
return ThreadContextHelper.runWithStashedThreadContext(client, () -> {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");
GetResponse getResponse;
Expand All @@ -389,7 +379,16 @@ public Model get(String modelId) {
} catch (RuntimeException runtimeException) {
// we need to use RuntimeException as container for real exception to keep signature
// of runWithStashedThreadContext generic
throw runtimeException.getCause();
Throwable throwable = runtimeException.getCause();
if (throwable != null) {
if (throwable instanceof InterruptedException) {
throw (InterruptedException) throwable;
}
if (throwable instanceof ExecutionException) {
throw (ExecutionException) throwable;
}
}
throw runtimeException;
}
}

Expand All @@ -404,22 +403,20 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
/*
GET /<model_index>/<modelId>?_local
*/
ModelDao.runWithStashedThreadContext(() -> {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");

getRequestBuilder.execute(ActionListener.wrap(response -> {
if (response.isSourceEmpty()) {
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage));
return;
}
final Map<String, Object> responseMap = response.getSourceAsMap();
Model model = Model.getModelFromSourceMap(responseMap);
actionListener.onResponse(new GetModelResponse(model));
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");

}, actionListener::onFailure));
});
getRequestBuilder.execute(ActionListener.wrap(response -> {
if (response.isSourceEmpty()) {
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage));
return;
}
final Map<String, Object> responseMap = response.getSourceAsMap();
Model model = Model.getModelFromSourceMap(responseMap);
actionListener.onResponse(new GetModelResponse(model));

}, actionListener::onFailure));
}

/**
Expand All @@ -430,7 +427,7 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
*/
@Override
public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
ModelDao.runWithStashedThreadContext(() -> {
ThreadContextHelper.runWithStashedThreadContext(client, () -> {
request.indices(MODEL_INDEX_NAME);
client.search(request, actionListener);
});
Expand Down Expand Up @@ -533,17 +530,16 @@ public void delete(String modelId, ActionListener<DeleteModelResponse> listener)
);

// Setup delete model request
ModelDao.runWithStashedThreadContext(() -> {
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model metadata removal, delete the model from the index
clearModelMetadataStep.whenComplete(
acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder),
listener::onFailure
);
});
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model metadata removal, delete the model from the index
clearModelMetadataStep.whenComplete(
acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder),
listener::onFailure
);

deleteModelFromIndexStep.whenComplete(deleteResponse -> {
// If model is not deleted, remove modelId from model graveyard and return with error message
if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) {
Expand Down Expand Up @@ -682,26 +678,4 @@ private String buildRemoveModelErrorMessage(String modelId, RemoveModelFromCache
return stringBuilder.toString();
}
}

/**
* Set the thread context to default, this is needed to allow actions on model system index
* when security plugin is enabled
* @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing
*/
private static void runWithStashedThreadContext(Runnable function) {
try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) {
function.run();
}
}

/**
* Set the thread context to default, this is needed to allow actions on model system index
* when security plugin is enabled
* @param function supplier function that needs to be executed after thread context has been stashed, return object
*/
private static <T> T runWithStashedThreadContext(Supplier<T> function) {
try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) {
return function.get();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,30 @@
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.common.ThreadContextHelper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class DeleteModelTransportAction extends HandledTransportAction<DeleteModelRequest, DeleteModelResponse> {

private final ModelDao modelDao;
private final Client client;

@Inject
public DeleteModelTransportAction(TransportService transportService, ActionFilters filters) {
public DeleteModelTransportAction(TransportService transportService, ActionFilters filters, Client client) {
super(DeleteModelAction.NAME, transportService, filters, DeleteModelRequest::new);
this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
this.client = client;
}

@Override
protected void doExecute(Task task, DeleteModelRequest request, ActionListener<DeleteModelResponse> listener) {
String modelID = request.getModelID();
modelDao.delete(modelID, listener);
ThreadContextHelper.runWithStashedThreadContext(client, () -> {
String modelID = request.getModelID();
modelDao.delete(modelID, listener);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.common.ThreadContextHelper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -27,17 +29,20 @@ public class GetModelTransportAction extends HandledTransportAction<GetModelRequ
private static final Logger LOG = LogManager.getLogger(GetModelTransportAction.class);
private ModelDao modelDao;

private final Client client;

@Inject
public GetModelTransportAction(TransportService transportService, ActionFilters actionFilters) {
public GetModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(GetModelAction.NAME, transportService, actionFilters, GetModelRequest::new);
this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
this.client = client;
}

@Override
protected void doExecute(Task task, GetModelRequest request, ActionListener<GetModelResponse> actionListener) {
String modelID = request.getModelID();

modelDao.get(modelID, actionListener);

ThreadContextHelper.runWithStashedThreadContext(client, () -> {
String modelID = request.getModelID();
modelDao.get(modelID, actionListener);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.common.ThreadContextHelper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -26,18 +28,23 @@
public class SearchModelTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {
private ModelDao modelDao;

private final Client client;

@Inject
public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters) {
public SearchModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
super(SearchModelAction.NAME, transportService, actionFilters, SearchRequest::new);
this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
this.client = client;
}

@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> listener) {
try {
this.modelDao.search(request, listener);
} catch (IOException e) {
listener.onFailure(e);
}
ThreadContextHelper.runWithStashedThreadContext(client, () -> {
try {
this.modelDao.search(request, listener);
} catch (IOException e) {
listener.onFailure(e);
}
});
}
}
Loading

0 comments on commit 1b6fd48

Please sign in to comment.