Skip to content

Commit

Permalink
[BUG FIX] Fix knn index shard to get bwc engine paths (opensearch-pro…
Browse files Browse the repository at this point in the history
…ject#310)

Fixes getEnginePaths in KNNIndexShard to retrieve all engine paths,
regardless of what version the index was created. Prevents silent
failure when warmup completes but doesnt load any segments.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] authored Mar 8, 2022
1 parent 61c03eb commit 807b612
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 61 deletions.
58 changes: 41 additions & 17 deletions src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@

import java.io.IOException;
import java.nio.file.Path;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.index.IndexUtil.getParametersAtLoading;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFilePrefix;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileSuffix;

/**
* KNNIndexShard wraps IndexShard and adds methods to perform k-NN related operations against the shard
Expand Down Expand Up @@ -83,12 +84,14 @@ public void warmup() throws IOException {
getAllEnginePaths(searcher.getIndexReader()).forEach((key, value) -> {
try {
nativeMemoryCacheManager.get(
new NativeMemoryEntryContext.IndexEntryContext(
key,
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(),
getParametersAtLoading(value, KNNEngine.getEngineNameFromPath(key), getIndexName()),
getIndexName()
), true);
new NativeMemoryEntryContext.IndexEntryContext(
key,
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(),
getParametersAtLoading(value, KNNEngine.getEngineNameFromPath(key), getIndexName()),
getIndexName()
),
true
);
} catch (ExecutionException ex) {
throw new RuntimeException(ex);
}
Expand Down Expand Up @@ -118,25 +121,46 @@ private Map<String, SpaceType> getEnginePaths(IndexReader indexReader, KNNEngine
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader());
Path shardPath = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory();
String fileExtension = reader.getSegmentInfo().info.getUseCompoundFile()
? knnEngine.getCompoundExtension() : knnEngine.getExtension();
? knnEngine.getCompoundExtension()
: knnEngine.getExtension();

for (FieldInfo fieldInfo : reader.getFieldInfos()) {
if (fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) {
// Space Type will not be present on ES versions 7.1 and 7.4 because the only available space type
// was L2. So, if Space Type is not present, just fall back to L2
String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue());
SpaceType spaceType = SpaceType.getSpace(spaceTypeName);
String engineFileName = buildEngineFileName(reader.getSegmentInfo().info.name,
knnEngine.getLatestBuildVersion(), fieldInfo.name, fileExtension);

engineFiles.putAll(reader.getSegmentInfo().files().stream()
.filter(fileName -> fileName.equals(engineFileName))
.map(fileName -> shardPath.resolve(fileName).toString())
.filter(Objects::nonNull)
.collect(Collectors.toMap(fileName -> fileName, fileName -> spaceType)));

engineFiles.putAll(
getEnginePaths(
reader.getSegmentInfo().files(),
reader.getSegmentInfo().info.name,
fieldInfo.name,
fileExtension,
shardPath,
spaceType
)
);
}
}
}
return engineFiles;
}

protected Map<String, SpaceType> getEnginePaths(
Collection<String> files,
String segmentName,
String fieldName,
String fileExtension,
Path shardPath,
SpaceType spaceType
) {
String prefix = buildEngineFilePrefix(segmentName);
String suffix = buildEngineFileSuffix(fieldName, fileExtension);
return files.stream()
.filter(fileName -> fileName.startsWith(prefix))
.filter(fileName -> fileName.endsWith(suffix))
.map(fileName -> shardPath.resolve(fileName).toString())
.collect(Collectors.toMap(fileName -> fileName, fileName -> spaceType));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,18 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep
}
docIdList.add(doc);
}
return new KNNCodecUtil.Pair(docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorList.toArray(new float[][]{}));
return new KNNCodecUtil.Pair(docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorList.toArray(new float[][] {}));
}

public static String buildEngineFileName(String segmentName, String latestBuildVersion, String fieldName,
String extension) {
return String.format("%s_%s_%s%s", segmentName, latestBuildVersion, fieldName, extension);
public static String buildEngineFileName(String segmentName, String latestBuildVersion, String fieldName, String extension) {
return String.format("%s%s%s", buildEngineFilePrefix(segmentName), latestBuildVersion, buildEngineFileSuffix(fieldName, extension));
}

public static String buildEngineFilePrefix(String segmentName) {
return String.format("%s_", segmentName);
}

public static String buildEngineFileSuffix(String fieldName, String extension) {
return String.format("_%s%s", fieldName, extension);
}
}
52 changes: 20 additions & 32 deletions src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,34 +77,27 @@ protected IndexService createKNNIndex(String indexName) {
*/
protected void createKnnIndexMapping(String indexName, String fieldName, Integer dimensions) {
PutMappingRequest request = new PutMappingRequest(indexName).type("_doc");
request.source(fieldName, "type=knn_vector,dimension="+dimensions);
request.source(fieldName, "type=knn_vector,dimension=" + dimensions);
OpenSearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet());
}

/**
* Get default k-NN settings for test cases
*/
protected Settings getKNNDefaultIndexSettings() {
return Settings.builder()
.put("number_of_shards", 1)
.put("number_of_replicas", 0)
.put("index.knn", true)
.build();
return Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.knn", true).build();
}

/**
* Add a k-NN doc to an index
*/
protected void addKnnDoc(String index, String docId, String fieldName, Object[] vector)
throws IOException, InterruptedException, ExecutionException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject()
.field(fieldName, vector)
.endObject();
IndexRequest indexRequest = new IndexRequest()
.index(index)
.id(docId)
.source(builder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
protected void addKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException, InterruptedException,
ExecutionException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, vector).endObject();
IndexRequest indexRequest = new IndexRequest().index(index)
.id(docId)
.source(builder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

IndexResponse response = client().index(indexRequest).get();
assertEquals(response.status(), RestStatus.CREATED);
Expand All @@ -113,16 +106,13 @@ protected void addKnnDoc(String index, String docId, String fieldName, Object[]
/**
* Add any document to index
*/
protected void addDoc(String index, String docId, String fieldName, String dummyValue)
throws IOException, InterruptedException, ExecutionException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject()
.field(fieldName, dummyValue)
.endObject();
IndexRequest indexRequest = new IndexRequest()
.index(index)
.id(docId)
.source(builder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
protected void addDoc(String index, String docId, String fieldName, String dummyValue) throws IOException, InterruptedException,
ExecutionException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, dummyValue).endObject();
IndexRequest indexRequest = new IndexRequest().index(index)
.id(docId)
.source(builder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

IndexResponse response = client().index(indexRequest).get();
assertEquals(response.status(), RestStatus.CREATED);
Expand All @@ -132,18 +122,16 @@ protected void addDoc(String index, String docId, String fieldName, String dummy
* Run a search against a k-NN index
*/
protected void searchKNNIndex(String index, String fieldName, float[] vector, int k) {
SearchResponse response = client().prepareSearch(index).setQuery(new KNNQueryBuilder(fieldName, vector, k))
.get();
SearchResponse response = client().prepareSearch(index).setQuery(new KNNQueryBuilder(fieldName, vector, k)).get();
assertEquals(response.status(), RestStatus.OK);
}

public Map<String, Object> xContentBuilderToMap(XContentBuilder xContentBuilder) {
return XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true,
xContentBuilder.contentType()).v2();
return XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, xContentBuilder.contentType()).v2();
}

public void assertTrainingSucceeds(ModelDao modelDao, String modelId, int attempts, int delayInMillis)
throws InterruptedException, ExecutionException {
public void assertTrainingSucceeds(ModelDao modelDao, String modelId, int attempts, int delayInMillis) throws InterruptedException,
ExecutionException {

int attemptNum = 0;
ModelMetadata modelMetadata;
Expand Down
52 changes: 44 additions & 8 deletions src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,27 @@

package org.opensearch.knn.index;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.opensearch.knn.KNNSingleNodeTestCase;
import org.opensearch.index.IndexService;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.GRAPH_COUNT;


public class KNNIndexShardTests extends KNNSingleNodeTestCase {

private final String testIndexName = "test-index";
Expand All @@ -29,7 +35,7 @@ public class KNNIndexShardTests extends KNNSingleNodeTestCase {
public void testGetIndexShard() throws InterruptedException, ExecutionException, IOException {
IndexService indexService = createKNNIndex(testIndexName);
createKnnIndexMapping(testIndexName, testFieldName, dimensions);
addKnnDoc(testIndexName, "1", testFieldName, new Float[] {2.5F, 3.5F});
addKnnDoc(testIndexName, "1", testFieldName, new Float[] { 2.5F, 3.5F });

IndexShard indexShard = indexService.iterator().next();
KNNIndexShard knnIndexShard = new KNNIndexShard(indexShard);
Expand All @@ -39,7 +45,7 @@ public void testGetIndexShard() throws InterruptedException, ExecutionException,
public void testGetIndexName() throws InterruptedException, ExecutionException, IOException {
IndexService indexService = createKNNIndex(testIndexName);
createKnnIndexMapping(testIndexName, testFieldName, dimensions);
addKnnDoc(testIndexName, "1", testFieldName, new Float[] {2.5F, 3.5F});
addKnnDoc(testIndexName, "1", testFieldName, new Float[] { 2.5F, 3.5F });

IndexShard indexShard = indexService.iterator().next();
KNNIndexShard knnIndexShard = new KNNIndexShard(indexShard);
Expand All @@ -59,9 +65,9 @@ public void testWarmup_emptyIndex() throws IOException {
public void testWarmup_shardPresentInCache() throws InterruptedException, ExecutionException, IOException {
IndexService indexService = createKNNIndex(testIndexName);
createKnnIndexMapping(testIndexName, testFieldName, dimensions);
addKnnDoc(testIndexName, "1", testFieldName, new Float[] {2.5F, 3.5F});
addKnnDoc(testIndexName, "1", testFieldName, new Float[] { 2.5F, 3.5F });

searchKNNIndex(testIndexName, testFieldName, new float[] {1.0f, 2.0f}, 1);
searchKNNIndex(testIndexName, testFieldName, new float[] { 1.0f, 2.0f }, 1);
assertEquals(1, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName).get(GRAPH_COUNT));

IndexShard indexShard = indexService.iterator().next();
Expand All @@ -76,15 +82,15 @@ public void testWarmup_shardNotPresentInCache() throws InterruptedException, Exe
IndexShard indexShard;
KNNIndexShard knnIndexShard;

addKnnDoc(testIndexName, "1", testFieldName, new Float[] {2.5F, 3.5F});
addKnnDoc(testIndexName, "1", testFieldName, new Float[] { 2.5F, 3.5F });
client().admin().indices().prepareFlush(testIndexName).execute();

indexShard = indexService.iterator().next();
knnIndexShard = new KNNIndexShard(indexShard);
knnIndexShard.warmup();
assertEquals(1, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName).get(GRAPH_COUNT));

addKnnDoc(testIndexName, "2", testFieldName, new Float[] {2.5F, 3.5F});
addKnnDoc(testIndexName, "2", testFieldName, new Float[] { 2.5F, 3.5F });
indexShard = indexService.iterator().next();
knnIndexShard = new KNNIndexShard(indexShard);
knnIndexShard.warmup();
Expand All @@ -107,7 +113,7 @@ public void testGetHNSWPaths() throws IOException, ExecutionException, Interrupt
assertEquals(0, hnswPaths.size());
searcher.close();

addKnnDoc(testIndexName, "1", testFieldName, new Float[] {2.5F, 3.5F});
addKnnDoc(testIndexName, "1", testFieldName, new Float[] { 2.5F, 3.5F });

searcher = indexShard.acquireSearcher("test-hnsw-paths-2");
hnswPaths = knnIndexShard.getAllEnginePaths(searcher.getIndexReader());
Expand All @@ -116,4 +122,34 @@ public void testGetHNSWPaths() throws IOException, ExecutionException, Interrupt
assertTrue(paths.get(0).contains("hnsw") || paths.get(0).contains("hnswc"));
searcher.close();
}

public void testGetEnginePaths() {
// Check that the correct engine paths are being returned by the KNNIndexShard
String segmentName = "_0";
String fieldName = "test_field";
String fileExt = ".test";
SpaceType spaceType = SpaceType.L2;

Set<String> includedFileNames = ImmutableSet.of(
String.format("%s_111_%s%s", segmentName, fieldName, fileExt),
String.format("%s_7_%s%s", segmentName, fieldName, fileExt),
String.format("%s_53_%s%s", segmentName, fieldName, fileExt)
);

List<String> excludedFileNames = ImmutableList.of(
String.format("_111_%s%s", fieldName, fileExt), // missing segment name
String.format("%s_111_%s", segmentName, fileExt), // missing field name
String.format("%s_111_%s.invalid", segmentName, fieldName) // missing extension
);

List<String> files = Stream.concat(includedFileNames.stream(), excludedFileNames.stream()).collect(Collectors.toList());

KNNIndexShard knnIndexShard = new KNNIndexShard(null);

Path path = Paths.get("");
Map<String, SpaceType> included = knnIndexShard.getEnginePaths(files, segmentName, fieldName, fileExt, path, spaceType);

assertEquals(includedFileNames.size(), included.size());
included.keySet().forEach(o -> assertTrue(includedFileNames.contains(o)));
}
}

0 comments on commit 807b612

Please sign in to comment.