Skip to content

Commit

Permalink
[k-NN] Add Clear Cache API (#740)
Browse files Browse the repository at this point in the history
* Add Clear Cache API

Signed-off-by: Naveen Tatikonda <[email protected]>

* Add Unit and Integration tests

Signed-off-by: Naveen Tatikonda <[email protected]>

* Add BWC Tests

Signed-off-by: Naveen Tatikonda <[email protected]>

* Add CHANGELOG

Signed-off-by: Naveen Tatikonda <[email protected]>

* Address Review Comments

Signed-off-by: Naveen Tatikonda <[email protected]>

---------

Signed-off-by: Naveen Tatikonda <[email protected]>
(cherry picked from commit 12f4a51)
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Apr 12, 2024
1 parent cee100f commit b8abb56
Show file tree
Hide file tree
Showing 17 changed files with 843 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x)
### Features
* Add Clear Cache API [#740](https://github.com/opensearch-project/k-NN/pull/740)
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.bwc;

import java.util.Collections;
import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER;

public class ClearCacheIT extends AbstractRestartUpgradeTestCase {
private static final String TEST_FIELD = "test-field";
private static final int DIMENSIONS = 5;
private static int docId = 0;
private static final int NUM_DOCS = 10;
private static int queryCnt = 0;
private static final int K = 5;

// Restart Upgrade BWC Tests to validate Clear Cache API
public void testClearCache() throws Exception {
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
if (isRunningAgainstOldCluster()) {
createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docId, NUM_DOCS);
} else {
queryCnt = NUM_DOCS;
validateClearCacheOnUpgrade(queryCnt);

docId = NUM_DOCS;
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docId, NUM_DOCS);

queryCnt = queryCnt + NUM_DOCS;
validateClearCacheOnUpgrade(queryCnt);
deleteKNNIndex(testIndex);
}
}

// validation steps for Clear Cache API after upgrading node to new version
private void validateClearCacheOnUpgrade(int queryCount) throws Exception {
int graphCount = getTotalGraphsInCache();
knnWarmup(Collections.singletonList(testIndex));
assertTrue(getTotalGraphsInCache() > graphCount);
validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, queryCount, K);

clearCache(Collections.singletonList(testIndex));
assertEquals(0, getTotalGraphsInCache());
validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, queryCount, K);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.bwc;

import java.util.Collections;

import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER;

public class ClearCacheIT extends AbstractRollingUpgradeTestCase {
private static final String TEST_FIELD = "test-field";
private static final int DIMENSIONS = 5;
private static int docId = 0;
private static final int K = 5;
private static final int NUM_DOCS = 10;
private static int queryCnt = 0;

// Rolling Upgrade BWC Tests to validate Clear Cache API
public void testClearCache() throws Exception {
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
switch (getClusterType()) {
case OLD:
createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
int docIdOld = 0;
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docIdOld, NUM_DOCS);
break;
case UPGRADED:
queryCnt = NUM_DOCS;
validateClearCacheOnUpgrade(queryCnt);

docId = NUM_DOCS;
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docId, NUM_DOCS);

queryCnt = queryCnt + NUM_DOCS;
validateClearCacheOnUpgrade(queryCnt);
deleteKNNIndex(testIndex);
}

}

// validation steps for Clear Cache API after upgrading all nodes from old version to new version
public void validateClearCacheOnUpgrade(int queryCount) throws Exception {
int graphCount = getTotalGraphsInCache();
knnWarmup(Collections.singletonList(testIndex));
assertTrue(getTotalGraphsInCache() > graphCount);
validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, queryCount, K);

clearCache(Collections.singletonList(testIndex));
assertEquals(0, getTotalGraphsInCache());
validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, queryCount, K);
}

}
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,7 @@ public class KNNConstants {
// Please refer this github issue for more details for choosing this value:
// https://github.com/opensearch-project/k-NN/issues/1049#issuecomment-1694741092
public static int MAX_DISTANCE_COMPUTATIONS = 2048000;

// API Constants
public static final String CLEAR_CACHE = "clear_cache";
}
39 changes: 34 additions & 5 deletions src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import com.google.common.annotations.VisibleForTesting;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
Expand All @@ -20,6 +19,7 @@
import org.opensearch.index.engine.Engine;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
Expand All @@ -30,6 +30,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

Expand All @@ -42,11 +43,11 @@
/**
* KNNIndexShard wraps IndexShard and adds methods to perform k-NN related operations against the shard
*/
@Log4j2
public class KNNIndexShard {
private IndexShard indexShard;
private NativeMemoryCacheManager nativeMemoryCacheManager;

private static Logger logger = LogManager.getLogger(KNNIndexShard.class);
private static final String INDEX_SHARD_CLEAR_CACHE_SEARCHER = "knn-clear-cache";

/**
* Constructor to generate KNNIndexShard. We do not perform validation that the index the shard is from
Expand Down Expand Up @@ -84,7 +85,7 @@ public String getIndexName() {
* @throws IOException Thrown when getting the HNSW Paths to be loaded in
*/
public void warmup() throws IOException {
logger.info("[KNN] Warming up index: " + getIndexName());
log.info("[KNN] Warming up index: [{}]", getIndexName());
try (Engine.Searcher searcher = indexShard.acquireSearcher("knn-warmup")) {
getAllEngineFileContexts(searcher.getIndexReader()).forEach((engineFileContext) -> {
try {
Expand All @@ -109,6 +110,34 @@ public void warmup() throws IOException {
}
}

/**
* Removes all the k-NN segments for this shard from the cache.
* Adding write lock onto the NativeMemoryAllocation of the index that needs to be evicted from cache.
* Write lock will be unlocked after the index is evicted. This locking mechanism is used to avoid
* conflicts with queries fired on this index when the index is being evicted from cache.
*/
public void clearCache() {
String indexName = getIndexName();
Optional<NativeMemoryAllocation> indexAllocationOptional;
NativeMemoryAllocation indexAllocation;
indexAllocationOptional = nativeMemoryCacheManager.getIndexMemoryAllocation(indexName);
if (indexAllocationOptional.isPresent()) {
indexAllocation = indexAllocationOptional.get();
indexAllocation.writeLock();
log.info("[KNN] Evicting index from cache: [{}]", indexName);
try (Engine.Searcher searcher = indexShard.acquireSearcher(INDEX_SHARD_CLEAR_CACHE_SEARCHER)) {
getAllEngineFileContexts(searcher.getIndexReader()).forEach(
(engineFileContext) -> nativeMemoryCacheManager.invalidate(engineFileContext.getIndexPath())
);
} catch (IOException ex) {
log.error("[KNN] Failed to evict index from cache: [{}]", indexName, ex);
throw new RuntimeException(ex);
} finally {
indexAllocation.writeUnlock();
}
}
}

/**
* For the given shard, get all of its engine paths
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.io.Closeable;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -303,6 +304,23 @@ public NativeMemoryAllocation get(NativeMemoryEntryContext<?> nativeMemoryEntryC
return cache.get(nativeMemoryEntryContext.getKey(), nativeMemoryEntryContext::load);
}

/**
* Returns the NativeMemoryAllocation associated with given index
* @param indexName name of OpenSearch index
* @return NativeMemoryAllocation associated with given index
*/
public Optional<NativeMemoryAllocation> getIndexMemoryAllocation(String indexName) {
Validate.notNull(indexName, "Index name cannot be null");
return cache.asMap()
.values()
.stream()
.filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation)
.filter(
indexAllocation -> indexName.equals(((NativeMemoryAllocation.IndexAllocation) indexAllocation).getOpenSearchIndexName())
)
.findFirst();
}

/**
* Invalidate entry from the cache.
*
Expand Down
10 changes: 8 additions & 2 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.knn.plugin.rest.RestKNNWarmupHandler;
import org.opensearch.knn.plugin.rest.RestSearchModelHandler;
import org.opensearch.knn.plugin.rest.RestTrainModelHandler;
import org.opensearch.knn.plugin.rest.RestClearCacheHandler;
import org.opensearch.knn.plugin.script.KNNScoringScriptEngine;
import org.opensearch.knn.plugin.stats.KNNStats;
import org.opensearch.knn.plugin.transport.DeleteModelAction;
Expand All @@ -41,6 +42,8 @@
import org.opensearch.knn.plugin.transport.KNNStatsTransportAction;
import org.opensearch.knn.plugin.transport.KNNWarmupAction;
import org.opensearch.knn.plugin.transport.KNNWarmupTransportAction;
import org.opensearch.knn.plugin.transport.ClearCacheAction;
import org.opensearch.knn.plugin.transport.ClearCacheTransportAction;
import com.google.common.collect.ImmutableList;

import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -236,14 +239,16 @@ public List<RestHandler> getRestHandlers(
RestDeleteModelHandler restDeleteModelHandler = new RestDeleteModelHandler();
RestTrainModelHandler restTrainModelHandler = new RestTrainModelHandler();
RestSearchModelHandler restSearchModelHandler = new RestSearchModelHandler();
RestClearCacheHandler restClearCacheHandler = new RestClearCacheHandler(clusterService, indexNameExpressionResolver);

return ImmutableList.of(
restKNNStatsHandler,
restKNNWarmupHandler,
restGetModelHandler,
restDeleteModelHandler,
restTrainModelHandler,
restSearchModelHandler
restSearchModelHandler,
restClearCacheHandler
);
}

Expand All @@ -263,7 +268,8 @@ public List<RestHandler> getRestHandlers(
new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class),
new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class),
new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class),
new ActionHandler<>(UpdateModelGraveyardAction.INSTANCE, UpdateModelGraveyardTransportAction.class)
new ActionHandler<>(UpdateModelGraveyardAction.INSTANCE, UpdateModelGraveyardTransportAction.class),
new ActionHandler<>(ClearCacheAction.INSTANCE, ClearCacheTransportAction.class)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin.rest;

import com.google.common.collect.ImmutableList;
import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.common.Strings;
import org.opensearch.core.index.Index;
import org.opensearch.knn.common.exception.KNNInvalidIndicesException;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.knn.plugin.transport.ClearCacheAction;
import org.opensearch.knn.plugin.transport.ClearCacheRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;

import static org.opensearch.action.support.IndicesOptions.strictExpandOpen;
import static org.opensearch.knn.common.KNNConstants.CLEAR_CACHE;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;

/**
* RestHandler for k-NN Clear Cache API. API provides the ability for a user to evict those indices from Cache.
*/
@AllArgsConstructor
@Log4j2
public class RestClearCacheHandler extends BaseRestHandler {
private static final String INDEX = "index";
public static String NAME = "knn_clear_cache_action";
private final ClusterService clusterService;
private final IndexNameExpressionResolver indexNameExpressionResolver;

/**
* @return name of Clear Cache API action
*/
@Override
public String getName() {
return NAME;
}

/**
* @return Immutable List of Clear Cache API endpoint
*/
@Override
public List<Route> routes() {
return ImmutableList.of(
new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/%s/{%s}", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, INDEX))
);
}

/**
* @param request RestRequest
* @param client NodeClient
* @return RestChannelConsumer
*/
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
ClearCacheRequest clearCacheRequest = createClearCacheRequest(request);
log.info("[KNN] ClearCache started for the following indices: [{}]", String.join(",", clearCacheRequest.indices()));
return channel -> client.execute(ClearCacheAction.INSTANCE, clearCacheRequest, new RestToXContentListener<>(channel));
}

// Create a clear cache request by processing the rest request and validating the indices
private ClearCacheRequest createClearCacheRequest(RestRequest request) {
String[] indexNames = Strings.splitStringByCommaToArray(request.param("index"));
Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), strictExpandOpen(), indexNames);
validateIndices(indices);

return new ClearCacheRequest(indexNames);
}

// Validate if the given indices are k-NN indices or not. If there are any invalid indices,
// the request is rejected and an exception is thrown.
private void validateIndices(Index[] indices) {
List<String> invalidIndexNames = Arrays.stream(indices)
.filter(index -> !"true".equals(clusterService.state().metadata().getIndexSafe(index).getSettings().get(KNN_INDEX)))
.map(Index::getName)
.collect(Collectors.toList());

if (!invalidIndexNames.isEmpty()) {
throw new KNNInvalidIndicesException(
invalidIndexNames,
"ClearCache request rejected. One or more indices have 'index.knn' set to false."
);
}
}
}
Loading

0 comments on commit b8abb56

Please sign in to comment.