diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNIndexCache.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNIndexCache.java index 2b80974b..8d959309 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNIndexCache.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNIndexCache.java @@ -126,12 +126,12 @@ private void onRemoval(RemovalNotification removalNo * Loads corresponding index for the given key to memory and returns the index object. * * @param key indexPath where the serialized hnsw graph is stored - * @param algoParams hnsw algorithm parameters + * @param indexName index name * @return KNNIndex holding the heap pointer of the loaded graph */ - public KNNIndex getIndex(String key, final String[] algoParams) { + public KNNIndex getIndex(String key, final String indexName) { try { - final KNNIndexCacheEntry knnIndexCacheEntry = cache.get(key, () -> loadIndex(key, algoParams)); + final KNNIndexCacheEntry knnIndexCacheEntry = cache.get(key, () -> loadIndex(key, indexName)); return knnIndexCacheEntry.getKnnIndex(); } catch (ExecutionException e) { throw new RuntimeException(e); @@ -178,12 +178,12 @@ public void setCacheCapacityReached(Boolean value) { * Loads hnsw index to memory. Registers the location of the serialized graph with ResourceWatcher. * * @param indexPathUrl path for serialized hnsw graph - * @param algoParams hnsw algorithm parameters + * @param indexName index name * @return KNNIndex holding the heap pointer of the loaded graph * @throws Exception Exception could occur when registering the index path * to Resource watcher or if the JNI call throws */ - public KNNIndexCacheEntry loadIndex(String indexPathUrl, final String[] algoParams) throws Exception { + public KNNIndexCacheEntry loadIndex(String indexPathUrl, String indexName) throws Exception { if(Strings.isNullOrEmpty(indexPathUrl)) throw new IllegalStateException("indexPath is null while performing load index"); logger.debug("Loading index on cache miss .. {}", indexPathUrl); @@ -196,7 +196,7 @@ public KNNIndexCacheEntry loadIndex(String indexPathUrl, final String[] algoPara // the entry fileWatcher.init(); - final KNNIndex knnIndex = KNNIndex.loadIndex(indexPathUrl, algoParams); + final KNNIndex knnIndex = KNNIndex.loadIndex(indexPathUrl, getQueryParams(indexName)); // TODO verify that this is safe - ideally we'd explicitly ensure that the FileWatcher is only checked // after the guava cache has finished loading the key to avoid a race condition where the watcher @@ -236,4 +236,8 @@ public void onFileDeleted(Path indexFilePath) { getInstance().cache.invalidate(indexFilePath.toString()); } }; + + private String[] getQueryParams(String indexName) { + return new String[] {"efSearch=" + KNNSettings.getEfSearchParam(indexName)}; + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNQuery.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNQuery.java index 4cdf6f4d..1e4b9175 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNQuery.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNQuery.java @@ -30,11 +30,13 @@ public class KNNQuery extends Query { private final String field; private final float[] queryVector; private final int k; + private final String indexName; - public KNNQuery(String field, float[] queryVector, int k) { + public KNNQuery(String field, float[] queryVector, int k, String indexName) { this.field = field; this.queryVector = queryVector; this.k = k; + this.indexName = indexName; } public String getField() { @@ -49,6 +51,8 @@ public int getK() { return this.k; } + public String getIndexName() { return this.indexName; } + /** * Constructs Weight implementation for this query * diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNQueryBuilder.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNQueryBuilder.java index 41ff4d43..35a9fffe 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNQueryBuilder.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNQueryBuilder.java @@ -186,7 +186,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio @Override protected Query doToQuery(QueryShardContext context) throws IOException { - return new KNNQuery(this.fieldName, vector, k); + return new KNNQuery(this.fieldName, vector, k, context.index().getName()); } @Override diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNSettings.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNSettings.java index f424ab71..e35df977 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNSettings.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNSettings.java @@ -339,4 +339,24 @@ public void onFailure(Exception e) { } }); } + + /** + * + * @param index Name of the index + * @return efSearch value + */ + public static int getEfSearchParam(String index) { + return getIndexSettingValue(index, KNN_ALGO_PARAM_EF_SEARCH, 512); + } + + public static int getIndexSettingValue(String index, String settingName, int defaultValue) { + return KNNSettings.state().clusterService.state().getMetaData() + .index(index).getSettings() + .getAsInt(settingName, defaultValue); + } + + public void setClusterService(ClusterService clusterService) { + this.clusterService = clusterService; + } + } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNVectorFieldMapper.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNVectorFieldMapper.java index 1b8cb8a0..5fa852bc 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNVectorFieldMapper.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNVectorFieldMapper.java @@ -137,8 +137,7 @@ public static class TypeParser implements Mapper.TypeParser { Builder builder = new KNNVectorFieldMapper.Builder(name); builder.algoParams(KNNConstants.HNSW_ALGO_M, parserContext.mapperService().getIndexSettings().getValue(KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING)); builder.algoParams(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, parserContext.mapperService().getIndexSettings().getValue(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING)); - builder.algoParams(KNNConstants.HNSW_ALGO_EF_SEARCH, parserContext.mapperService().getIndexSettings().getValue(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING)); - + /** * If dimension not provided. Throw Exception */ diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNWeight.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNWeight.java index dbb7d7a2..a324d2a3 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNWeight.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/KNNWeight.java @@ -16,7 +16,6 @@ package com.amazon.opendistroforelasticsearch.knn.index; import com.amazon.opendistroforelasticsearch.knn.index.codec.KNNCodecUtil; -import com.amazon.opendistroforelasticsearch.knn.index.util.KNNConstants; import com.amazon.opendistroforelasticsearch.knn.index.v1736.KNNIndex; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -36,8 +35,6 @@ import java.io.IOException; import java.nio.file.Path; -import java.security.AccessController; -import java.security.PrivilegedAction; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -100,8 +97,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { */ Path indexPath = PathUtils.get(directory, hnswFiles.get(0)); - - final KNNIndex index = knnIndexCache.getIndex(indexPath.toString(), getQueryParams(queryFieldInfo)); + final KNNIndex index = knnIndexCache.getIndex(indexPath.toString(), knnQuery.getIndexName()); final KNNQueryResult[] results = index.queryIndex( knnQuery.getQueryVector(), knnQuery.getK() @@ -127,12 +123,5 @@ public Scorer scorer(LeafReaderContext context) throws IOException { public boolean isCacheable(LeafReaderContext context) { return true; } - - private String[] getQueryParams(FieldInfo fieldInfo) { - if (fieldInfo.attributes().containsKey(KNNConstants.HNSW_ALGO_EF_SEARCH)) { - return new String[] {"efSearch=" + fieldInfo.attributes().get(KNNConstants.HNSW_ALGO_EF_SEARCH)}; - } - return new String[] {}; - } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/v1736/KNNIndex.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/v1736/KNNIndex.java index 39295dc2..a228d2f4 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/v1736/KNNIndex.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/v1736/KNNIndex.java @@ -87,11 +87,10 @@ public KNNQueryResult[] run() { public void close() { Lock writeLock = readWriteLock.writeLock(); writeLock.lock(); - // Autocloseable documentation recommends making close idempotent. We don't expect to doubly close - // but this will help prevent a crash in that situation. - if (this.isClosed) { - return; + // but this will help prevent a crash in that situation. + if (this.isClosed) { + return; } try { gc(this.indexPointer); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNN80HnswIndexIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNN80HnswIndexIT.java index a8c4c096..b7019b57 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNN80HnswIndexIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNN80HnswIndexIT.java @@ -16,7 +16,6 @@ package com.amazon.opendistroforelasticsearch.knn.index; import com.amazon.opendistroforelasticsearch.knn.index.codec.KNN80Codec; - import com.amazon.opendistroforelasticsearch.knn.index.codec.KNNCodecUtil; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.document.Document; @@ -31,18 +30,32 @@ import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.watcher.ResourceWatcherService; +import org.mockito.Mockito; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; +import static org.mockito.Mockito.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + @ESIntegTestCase.ClusterScope(scope=ESIntegTestCase.Scope.SUITE, numDataNodes=1) public class KNN80HnswIndexIT extends ESIntegTestCase { + private void setUpMockClusterService() { + ClusterService clusterService = mock(ClusterService.class, RETURNS_DEEP_STUBS); + Settings settings = Settings.Builder.EMPTY_SETTINGS; + when(clusterService.state().getMetaData().index(Mockito.anyString()).getSettings()).thenReturn(settings); + KNNSettings.state().setClusterService(clusterService); + } + public void testFooter() throws Exception { + setUpMockClusterService(); Directory dir = newFSDirectory(createTempDir()); IndexWriterConfig iwc = newIndexWriterConfig(); iwc.setMergeScheduler(new SerialMergeScheduler()); @@ -55,7 +68,6 @@ public void testFooter() throws Exception { doc.add(vectorField); writer.addDocument(doc); - KNNIndexCache.setResourceWatcherService(createDisabledResourceWatcherService()); IndexReader reader = writer.getReader(); LeafReaderContext lrc = reader.getContext().leaves().iterator().next(); // leaf reader context @@ -73,7 +85,7 @@ public void testFooter() throws Exception { indexInput.close(); IndexSearcher searcher = newSearcher(reader); - assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 2.5f}, 1))); + assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 2.5f}, 1, "myindex"))); reader.close(); writer.close(); @@ -81,6 +93,7 @@ public void testFooter() throws Exception { } public void testMultiFieldsKnnIndex() throws Exception { + setUpMockClusterService(); Directory dir = newFSDirectory(createTempDir()); IndexWriterConfig iwc = newIndexWriterConfig(); iwc.setMergeScheduler(new SerialMergeScheduler()); @@ -121,14 +134,14 @@ public void testMultiFieldsKnnIndex() throws Exception { // query to verify distance for each of the field IndexSearcher searcher = newSearcher(reader); - float score = searcher.search(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1), 10).scoreDocs[0].score; - float score1 = searcher.search(new KNNQuery("my_vector", new float[] {1.0f, 2.0f}, 1), 10).scoreDocs[0].score; + float score = searcher.search(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1, "dummy"), 10).scoreDocs[0].score; + float score1 = searcher.search(new KNNQuery("my_vector", new float[] {1.0f, 2.0f}, 1, "dummy"), 10).scoreDocs[0].score; assertEquals(score, 0.1667f, 0.01f); assertEquals(score1, 0.0714f, 0.01f); // query to determine the hits - assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1))); - assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] {1.0f, 1.0f}, 1))); + assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1, "dummy"))); + assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] {1.0f, 1.0f}, 1, "dummy"))); reader.close(); writer.close(); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNQueryBuilderTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNQueryBuilderTests.java index d7d35a33..855aab38 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNQueryBuilderTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNQueryBuilderTests.java @@ -18,7 +18,10 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.test.ESTestCase; +import org.mockito.Mockito; public class KNNQueryBuilderTests extends ESTestCase { @@ -79,7 +82,10 @@ public void testFromXcontent() throws Exception { public void testDoToQuery() throws Exception { float[] queryVector = {1.0f, 2.0f, 3.0f, 4.0f}; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1); - KNNQuery query = (KNNQuery)knnQueryBuilder.doToQuery(null); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = Mockito.mock(QueryShardContext.class); + Mockito.when(mockQueryShardContext.index()).thenReturn(dummyIndex); + KNNQuery query = (KNNQuery)knnQueryBuilder.doToQuery(mockQueryShardContext); assertEquals(knnQueryBuilder.getK(), query.getK()); assertEquals(knnQueryBuilder.fieldName(), query.getField()); assertEquals(knnQueryBuilder.vector(), query.getQueryVector());