Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

lazy load efSearch #52

Merged
merged 3 commits into from
Feb 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ private void onRemoval(RemovalNotification<String, KNNIndexCacheEntry> removalNo
* Loads corresponding index for the given key to memory and returns the index object.
*
* @param key indexPath where the serialized hnsw graph is stored
* @param algoParams hnsw algorithm parameters
* @param indexName index name
* @return KNNIndex holding the heap pointer of the loaded graph
*/
public KNNIndex getIndex(String key, final String[] algoParams) {
public KNNIndex getIndex(String key, final String indexName) {
try {
final KNNIndexCacheEntry knnIndexCacheEntry = cache.get(key, () -> loadIndex(key, algoParams));
final KNNIndexCacheEntry knnIndexCacheEntry = cache.get(key, () -> loadIndex(key, indexName));
return knnIndexCacheEntry.getKnnIndex();
} catch (ExecutionException e) {
throw new RuntimeException(e);
Expand Down Expand Up @@ -178,12 +178,12 @@ public void setCacheCapacityReached(Boolean value) {
* Loads hnsw index to memory. Registers the location of the serialized graph with ResourceWatcher.
*
* @param indexPathUrl path for serialized hnsw graph
* @param algoParams hnsw algorithm parameters
* @param indexName index name
* @return KNNIndex holding the heap pointer of the loaded graph
* @throws Exception Exception could occur when registering the index path
* to Resource watcher or if the JNI call throws
*/
public KNNIndexCacheEntry loadIndex(String indexPathUrl, final String[] algoParams) throws Exception {
public KNNIndexCacheEntry loadIndex(String indexPathUrl, String indexName) throws Exception {
if(Strings.isNullOrEmpty(indexPathUrl))
throw new IllegalStateException("indexPath is null while performing load index");
logger.debug("Loading index on cache miss .. {}", indexPathUrl);
Expand All @@ -196,7 +196,7 @@ public KNNIndexCacheEntry loadIndex(String indexPathUrl, final String[] algoPara
// the entry
fileWatcher.init();

final KNNIndex knnIndex = KNNIndex.loadIndex(indexPathUrl, algoParams);
final KNNIndex knnIndex = KNNIndex.loadIndex(indexPathUrl, getQueryParams(indexName));

// TODO verify that this is safe - ideally we'd explicitly ensure that the FileWatcher is only checked
// after the guava cache has finished loading the key to avoid a race condition where the watcher
Expand Down Expand Up @@ -236,4 +236,8 @@ public void onFileDeleted(Path indexFilePath) {
getInstance().cache.invalidate(indexFilePath.toString());
}
};

private String[] getQueryParams(String indexName) {
return new String[] {"efSearch=" + KNNSettings.getEfSearchParam(indexName)};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain more why there is only 1 query parameter being passed in? Why are M and efConstruction not gotten?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

efSearch is required only during searches. Other algorithm params are needed for indexing.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ public class KNNQuery extends Query {
private final String field;
private final float[] queryVector;
private final int k;
private final String indexName;

public KNNQuery(String field, float[] queryVector, int k) {
public KNNQuery(String field, float[] queryVector, int k, String indexName) {
this.field = field;
this.queryVector = queryVector;
this.k = k;
this.indexName = indexName;
}

public String getField() {
Expand All @@ -49,6 +51,8 @@ public int getK() {
return this.k;
}

public String getIndexName() { return this.indexName; }

/**
* Constructs Weight implementation for this query
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio

@Override
protected Query doToQuery(QueryShardContext context) throws IOException {
return new KNNQuery(this.fieldName, vector, k);
return new KNNQuery(this.fieldName, vector, k, context.index().getName());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,4 +339,24 @@ public void onFailure(Exception e) {
}
});
}

/**
*
* @param index Name of the index
* @return efSearch value
*/
public static int getEfSearchParam(String index) {
return getIndexSettingValue(index, KNN_ALGO_PARAM_EF_SEARCH, 512);
}

public static int getIndexSettingValue(String index, String settingName, int defaultValue) {
return KNNSettings.state().clusterService.state().getMetaData()
.index(index).getSettings()
.getAsInt(settingName, defaultValue);
}

public void setClusterService(ClusterService clusterService) {
this.clusterService = clusterService;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ public static class TypeParser implements Mapper.TypeParser {
Builder builder = new KNNVectorFieldMapper.Builder(name);
builder.algoParams(KNNConstants.HNSW_ALGO_M, parserContext.mapperService().getIndexSettings().getValue(KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING));
builder.algoParams(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, parserContext.mapperService().getIndexSettings().getValue(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING));
builder.algoParams(KNNConstants.HNSW_ALGO_EF_SEARCH, parserContext.mapperService().getIndexSettings().getValue(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING));


/**
* If dimension not provided. Throw Exception
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package com.amazon.opendistroforelasticsearch.knn.index;

import com.amazon.opendistroforelasticsearch.knn.index.codec.KNNCodecUtil;
import com.amazon.opendistroforelasticsearch.knn.index.util.KNNConstants;
import com.amazon.opendistroforelasticsearch.knn.index.v1736.KNNIndex;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -36,8 +35,6 @@

import java.io.IOException;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -100,8 +97,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
*/

Path indexPath = PathUtils.get(directory, hnswFiles.get(0));

final KNNIndex index = knnIndexCache.getIndex(indexPath.toString(), getQueryParams(queryFieldInfo));
final KNNIndex index = knnIndexCache.getIndex(indexPath.toString(), knnQuery.getIndexName());
final KNNQueryResult[] results = index.queryIndex(
knnQuery.getQueryVector(),
knnQuery.getK()
Expand All @@ -127,12 +123,5 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
public boolean isCacheable(LeafReaderContext context) {
return true;
}

private String[] getQueryParams(FieldInfo fieldInfo) {
if (fieldInfo.attributes().containsKey(KNNConstants.HNSW_ALGO_EF_SEARCH)) {
return new String[] {"efSearch=" + fieldInfo.attributes().get(KNNConstants.HNSW_ALGO_EF_SEARCH)};
}
return new String[] {};
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,10 @@ public KNNQueryResult[] run() {
public void close() {
Lock writeLock = readWriteLock.writeLock();
writeLock.lock();

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this no longer needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Bad merge. Will add it back.

// Autocloseable documentation recommends making close idempotent. We don't expect to doubly close
// but this will help prevent a crash in that situation.
if (this.isClosed) {
return;
// but this will help prevent a crash in that situation.
if (this.isClosed) {
return;
}
try {
gc(this.indexPointer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package com.amazon.opendistroforelasticsearch.knn.index;

import com.amazon.opendistroforelasticsearch.knn.index.codec.KNN80Codec;

import com.amazon.opendistroforelasticsearch.knn.index.codec.KNNCodecUtil;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.document.Document;
Expand All @@ -31,18 +30,32 @@
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.watcher.ResourceWatcherService;
import org.mockito.Mockito;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

@ESIntegTestCase.ClusterScope(scope=ESIntegTestCase.Scope.SUITE, numDataNodes=1)
public class KNN80HnswIndexIT extends ESIntegTestCase {

private void setUpMockClusterService() {
ClusterService clusterService = mock(ClusterService.class, RETURNS_DEEP_STUBS);
Settings settings = Settings.Builder.EMPTY_SETTINGS;
when(clusterService.state().getMetaData().index(Mockito.anyString()).getSettings()).thenReturn(settings);
KNNSettings.state().setClusterService(clusterService);
}

public void testFooter() throws Exception {
setUpMockClusterService();
Directory dir = newFSDirectory(createTempDir());
IndexWriterConfig iwc = newIndexWriterConfig();
iwc.setMergeScheduler(new SerialMergeScheduler());
Expand All @@ -55,7 +68,6 @@ public void testFooter() throws Exception {
doc.add(vectorField);
writer.addDocument(doc);


KNNIndexCache.setResourceWatcherService(createDisabledResourceWatcherService());
IndexReader reader = writer.getReader();
LeafReaderContext lrc = reader.getContext().leaves().iterator().next(); // leaf reader context
Expand All @@ -73,14 +85,15 @@ public void testFooter() throws Exception {
indexInput.close();

IndexSearcher searcher = newSearcher(reader);
assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 2.5f}, 1)));
assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 2.5f}, 1, "myindex")));

reader.close();
writer.close();
dir.close();
}

public void testMultiFieldsKnnIndex() throws Exception {
setUpMockClusterService();
Directory dir = newFSDirectory(createTempDir());
IndexWriterConfig iwc = newIndexWriterConfig();
iwc.setMergeScheduler(new SerialMergeScheduler());
Expand Down Expand Up @@ -121,14 +134,14 @@ public void testMultiFieldsKnnIndex() throws Exception {

// query to verify distance for each of the field
IndexSearcher searcher = newSearcher(reader);
float score = searcher.search(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1), 10).scoreDocs[0].score;
float score1 = searcher.search(new KNNQuery("my_vector", new float[] {1.0f, 2.0f}, 1), 10).scoreDocs[0].score;
float score = searcher.search(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1, "dummy"), 10).scoreDocs[0].score;
float score1 = searcher.search(new KNNQuery("my_vector", new float[] {1.0f, 2.0f}, 1, "dummy"), 10).scoreDocs[0].score;
assertEquals(score, 0.1667f, 0.01f);
assertEquals(score1, 0.0714f, 0.01f);

// query to determine the hits
assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1)));
assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] {1.0f, 1.0f}, 1)));
assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] {1.0f, 0.0f, 0.0f}, 1, "dummy")));
assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] {1.0f, 1.0f}, 1, "dummy")));

reader.close();
writer.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.test.ESTestCase;
import org.mockito.Mockito;

public class KNNQueryBuilderTests extends ESTestCase {

Expand Down Expand Up @@ -79,7 +82,10 @@ public void testFromXcontent() throws Exception {
public void testDoToQuery() throws Exception {
float[] queryVector = {1.0f, 2.0f, 3.0f, 4.0f};
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1);
KNNQuery query = (KNNQuery)knnQueryBuilder.doToQuery(null);
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = Mockito.mock(QueryShardContext.class);
Mockito.when(mockQueryShardContext.index()).thenReturn(dummyIndex);
KNNQuery query = (KNNQuery)knnQueryBuilder.doToQuery(mockQueryShardContext);
assertEquals(knnQueryBuilder.getK(), query.getK());
assertEquals(knnQueryBuilder.fieldName(), query.getField());
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
Expand Down