Skip to content

Commit

Permalink
Stash context only in methods that work with system index
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Apr 4, 2023
1 parent d04a856 commit de3e487
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 107 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add CHANGELOG ([#800](https://github.com/opensearch-project/k-NN/pull/800))
* Bump byte-buddy version from 1.12.22 to 1.14.2 ([#804](https://github.com/opensearch-project/k-NN/pull/804))
* Bump numpy version from 1.22.x to 1.24.2 ([#811](https://github.com/opensearch-project/k-NN/pull/811))
* Add support for integ tests on secured cluster ([#827](https://github.com/opensearch-project/k-NN/pull/827))
* Support .opensearch-knn-model index as system index with security enabled ([#827](https://github.com/opensearch-project/k-NN/pull/827))
### Documentation
### Maintenance
### Refactoring
Expand Down
185 changes: 79 additions & 106 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -317,26 +317,20 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
// passed then nothing needs to be removed from the cache
ActionListener<IndexResponse> onMetaListener;
onMetaListener = ActionListener.wrap(indexResponse -> {
// temporary setting thread context to default, this is needed to allow actions on model system index
// when security plugin is enabled
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.execute(
RemoveModelFromCacheAction.INSTANCE,
new RemoveModelFromCacheRequest(model.getModelID()),
ActionListener.wrap(removeModelFromCacheResponse -> {
if (!removeModelFromCacheResponse.hasFailures()) {
listener.onResponse(indexResponse);
return;
}

String failureMessage = buildRemoveModelErrorMessage(model.getModelID(), removeModelFromCacheResponse);

listener.onFailure(new RuntimeException(failureMessage));
}, listener::onFailure)
);
} catch (Exception e) {
listener.onFailure(e);
}
client.execute(
RemoveModelFromCacheAction.INSTANCE,
new RemoveModelFromCacheRequest(model.getModelID()),
ActionListener.wrap(removeModelFromCacheResponse -> {
if (!removeModelFromCacheResponse.hasFailures()) {
listener.onResponse(indexResponse);
return;
}

String failureMessage = buildRemoveModelErrorMessage(model.getModelID(), removeModelFromCacheResponse);

listener.onFailure(new RuntimeException(failureMessage));
}, listener::onFailure)
);
}, listener::onFailure);

// After the model is indexed, update metadata only if the model is in CREATED state
Expand All @@ -362,20 +356,16 @@ private ActionListener<IndexResponse> getUpdateModelMetadataListener(
ModelMetadata modelMetadata,
ActionListener<IndexResponse> listener
) {
// temporary setting thread context to default, this is needed to allow actions on model system index
// when security plugin is enabled
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
return ActionListener.wrap(
indexResponse -> client.execute(
UpdateModelMetadataAction.INSTANCE,
new UpdateModelMetadataRequest(indexResponse.getId(), false, modelMetadata),
// Here we wrap the IndexResponse listener around an AcknowledgedListener. This allows us
// to pass the indexResponse back up.
ActionListener.wrap(acknowledgedResponse -> listener.onResponse(indexResponse), listener::onFailure)
),
listener::onFailure
);
}
return ActionListener.wrap(
indexResponse -> client.execute(
UpdateModelMetadataAction.INSTANCE,
new UpdateModelMetadataRequest(indexResponse.getId(), false, modelMetadata),
// Here we wrap the IndexResponse listener around an AcknowledgedListener. This allows us
// to pass the indexResponse back up.
ActionListener.wrap(acknowledgedResponse -> listener.onResponse(indexResponse), listener::onFailure)
),
listener::onFailure
);
}

@Override
Expand Down Expand Up @@ -552,23 +542,23 @@ public void delete(String modelId, ActionListener<DeleteModelResponse> listener)
acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder),
listener::onFailure
);

deleteModelFromIndexStep.whenComplete(deleteResponse -> {
// If model is not deleted, remove modelId from model graveyard and return with error message
if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) {
updateModelGraveyardToDelete(modelId, true, unblockModelIdStep, Optional.empty());
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), errorMessage));
return;
}

// After model is deleted from the index, make sure the model is evicted from every cache in the cluster
removeModelFromCache(modelId, clearModelFromCacheStep);
}, e -> listener.onFailure(new OpenSearchException(e)));
} catch (Exception e) {
listener.onFailure(e);
}

deleteModelFromIndexStep.whenComplete(deleteResponse -> {
// If model is not deleted, remove modelId from model graveyard and return with error message
if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) {
updateModelGraveyardToDelete(modelId, true, unblockModelIdStep, Optional.empty());
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), errorMessage));
return;
}

// After model is deleted from the index, make sure the model is evicted from every cache in the cluster
removeModelFromCache(modelId, clearModelFromCacheStep);
}, e -> listener.onFailure(new OpenSearchException(e)));

clearModelFromCacheStep.whenComplete(removeModelFromCacheResponse -> {

// If there are any failures while removing model from the cache build the error message
Expand Down Expand Up @@ -624,76 +614,59 @@ private void updateModelGraveyardToDelete(
StepListener<AcknowledgedResponse> step,
Optional<Exception> exception
) {
// temporary setting thread context to default, this is needed to allow actions on model system index
// when security plugin is enabled
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.execute(
UpdateModelGraveyardAction.INSTANCE,
new UpdateModelGraveyardRequest(modelId, isRemoveRequest),
ActionListener.wrap(acknowledgedResponse -> {
if (exception.isEmpty()) {
step.onResponse(acknowledgedResponse);
return;
}
throw exception.get();

}, e -> {
// If it fails to remove the modelId from Model Graveyard, then log the error message
String errorMessage = String.format("Failed to remove \" %s \" from Model Graveyard", modelId);
String failureMessage = String.format("%s%s%s", errorMessage, "\n", e.getMessage());
logger.error(failureMessage);
client.execute(
UpdateModelGraveyardAction.INSTANCE,
new UpdateModelGraveyardRequest(modelId, isRemoveRequest),
ActionListener.wrap(acknowledgedResponse -> {
if (exception.isEmpty()) {
step.onResponse(acknowledgedResponse);
return;
}
throw exception.get();

if (exception.isEmpty()) {
step.onFailure(e);
return;
}
step.onFailure(exception.get());
})
);
} catch (Exception e) {
step.onFailure(e);
}
}, e -> {
// If it fails to remove the modelId from Model Graveyard, then log the error message
String errorMessage = String.format("Failed to remove \" %s \" from Model Graveyard", modelId);
String failureMessage = String.format("%s%s%s", errorMessage, "\n", e.getMessage());
logger.error(failureMessage);

if (exception.isEmpty()) {
step.onFailure(e);
return;
}
step.onFailure(exception.get());
})
);
}

// Clear the metadata of the model for a given modelId
private void clearModelMetadata(String modelId, StepListener<AcknowledgedResponse> clearModelMetadataStep) {
// temporary setting thread context to default, this is needed to allow actions on model system index
// when security plugin is enabled
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.execute(
UpdateModelMetadataAction.INSTANCE,
new UpdateModelMetadataRequest(modelId, true, null),
ActionListener.wrap(
clearModelMetadataStep::onResponse,
exception -> removeModelIdFromGraveyardOnFailure(modelId, exception, clearModelMetadataStep)
)
);
} catch (Exception e) {
clearModelMetadataStep.onFailure(e);
}
client.execute(
UpdateModelMetadataAction.INSTANCE,
new UpdateModelMetadataRequest(modelId, true, null),
ActionListener.wrap(
clearModelMetadataStep::onResponse,
exception -> removeModelIdFromGraveyardOnFailure(modelId, exception, clearModelMetadataStep)
)
);
}

// This function helps to remove the model from model graveyard and return the exception from previous step
// when the delete request fails while executing after adding modelId to model graveyard
private void removeModelIdFromGraveyardOnFailure(String modelId, Exception exceptionFromPreviousStep, StepListener<?> step) {
// temporary setting thread context to default, this is needed to allow actions on model system index
// when security plugin is enabled
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.execute(
UpdateModelGraveyardAction.INSTANCE,
new UpdateModelGraveyardRequest(modelId, true),
ActionListener.wrap(acknowledgedResponse -> { throw exceptionFromPreviousStep; }, unblockingFailedException -> {
// If it fails to remove the modelId from Model Graveyard, then log the error message and
// throw the exception that was passed as a parameter from previous step
String errorMessage = String.format("Failed to remove \" %s \" from Model Graveyard", modelId);
String failureMessage = String.format("%s%s%s", errorMessage, "\n", unblockingFailedException.getMessage());
logger.error(failureMessage);
step.onFailure(exceptionFromPreviousStep);
})
);
} catch (Exception e) {
step.onFailure(e);
}
client.execute(
UpdateModelGraveyardAction.INSTANCE,
new UpdateModelGraveyardRequest(modelId, true),
ActionListener.wrap(acknowledgedResponse -> { throw exceptionFromPreviousStep; }, unblockingFailedException -> {
// If it fails to remove the modelId from Model Graveyard, then log the error message and
// throw the exception that was passed as a parameter from previous step
String errorMessage = String.format("Failed to remove \" %s \" from Model Graveyard", modelId);
String failureMessage = String.format("%s%s%s", errorMessage, "\n", unblockingFailedException.getMessage());
logger.error(failureMessage);
step.onFailure(exceptionFromPreviousStep);
})
);
}

private String buildRemoveModelErrorMessage(String modelId, RemoveModelFromCacheResponse response) {
Expand Down

0 comments on commit de3e487

Please sign in to comment.