diff --git a/CHANGELOG.md b/CHANGELOG.md index d46a9c41b4..e3c07fc1db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x) ### Features +* Add Clear Cache API [#740](https://github.com/opensearch-project/k-NN/pull/740) ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) * Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ClearCacheIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ClearCacheIT.java new file mode 100644 index 0000000000..045821e09e --- /dev/null +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ClearCacheIT.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.bwc; + +import java.util.Collections; +import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; + +public class ClearCacheIT extends AbstractRestartUpgradeTestCase { + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 5; + private static int docId = 0; + private static final int NUM_DOCS = 10; + private static int queryCnt = 0; + private static final int K = 5; + + // Restart Upgrade BWC Tests to validate Clear Cache API + public void testClearCache() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + if (isRunningAgainstOldCluster()) { + createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docId, NUM_DOCS); + } else { + queryCnt = NUM_DOCS; + validateClearCacheOnUpgrade(queryCnt); + + docId = NUM_DOCS; + addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docId, NUM_DOCS); + + queryCnt = queryCnt + NUM_DOCS; + validateClearCacheOnUpgrade(queryCnt); + deleteKNNIndex(testIndex); + } + } + + // validation steps for Clear Cache API after upgrading node to new version + private void validateClearCacheOnUpgrade(int queryCount) throws Exception { + int graphCount = getTotalGraphsInCache(); + knnWarmup(Collections.singletonList(testIndex)); + assertTrue(getTotalGraphsInCache() > graphCount); + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, queryCount, K); + + clearCache(Collections.singletonList(testIndex)); + assertEquals(0, getTotalGraphsInCache()); + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, queryCount, K); + } +} diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/ClearCacheIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/ClearCacheIT.java new file mode 100644 index 0000000000..24c674d0d2 --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/ClearCacheIT.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.bwc; + +import java.util.Collections; + +import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; + +public class ClearCacheIT extends AbstractRollingUpgradeTestCase { + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 5; + private static int docId = 0; + private static final int K = 5; + private static final int NUM_DOCS = 10; + private static int queryCnt = 0; + + // Rolling Upgrade BWC Tests to validate Clear Cache API + public void testClearCache() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + switch (getClusterType()) { + case OLD: + createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + int docIdOld = 0; + addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docIdOld, NUM_DOCS); + break; + case UPGRADED: + queryCnt = NUM_DOCS; + validateClearCacheOnUpgrade(queryCnt); + + docId = NUM_DOCS; + addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docId, NUM_DOCS); + + queryCnt = queryCnt + NUM_DOCS; + validateClearCacheOnUpgrade(queryCnt); + deleteKNNIndex(testIndex); + } + + } + + // validation steps for Clear Cache API after upgrading all nodes from old version to new version + public void validateClearCacheOnUpgrade(int queryCount) throws Exception { + int graphCount = getTotalGraphsInCache(); + knnWarmup(Collections.singletonList(testIndex)); + assertTrue(getTotalGraphsInCache() > graphCount); + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, queryCount, K); + + clearCache(Collections.singletonList(testIndex)); + assertEquals(0, getTotalGraphsInCache()); + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, queryCount, K); + } + +} diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 6c92afabc0..a2e539375c 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -127,4 +127,7 @@ public class KNNConstants { // Please refer this github issue for more details for choosing this value: // https://github.com/opensearch-project/k-NN/issues/1049#issuecomment-1694741092 public static int MAX_DISTANCE_COMPUTATIONS = 2048000; + + // API Constants + public static final String CLEAR_CACHE = "clear_cache"; } diff --git a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java index f7597dce63..25c71f2271 100644 --- a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java +++ b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java @@ -8,9 +8,8 @@ import com.google.common.annotations.VisibleForTesting; import lombok.AllArgsConstructor; import lombok.Getter; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; @@ -20,6 +19,7 @@ import org.opensearch.index.engine.Engine; import org.opensearch.index.shard.IndexShard; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; @@ -30,6 +30,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; @@ -42,11 +43,11 @@ /** * KNNIndexShard wraps IndexShard and adds methods to perform k-NN related operations against the shard */ +@Log4j2 public class KNNIndexShard { private IndexShard indexShard; private NativeMemoryCacheManager nativeMemoryCacheManager; - - private static Logger logger = LogManager.getLogger(KNNIndexShard.class); + private static final String INDEX_SHARD_CLEAR_CACHE_SEARCHER = "knn-clear-cache"; /** * Constructor to generate KNNIndexShard. We do not perform validation that the index the shard is from @@ -84,7 +85,7 @@ public String getIndexName() { * @throws IOException Thrown when getting the HNSW Paths to be loaded in */ public void warmup() throws IOException { - logger.info("[KNN] Warming up index: " + getIndexName()); + log.info("[KNN] Warming up index: [{}]", getIndexName()); try (Engine.Searcher searcher = indexShard.acquireSearcher("knn-warmup")) { getAllEngineFileContexts(searcher.getIndexReader()).forEach((engineFileContext) -> { try { @@ -109,6 +110,34 @@ public void warmup() throws IOException { } } + /** + * Removes all the k-NN segments for this shard from the cache. + * Adding write lock onto the NativeMemoryAllocation of the index that needs to be evicted from cache. + * Write lock will be unlocked after the index is evicted. This locking mechanism is used to avoid + * conflicts with queries fired on this index when the index is being evicted from cache. + */ + public void clearCache() { + String indexName = getIndexName(); + Optional indexAllocationOptional; + NativeMemoryAllocation indexAllocation; + indexAllocationOptional = nativeMemoryCacheManager.getIndexMemoryAllocation(indexName); + if (indexAllocationOptional.isPresent()) { + indexAllocation = indexAllocationOptional.get(); + indexAllocation.writeLock(); + log.info("[KNN] Evicting index from cache: [{}]", indexName); + try (Engine.Searcher searcher = indexShard.acquireSearcher(INDEX_SHARD_CLEAR_CACHE_SEARCHER)) { + getAllEngineFileContexts(searcher.getIndexReader()).forEach( + (engineFileContext) -> nativeMemoryCacheManager.invalidate(engineFileContext.getIndexPath()) + ); + } catch (IOException ex) { + log.error("[KNN] Failed to evict index from cache: [{}]", indexName, ex); + throw new RuntimeException(ex); + } finally { + indexAllocation.writeUnlock(); + } + } + } + /** * For the given shard, get all of its engine paths * diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java index 8b3a3bce15..9478e1e006 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java @@ -27,6 +27,7 @@ import java.io.Closeable; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -303,6 +304,23 @@ public NativeMemoryAllocation get(NativeMemoryEntryContext nativeMemoryEntryC return cache.get(nativeMemoryEntryContext.getKey(), nativeMemoryEntryContext::load); } + /** + * Returns the NativeMemoryAllocation associated with given index + * @param indexName name of OpenSearch index + * @return NativeMemoryAllocation associated with given index + */ + public Optional getIndexMemoryAllocation(String indexName) { + Validate.notNull(indexName, "Index name cannot be null"); + return cache.asMap() + .values() + .stream() + .filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation) + .filter( + indexAllocation -> indexName.equals(((NativeMemoryAllocation.IndexAllocation) indexAllocation).getOpenSearchIndexName()) + ) + .findFirst(); + } + /** * Invalidate entry from the cache. * diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 3edff1879d..2e5a550923 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -31,6 +31,7 @@ import org.opensearch.knn.plugin.rest.RestKNNWarmupHandler; import org.opensearch.knn.plugin.rest.RestSearchModelHandler; import org.opensearch.knn.plugin.rest.RestTrainModelHandler; +import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; import org.opensearch.knn.plugin.stats.KNNStats; import org.opensearch.knn.plugin.transport.DeleteModelAction; @@ -41,6 +42,8 @@ import org.opensearch.knn.plugin.transport.KNNStatsTransportAction; import org.opensearch.knn.plugin.transport.KNNWarmupAction; import org.opensearch.knn.plugin.transport.KNNWarmupTransportAction; +import org.opensearch.knn.plugin.transport.ClearCacheAction; +import org.opensearch.knn.plugin.transport.ClearCacheTransportAction; import com.google.common.collect.ImmutableList; import org.opensearch.action.ActionRequest; @@ -236,6 +239,7 @@ public List getRestHandlers( RestDeleteModelHandler restDeleteModelHandler = new RestDeleteModelHandler(); RestTrainModelHandler restTrainModelHandler = new RestTrainModelHandler(); RestSearchModelHandler restSearchModelHandler = new RestSearchModelHandler(); + RestClearCacheHandler restClearCacheHandler = new RestClearCacheHandler(clusterService, indexNameExpressionResolver); return ImmutableList.of( restKNNStatsHandler, @@ -243,7 +247,8 @@ public List getRestHandlers( restGetModelHandler, restDeleteModelHandler, restTrainModelHandler, - restSearchModelHandler + restSearchModelHandler, + restClearCacheHandler ); } @@ -263,7 +268,8 @@ public List getRestHandlers( new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class), new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class), new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class), - new ActionHandler<>(UpdateModelGraveyardAction.INSTANCE, UpdateModelGraveyardTransportAction.class) + new ActionHandler<>(UpdateModelGraveyardAction.INSTANCE, UpdateModelGraveyardTransportAction.class), + new ActionHandler<>(ClearCacheAction.INSTANCE, ClearCacheTransportAction.class) ); } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestClearCacheHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestClearCacheHandler.java new file mode 100644 index 0000000000..2cbc9cd760 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestClearCacheHandler.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.rest; + +import com.google.common.collect.ImmutableList; +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.Strings; +import org.opensearch.core.index.Index; +import org.opensearch.knn.common.exception.KNNInvalidIndicesException; +import org.opensearch.knn.plugin.KNNPlugin; +import org.opensearch.knn.plugin.transport.ClearCacheAction; +import org.opensearch.knn.plugin.transport.ClearCacheRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.stream.Collectors; + +import static org.opensearch.action.support.IndicesOptions.strictExpandOpen; +import static org.opensearch.knn.common.KNNConstants.CLEAR_CACHE; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; + +/** + * RestHandler for k-NN Clear Cache API. API provides the ability for a user to evict those indices from Cache. + */ +@AllArgsConstructor +@Log4j2 +public class RestClearCacheHandler extends BaseRestHandler { + private static final String INDEX = "index"; + public static String NAME = "knn_clear_cache_action"; + private final ClusterService clusterService; + private final IndexNameExpressionResolver indexNameExpressionResolver; + + /** + * @return name of Clear Cache API action + */ + @Override + public String getName() { + return NAME; + } + + /** + * @return Immutable List of Clear Cache API endpoint + */ + @Override + public List routes() { + return ImmutableList.of( + new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/%s/{%s}", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, INDEX)) + ); + } + + /** + * @param request RestRequest + * @param client NodeClient + * @return RestChannelConsumer + */ + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + ClearCacheRequest clearCacheRequest = createClearCacheRequest(request); + log.info("[KNN] ClearCache started for the following indices: [{}]", String.join(",", clearCacheRequest.indices())); + return channel -> client.execute(ClearCacheAction.INSTANCE, clearCacheRequest, new RestToXContentListener<>(channel)); + } + + // Create a clear cache request by processing the rest request and validating the indices + private ClearCacheRequest createClearCacheRequest(RestRequest request) { + String[] indexNames = Strings.splitStringByCommaToArray(request.param("index")); + Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), strictExpandOpen(), indexNames); + validateIndices(indices); + + return new ClearCacheRequest(indexNames); + } + + // Validate if the given indices are k-NN indices or not. If there are any invalid indices, + // the request is rejected and an exception is thrown. + private void validateIndices(Index[] indices) { + List invalidIndexNames = Arrays.stream(indices) + .filter(index -> !"true".equals(clusterService.state().metadata().getIndexSafe(index).getSettings().get(KNN_INDEX))) + .map(Index::getName) + .collect(Collectors.toList()); + + if (!invalidIndexNames.isEmpty()) { + throw new KNNInvalidIndicesException( + invalidIndexNames, + "ClearCache request rejected. One or more indices have 'index.knn' set to false." + ); + } + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheAction.java b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheAction.java new file mode 100644 index 0000000000..358027a8bc --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheAction.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.core.common.io.stream.Writeable; + +/** + * Action associated with ClearCache + */ +public class ClearCacheAction extends ActionType { + + public static final ClearCacheAction INSTANCE = new ClearCacheAction(); + public static final String NAME = "cluster:admin/clear_cache_action"; + + private ClearCacheAction() { + super(NAME, ClearCacheResponse::new); + } + + @Override + public Writeable.Reader getResponseReader() { + return ClearCacheResponse::new; + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheRequest.java new file mode 100644 index 0000000000..029b9babec --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheRequest.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.support.broadcast.BroadcastRequest; +import org.opensearch.core.common.io.stream.StreamInput; + +import java.io.IOException; + +/** + * Clear Cache Request. This request contains a list of indices which needs to be evicted from Cache. + */ +public class ClearCacheRequest extends BroadcastRequest { + + /** + * Constructor + * + * @param in input stream + * @throws IOException if read from stream fails + */ + public ClearCacheRequest(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructor + * + * @param indices list of indices which needs to be evicted from cache + */ + public ClearCacheRequest(String... indices) { + super(indices); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheResponse.java new file mode 100644 index 0000000000..3da7b07384 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheResponse.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.core.action.support.DefaultShardOperationFailedException; +import org.opensearch.action.support.broadcast.BroadcastResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContentObject; + +import java.io.IOException; +import java.util.List; + +/** + * {@link ClearCacheResponse} represents Response returned by {@link ClearCacheRequest}. + * Returns total number of shards on which ClearCache was performed on, as well as + * the number of shards that succeeded and the number of shards that failed. + */ +public class ClearCacheResponse extends BroadcastResponse implements ToXContentObject { + + /** + * Constructor + * + * @param in input stream + * @throws IOException if read from stream fails + */ + public ClearCacheResponse(StreamInput in) throws IOException { + super(in); + } + + /** + * Constructor + * + * @param totalShards total number of shards on which ClearCache was performed + * @param successfulShards number of shards that succeeded + * @param failedShards number of shards that failed + * @param shardFailures list of shard failure exceptions + */ + public ClearCacheResponse( + int totalShards, + int successfulShards, + int failedShards, + List shardFailures + ) { + super(totalShards, successfulShards, failedShards, shardFailures); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheTransportAction.java new file mode 100644 index 0000000000..4a294e73b5 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/ClearCacheTransportAction.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.broadcast.node.TransportBroadcastByNodeAction; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlockException; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardsIterator; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.support.DefaultShardOperationFailedException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.index.Index; +import org.opensearch.index.IndexService; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.indices.IndicesService; +import org.opensearch.knn.index.KNNIndexShard; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.List; + +/** + * Transport Action to evict k-NN indices from Cache. TransportBroadcastByNodeAction will distribute the request to + * all shards across the cluster for the given indices. For each shard, shardOperation will be called and the + * indices will be cleared from cache. + */ +public class ClearCacheTransportAction extends TransportBroadcastByNodeAction< + ClearCacheRequest, + ClearCacheResponse, + TransportBroadcastByNodeAction.EmptyResult> { + + private IndicesService indicesService; + + /** + * Constructor + * + * @param clusterService ClusterService + * @param transportService TransportService + * @param actionFilters ActionFilters + * @param indexNameExpressionResolver IndexNameExpressionResolver + * @param indicesService IndicesService + */ + @Inject + public ClearCacheTransportAction( + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + IndicesService indicesService + ) { + super( + ClearCacheAction.NAME, + clusterService, + transportService, + actionFilters, + indexNameExpressionResolver, + ClearCacheRequest::new, + ThreadPool.Names.SEARCH + ); + this.indicesService = indicesService; + } + + /** + * @param streamInput StreamInput + * @return EmptyResult + * @throws IOException + */ + @Override + protected EmptyResult readShardResult(StreamInput streamInput) throws IOException { + return EmptyResult.readEmptyResultFrom(streamInput); + } + + /** + * @param request ClearCacheRequest + * @param totalShards total number of shards on which ClearCache was performed + * @param successfulShards number of shards that succeeded + * @param failedShards number of shards that failed + * @param emptyResults List of EmptyResult + * @param shardFailures list of shard failure exceptions + * @param clusterState ClusterState + * @return {@link ClearCacheResponse} + */ + @Override + protected ClearCacheResponse newResponse( + ClearCacheRequest request, + int totalShards, + int successfulShards, + int failedShards, + List emptyResults, + List shardFailures, + ClusterState clusterState + ) { + return new ClearCacheResponse(totalShards, successfulShards, failedShards, shardFailures); + } + + /** + * @param streamInput StreamInput + * @return {@link ClearCacheRequest} + * @throws IOException + */ + @Override + protected ClearCacheRequest readRequestFrom(StreamInput streamInput) throws IOException { + return new ClearCacheRequest(streamInput); + } + + /** + * Operation performed at a shard level on all the shards of given index where the index is removed from the cache. + * + * @param request ClearCacheRequest + * @param shardRouting ShardRouting of given shard + * @return EmptyResult + * @throws IOException + */ + @Override + protected EmptyResult shardOperation(ClearCacheRequest request, ShardRouting shardRouting) throws IOException { + Index index = shardRouting.shardId().getIndex(); + IndexService indexService = indicesService.indexServiceSafe(index); + IndexShard indexShard = indexService.getShard(shardRouting.shardId().id()); + KNNIndexShard knnIndexShard = new KNNIndexShard(indexShard); + knnIndexShard.clearCache(); + return EmptyResult.INSTANCE; + } + + /** + * @param clusterState ClusterState + * @param request ClearCacheRequest + * @param concreteIndices Indices in the request + * @return ShardsIterator with all the shards for given concrete indices + */ + @Override + protected ShardsIterator shards(ClusterState clusterState, ClearCacheRequest request, String[] concreteIndices) { + return clusterState.routingTable().allShards(concreteIndices); + } + + /** + * @param clusterState ClusterState + * @param request ClearCacheRequest + * @return ClusterBlockException if there is any global cluster block at a cluster block level of "METADATA_WRITE" + */ + @Override + protected ClusterBlockException checkGlobalBlock(ClusterState clusterState, ClearCacheRequest request) { + return clusterState.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } + + /** + * @param clusterState ClusterState + * @param request ClearCacheRequest + * @param concreteIndices Indices in the request + * @return ClusterBlockException if there is any cluster block on any of the given indices at a cluster block level of "METADATA_WRITE" + */ + @Override + protected ClusterBlockException checkRequestBlock(ClusterState clusterState, ClearCacheRequest request, String[] concreteIndices) { + return clusterState.blocks().indicesBlockedException(ClusterBlockLevel.METADATA_WRITE, concreteIndices); + } +} diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index 323442fff4..1c9e75c3af 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -6,6 +6,12 @@ package org.opensearch.knn; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlock; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.block.ClusterBlocks; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; @@ -32,9 +38,12 @@ import java.io.IOException; import java.util.Collection; import java.util.Collections; +import java.util.EnumSet; import java.util.Map; import java.util.concurrent.ExecutionException; +import static org.mockito.Mockito.when; + public class KNNSingleNodeTestCase extends OpenSearchSingleNodeTestCase { @Override public void setUp() throws Exception { @@ -208,4 +217,25 @@ public void assertTrainingSucceeds(ModelDao modelDao, String modelId, int attemp fail("Training did not succeed after " + attempts + " attempts with a delay of " + delayInMillis + " ms."); } + + // Add Global Cluster Block with the given ClusterBlockLevel + protected void addGlobalClusterBlock(ClusterService clusterService, String description, EnumSet clusterBlockLevels) { + ClusterBlock block = new ClusterBlock(randomInt(), description, false, false, false, RestStatus.FORBIDDEN, clusterBlockLevels); + ClusterBlocks clusterBlocks = ClusterBlocks.builder().addGlobalBlock(block).build(); + ClusterState state = ClusterState.builder(ClusterName.DEFAULT).blocks(clusterBlocks).build(); + when(clusterService.state()).thenReturn(state); + } + + // Add Cluster Block for an Index with given ClusterBlockLevel + protected void addIndexClusterBlock( + ClusterService clusterService, + String description, + EnumSet clusterBlockLevels, + String testIndex + ) { + ClusterBlock block = new ClusterBlock(randomInt(), description, false, false, false, RestStatus.FORBIDDEN, clusterBlockLevels); + ClusterBlocks clusterBlocks = ClusterBlocks.builder().addIndexBlock(testIndex, block).build(); + ClusterState state = ClusterState.builder(ClusterName.DEFAULT).blocks(clusterBlocks).build(); + when(clusterService.state()).thenReturn(state); + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java b/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java index 079d18e667..7b6f96d5a8 100644 --- a/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import lombok.SneakyThrows; import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.index.IndexService; import org.opensearch.index.engine.Engine; @@ -155,4 +156,30 @@ public void testGetEngineFileContexts() { assertEquals(includedFileNames.size(), included.size()); included.stream().map(KNNIndexShard.EngineFileContext::getIndexPath).forEach(o -> assertTrue(includedFileNames.contains(o))); } + + @SneakyThrows + public void testClearCache_emptyIndex() { + IndexService indexService = createKNNIndex(testIndexName); + createKnnIndexMapping(testIndexName, testFieldName, dimensions); + + IndexShard indexShard = indexService.iterator().next(); + KNNIndexShard knnIndexShard = new KNNIndexShard(indexShard); + knnIndexShard.clearCache(); + assertNull(NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName)); + } + + @SneakyThrows + public void testClearCache_shardPresentInCache() { + IndexService indexService = createKNNIndex(testIndexName); + createKnnIndexMapping(testIndexName, testFieldName, dimensions); + addKnnDoc(testIndexName, String.valueOf(randomInt()), testFieldName, new Float[] { randomFloat(), randomFloat() }); + + IndexShard indexShard = indexService.iterator().next(); + KNNIndexShard knnIndexShard = new KNNIndexShard(indexShard); + knnIndexShard.warmup(); + assertEquals(1, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName).get(GRAPH_COUNT)); + + knnIndexShard.clearCache(); + assertNull(NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName)); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestClearCacheHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestClearCacheHandlerIT.java new file mode 100644 index 0000000000..e7b454ebbe --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/action/RestClearCacheHandlerIT.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.action; + +import lombok.SneakyThrows; +import org.opensearch.client.Request; +import org.opensearch.client.ResponseException; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.plugin.KNNPlugin; +import org.opensearch.rest.RestRequest; + +import java.util.Arrays; +import java.util.Collections; + +import static org.opensearch.knn.common.KNNConstants.CLEAR_CACHE; + +/** + * Integration tests to validate ClearCache API + */ + +public class RestClearCacheHandlerIT extends KNNRestTestCase { + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 2; + private static final String ALL_INDICES = "_all"; + + @SneakyThrows + public void testNonExistentIndex() { + String nonExistentIndex = "non-existent-index"; + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, nonExistentIndex); + Request request = new Request(RestRequest.Method.POST.name(), restURI); + + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertTrue(ex.getMessage().contains(nonExistentIndex)); + } + + @SneakyThrows + public void testNotKnnIndex() { + String notKNNIndex = "not-knn-index"; + createIndex(notKNNIndex, Settings.EMPTY); + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, notKNNIndex); + Request request = new Request(RestRequest.Method.POST.name(), restURI); + + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertTrue(ex.getMessage().contains(notKNNIndex)); + } + + @SneakyThrows + public void testClearCacheSingleIndex() { + String testIndex = getTestName().toLowerCase(); + int graphCountBefore = getTotalGraphsInCache(); + createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + addKnnDoc(testIndex, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + + knnWarmup(Collections.singletonList(testIndex)); + + assertEquals(graphCountBefore + 1, getTotalGraphsInCache()); + + clearCache(Collections.singletonList(testIndex)); + assertEquals(graphCountBefore, getTotalGraphsInCache()); + } + + @SneakyThrows + public void testClearCacheMultipleIndices() { + String testIndex1 = getTestName().toLowerCase(); + String testIndex2 = getTestName().toLowerCase() + 1; + int graphCountBefore = getTotalGraphsInCache(); + + createKnnIndex(testIndex1, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + addKnnDoc(testIndex1, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + + createKnnIndex(testIndex2, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + addKnnDoc(testIndex2, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + + knnWarmup(Arrays.asList(testIndex1, testIndex2)); + + assertEquals(graphCountBefore + 2, getTotalGraphsInCache()); + + clearCache(Arrays.asList(ALL_INDICES)); + assertEquals(graphCountBefore, getTotalGraphsInCache()); + } + + @SneakyThrows + public void testClearCacheMultipleIndicesWithPatterns() { + String testIndex1 = getTestName().toLowerCase(); + String testIndex2 = getTestName().toLowerCase() + 1; + String testIndex3 = "abc" + getTestName().toLowerCase(); + int graphCountBefore = getTotalGraphsInCache(); + + createKnnIndex(testIndex1, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + addKnnDoc(testIndex1, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + + createKnnIndex(testIndex2, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + addKnnDoc(testIndex2, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + + createKnnIndex(testIndex3, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + addKnnDoc(testIndex3, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + + knnWarmup(Arrays.asList(testIndex1, testIndex2, testIndex3)); + + assertEquals(graphCountBefore + 3, getTotalGraphsInCache()); + String indexPattern = getTestName().toLowerCase() + "*"; + + clearCache(Arrays.asList(indexPattern)); + assertEquals(graphCountBefore + 1, getTotalGraphsInCache()); + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/transport/ClearCacheTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/ClearCacheTransportActionTests.java new file mode 100644 index 0000000000..3222a3eb74 --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/transport/ClearCacheTransportActionTests.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import lombok.SneakyThrows; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlockException; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardsIterator; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.index.IndexService; +import org.opensearch.knn.KNNSingleNodeTestCase; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; + +import java.util.EnumSet; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ClearCacheTransportActionTests extends KNNSingleNodeTestCase { + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 2; + + @SneakyThrows + public void testShardOperation() { + String testIndex = getTestName().toLowerCase(); + KNNWarmupRequest knnWarmupRequest = new KNNWarmupRequest(testIndex); + KNNWarmupTransportAction knnWarmupTransportAction = node().injector().getInstance(KNNWarmupTransportAction.class); + assertEquals(0, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().size()); + + IndexService indexService = createKNNIndex(testIndex); + createKnnIndexMapping(testIndex, TEST_FIELD, DIMENSIONS); + addKnnDoc(testIndex, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + ShardRouting shardRouting = indexService.iterator().next().routingEntry(); + + knnWarmupTransportAction.shardOperation(knnWarmupRequest, shardRouting); + assertEquals(1, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().size()); + + ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex); + ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class); + clearCacheTransportAction.shardOperation(clearCacheRequest, shardRouting); + assertEquals(0, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().size()); + } + + @SneakyThrows + public void testShards() { + String testIndex = getTestName().toLowerCase(); + ClusterService clusterService = node().injector().getInstance(ClusterService.class); + ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class); + ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex); + + createKNNIndex(testIndex); + createKnnIndexMapping(testIndex, TEST_FIELD, DIMENSIONS); + addKnnDoc(testIndex, String.valueOf(randomInt()), TEST_FIELD, new Float[] { randomFloat(), randomFloat() }); + + ShardsIterator shardsIterator = clearCacheTransportAction.shards( + clusterService.state(), + clearCacheRequest, + new String[] { testIndex } + ); + assertEquals(1, shardsIterator.size()); + } + + public void testCheckGlobalBlock_throwsClusterBlockException() { + String testIndex = getTestName().toLowerCase(); + String description = "testing metadata block"; + ClusterService clusterService = mock(ClusterService.class); + addGlobalClusterBlock(clusterService, description, EnumSet.of(ClusterBlockLevel.METADATA_WRITE)); + ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class); + ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex); + ClusterBlockException ex = clearCacheTransportAction.checkGlobalBlock(clusterService.state(), clearCacheRequest); + assertTrue(ex.getMessage().contains(description)); + } + + public void testCheckGlobalBlock_notThrowsClusterBlockException() { + String testIndex = getTestName().toLowerCase(); + ClusterService clusterService = mock(ClusterService.class); + ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class); + ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex); + ClusterState state = ClusterState.builder(ClusterName.DEFAULT).build(); + when(clusterService.state()).thenReturn(state); + assertNull(clearCacheTransportAction.checkGlobalBlock(clusterService.state(), clearCacheRequest)); + } + + public void testCheckRequestBlock_throwsClusterBlockException() { + String testIndex = getTestName().toLowerCase(); + String description = "testing index metadata block"; + ClusterService clusterService = mock(ClusterService.class); + addIndexClusterBlock(clusterService, description, EnumSet.of(ClusterBlockLevel.METADATA_WRITE), testIndex); + + ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class); + ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex); + ClusterBlockException ex = clearCacheTransportAction.checkRequestBlock( + clusterService.state(), + clearCacheRequest, + new String[] { testIndex } + ); + assertTrue(ex.getMessage().contains(testIndex)); + assertTrue(ex.getMessage().contains(description)); + + } + + public void testCheckRequestBlock_notThrowsClusterBlockException() { + String testIndex = getTestName().toLowerCase(); + ClusterService clusterService = mock(ClusterService.class); + ClearCacheTransportAction clearCacheTransportAction = node().injector().getInstance(ClearCacheTransportAction.class); + ClearCacheRequest clearCacheRequest = new ClearCacheRequest(testIndex); + ClusterState state = ClusterState.builder(ClusterName.DEFAULT).build(); + when(clusterService.state()).thenReturn(state); + assertNull(clearCacheTransportAction.checkRequestBlock(clusterService.state(), clearCacheRequest, new String[] { testIndex })); + } +} diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 60c4016480..0331d49e5e 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -84,6 +84,7 @@ import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.CLEAR_CACHE; import static org.opensearch.knn.TestUtils.NUMBER_OF_REPLICAS; import static org.opensearch.knn.TestUtils.NUMBER_OF_SHARDS; @@ -692,6 +693,20 @@ protected Response executeWarmupRequest(List indices, final String baseU return client().performRequest(request); } + /** + * Evicts valid k-NN indices from the cache. + * + * @param indices list of k-NN indices that needs to be removed from cache + * @return Response of clear Cache API request + * @throws IOException + */ + protected Response clearCache(List indices) throws IOException { + String indicesSuffix = String.join(",", indices); + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, indicesSuffix); + Request request = new Request("POST", restURI); + return client().performRequest(request); + } + /** * Parse KNN Cluster stats from response */