Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[k-NN] Add Clear Cache API #740

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.9...2.x)
### Features
* Add Clear Cache API [#740](https://github.com/opensearch-project/k-NN/pull/740)
### Enhancements
* Enabled the IVF algorithm to work with Filters of K-NN Query. [#1013](https://github.com/opensearch-project/k-NN/pull/1013)
### Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.bwc;

import java.util.Collections;
import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER;

public class ClearCacheIT extends AbstractRestartUpgradeTestCase {
private static final String TEST_FIELD = "test-field";
private static final int DIMENSIONS = 5;
private static int docId = 0;
private static final int NUM_DOCS = 10;
private static int queryCnt = 0;
private static final int K = 5;

// Restart Upgrade BWC Tests to validate Clear Cache API
public void testClearCache() throws Exception {
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
if (isRunningAgainstOldCluster()) {
createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docId, NUM_DOCS);
} else {
queryCnt = NUM_DOCS;
validateClearCacheOnUpgrade(queryCnt);

docId = NUM_DOCS;
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docId, NUM_DOCS);

queryCnt = queryCnt + NUM_DOCS;
validateClearCacheOnUpgrade(queryCnt);
deleteKNNIndex(testIndex);
}
}

// validation steps for Clear Cache API after upgrading node to new version
private void validateClearCacheOnUpgrade(int queryCount) throws Exception {
int graphCount = getTotalGraphsInCache();
knnWarmup(Collections.singletonList(testIndex));
assertTrue(getTotalGraphsInCache() > graphCount);
validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, queryCount, K);

clearCache(Collections.singletonList(testIndex));
assertEquals(0, getTotalGraphsInCache());
validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, queryCount, K);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.bwc;

import java.util.Collections;

import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER;

public class ClearCacheIT extends AbstractRollingUpgradeTestCase {
private static final String TEST_FIELD = "test-field";
private static final int DIMENSIONS = 5;
private static int docId = 0;
private static final int K = 5;
private static final int NUM_DOCS = 10;
private static int queryCnt = 0;

// Rolling Upgrade BWC Tests to validate Clear Cache API
public void testClearCache() throws Exception {
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
switch (getClusterType()) {
case OLD:
createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
int docIdOld = 0;
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docIdOld, NUM_DOCS);
break;
case UPGRADED:
queryCnt = NUM_DOCS;
validateClearCacheOnUpgrade(queryCnt);

docId = NUM_DOCS;
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docId, NUM_DOCS);

queryCnt = queryCnt + NUM_DOCS;
validateClearCacheOnUpgrade(queryCnt);
deleteKNNIndex(testIndex);
}

}

// validation steps for Clear Cache API after upgrading all nodes from old version to new version
public void validateClearCacheOnUpgrade(int queryCount) throws Exception {
int graphCount = getTotalGraphsInCache();
knnWarmup(Collections.singletonList(testIndex));
assertTrue(getTotalGraphsInCache() > graphCount);
validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, queryCount, K);

clearCache(Collections.singletonList(testIndex));
assertEquals(0, getTotalGraphsInCache());
validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, queryCount, K);
}

}
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,7 @@ public class KNNConstants {
private static final String JNI_LIBRARY_PREFIX = "opensearchknn_";
public static final String FAISS_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + FAISS_NAME;
public static final String NMSLIB_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + NMSLIB_NAME;

// API Constants
public static final String CLEAR_CACHE = "clear_cache";
}
37 changes: 32 additions & 5 deletions src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

package org.opensearch.knn.index;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
Expand All @@ -17,6 +16,7 @@
import org.opensearch.index.engine.Engine;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
Expand All @@ -27,6 +27,7 @@
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

Expand All @@ -38,11 +39,11 @@
/**
* KNNIndexShard wraps IndexShard and adds methods to perform k-NN related operations against the shard
*/
@Log4j2
public class KNNIndexShard {
private IndexShard indexShard;
private NativeMemoryCacheManager nativeMemoryCacheManager;

private static Logger logger = LogManager.getLogger(KNNIndexShard.class);
private static final String INDEX_SHARD_CLEAR_CACHE_SEARCHER = "knn-clear-cache";

/**
* Constructor to generate KNNIndexShard. We do not perform validation that the index the shard is from
Expand Down Expand Up @@ -80,7 +81,7 @@ public String getIndexName() {
* @throws IOException Thrown when getting the HNSW Paths to be loaded in
*/
public void warmup() throws IOException {
logger.info("[KNN] Warming up index: " + getIndexName());
log.info("[KNN] Warming up index: [{}]", getIndexName());
try (Engine.Searcher searcher = indexShard.acquireSearcher("knn-warmup")) {
getAllEnginePaths(searcher.getIndexReader()).forEach((key, value) -> {
try {
Expand All @@ -100,6 +101,32 @@ public void warmup() throws IOException {
}
}

/**
* Removes all the k-NN segments for this shard from the cache.
* Adding write lock onto the NativeMemoryAllocation of the index that needs to be evicted from cache.
* Write lock will be unlocked after the index is evicted. This locking mechanism is used to avoid
* conflicts with queries fired on this index when the index is being evicted from cache.
*/
public void clearCache() {
String indexName = getIndexName();
Optional<NativeMemoryAllocation> indexAllocationOptional;
NativeMemoryAllocation indexAllocation;
indexAllocationOptional = nativeMemoryCacheManager.getIndexMemoryAllocation(indexName);
if (indexAllocationOptional.isPresent()) {
indexAllocation = indexAllocationOptional.get();
indexAllocation.writeLock();
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
log.info("[KNN] Evicting index from cache: [{}]", indexName);
try (Engine.Searcher searcher = indexShard.acquireSearcher(INDEX_SHARD_CLEAR_CACHE_SEARCHER)) {
getAllEnginePaths(searcher.getIndexReader()).forEach((key, value) -> nativeMemoryCacheManager.invalidate(key));
} catch (IOException ex) {
log.error("[KNN] Failed to evict index from cache: [{}]", indexName, ex);
throw new RuntimeException(ex);
} finally {
indexAllocation.writeUnlock();
}
}
}

/**
* For the given shard, get all of its engine paths
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.io.Closeable;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -303,6 +304,23 @@ public NativeMemoryAllocation get(NativeMemoryEntryContext<?> nativeMemoryEntryC
return cache.get(nativeMemoryEntryContext.getKey(), nativeMemoryEntryContext::load);
}

/**
* Returns the NativeMemoryAllocation associated with given index
* @param indexName name of OpenSearch index
* @return NativeMemoryAllocation associated with given index
*/
public Optional<NativeMemoryAllocation> getIndexMemoryAllocation(String indexName) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think our convention is to have params in public methods as final. not a critical though, just for future PRs

Validate.notNull(indexName, "Index name cannot be null");
return cache.asMap()
.values()
.stream()
.filter(nativeMemoryAllocation -> nativeMemoryAllocation instanceof NativeMemoryAllocation.IndexAllocation)
.filter(
indexAllocation -> indexName.equals(((NativeMemoryAllocation.IndexAllocation) indexAllocation).getOpenSearchIndexName())
)
.findFirst();
}

/**
* Invalidate entry from the cache.
*
Expand Down
10 changes: 8 additions & 2 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.knn.plugin.rest.RestKNNWarmupHandler;
import org.opensearch.knn.plugin.rest.RestSearchModelHandler;
import org.opensearch.knn.plugin.rest.RestTrainModelHandler;
import org.opensearch.knn.plugin.rest.RestClearCacheHandler;
import org.opensearch.knn.plugin.script.KNNScoringScriptEngine;
import org.opensearch.knn.plugin.stats.KNNStats;
import org.opensearch.knn.plugin.transport.DeleteModelAction;
Expand All @@ -41,6 +42,8 @@
import org.opensearch.knn.plugin.transport.KNNStatsTransportAction;
import org.opensearch.knn.plugin.transport.KNNWarmupAction;
import org.opensearch.knn.plugin.transport.KNNWarmupTransportAction;
import org.opensearch.knn.plugin.transport.ClearCacheAction;
import org.opensearch.knn.plugin.transport.ClearCacheTransportAction;
import com.google.common.collect.ImmutableList;

import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -231,14 +234,16 @@ public List<RestHandler> getRestHandlers(
RestDeleteModelHandler restDeleteModelHandler = new RestDeleteModelHandler();
RestTrainModelHandler restTrainModelHandler = new RestTrainModelHandler();
RestSearchModelHandler restSearchModelHandler = new RestSearchModelHandler();
RestClearCacheHandler restClearCacheHandler = new RestClearCacheHandler(clusterService, indexNameExpressionResolver);

return ImmutableList.of(
restKNNStatsHandler,
restKNNWarmupHandler,
restGetModelHandler,
restDeleteModelHandler,
restTrainModelHandler,
restSearchModelHandler
restSearchModelHandler,
restClearCacheHandler
);
}

Expand All @@ -258,7 +263,8 @@ public List<RestHandler> getRestHandlers(
new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class),
new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class),
new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class),
new ActionHandler<>(UpdateModelGraveyardAction.INSTANCE, UpdateModelGraveyardTransportAction.class)
new ActionHandler<>(UpdateModelGraveyardAction.INSTANCE, UpdateModelGraveyardTransportAction.class),
new ActionHandler<>(ClearCacheAction.INSTANCE, ClearCacheTransportAction.class)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin.rest;

import com.google.common.collect.ImmutableList;
import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.common.Strings;
import org.opensearch.core.index.Index;
import org.opensearch.knn.common.exception.KNNInvalidIndicesException;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.knn.plugin.transport.ClearCacheAction;
import org.opensearch.knn.plugin.transport.ClearCacheRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;

import static org.opensearch.action.support.IndicesOptions.strictExpandOpen;
import static org.opensearch.knn.common.KNNConstants.CLEAR_CACHE;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;

/**
* RestHandler for k-NN Clear Cache API. API provides the ability for a user to evict those indices from Cache.
*/
@AllArgsConstructor
@Log4j2
public class RestClearCacheHandler extends BaseRestHandler {
private static final String INDEX = "index";
public static String NAME = "knn_clear_cache_action";
private final ClusterService clusterService;
private final IndexNameExpressionResolver indexNameExpressionResolver;

/**
* @return name of Clear Cache API action
*/
@Override
public String getName() {
return NAME;
}

/**
* @return Immutable List of Clear Cache API endpoint
*/
@Override
public List<Route> routes() {
return ImmutableList.of(
new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/%s/{%s}", KNNPlugin.KNN_BASE_URI, CLEAR_CACHE, INDEX))
);
}

/**
* @param request RestRequest
* @param client NodeClient
* @return RestChannelConsumer
*/
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
ClearCacheRequest clearCacheRequest = createClearCacheRequest(request);
log.info("[KNN] ClearCache started for the following indices: [{}]", String.join(",", clearCacheRequest.indices()));
return channel -> client.execute(ClearCacheAction.INSTANCE, clearCacheRequest, new RestToXContentListener<>(channel));
}

// Create a clear cache request by processing the rest request and validating the indices
private ClearCacheRequest createClearCacheRequest(RestRequest request) {
String[] indexNames = Strings.splitStringByCommaToArray(request.param("index"));
Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), strictExpandOpen(), indexNames);
validateIndices(indices);

return new ClearCacheRequest(indexNames);
}

// Validate if the given indices are k-NN indices or not. If there are any invalid indices,
// the request is rejected and an exception is thrown.
private void validateIndices(Index[] indices) {
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
List<String> invalidIndexNames = Arrays.stream(indices)
.filter(index -> !"true".equals(clusterService.state().metadata().getIndexSafe(index).getSettings().get(KNN_INDEX)))
.map(Index::getName)
.collect(Collectors.toList());

if (!invalidIndexNames.isEmpty()) {
throw new KNNInvalidIndicesException(
invalidIndexNames,
"ClearCache request rejected. One or more indices have 'index.knn' set to false."
);
}
}
}
Loading