diff --git a/src/main/java/org/opensearch/knn/index/KNNQuery.java b/src/main/java/org/opensearch/knn/index/KNNQuery.java index 76709ae7e2..631b36d7a4 100644 --- a/src/main/java/org/opensearch/knn/index/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/KNNQuery.java @@ -42,7 +42,9 @@ public int getK() { return this.k; } - public String getIndexName() { return this.indexName; } + public String getIndexName() { + return this.indexName; + } /** * Constructs Weight implementation for this query @@ -77,8 +79,7 @@ public int hashCode() { @Override public boolean equals(Object other) { - return sameClassAs(other) && - equalsTo(getClass().cast(other)); + return sameClassAs(other) && equalsTo(getClass().cast(other)); } private boolean equalsTo(KNNQuery other) { diff --git a/src/main/java/org/opensearch/knn/index/KNNWeight.java b/src/main/java/org/opensearch/knn/index/KNNWeight.java index 3f777c9703..7defb1eeb2 100644 --- a/src/main/java/org/opensearch/knn/index/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/KNNWeight.java @@ -5,7 +5,6 @@ package org.opensearch.knn.index; -import com.google.common.collect.ImmutableMap; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -19,7 +18,6 @@ import org.apache.lucene.index.FilterLeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentReader; -import org.apache.lucene.index.Term; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Scorer; @@ -36,10 +34,8 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; @@ -79,108 +75,116 @@ public Explanation explain(LeafReaderContext context, int doc) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { - SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); - String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); - - FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); - - if (fieldInfo == null) { - logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), - reader.getSegmentName()); - return null; - } - - KNNEngine knnEngine; - SpaceType spaceType; - - // Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's - // metadata. - String modelId = fieldInfo.getAttribute(MODEL_ID); - if (modelId != null) { - ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (modelMetadata == null) { - throw new RuntimeException("Model \"" + modelId + "\" does not exist."); - } - - knnEngine = modelMetadata.getKnnEngine(); - spaceType = modelMetadata.getSpaceType(); - } else { - String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName()); - knnEngine = KNNEngine.getEngine(engineName); - String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); - spaceType = SpaceType.getSpace(spaceTypeName); + SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); + String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); + + FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + + if (fieldInfo == null) { + logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); + return null; + } + + KNNEngine knnEngine; + SpaceType spaceType; + + // Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's + // metadata. + String modelId = fieldInfo.getAttribute(MODEL_ID); + if (modelId != null) { + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + if (modelMetadata == null) { + throw new RuntimeException("Model \"" + modelId + "\" does not exist."); } - /* - * In case of compound file, extension would be + c otherwise - */ - String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile() - ? knnEngine.getExtension() + KNNConstants.COMPOUND_EXTENSION : knnEngine.getExtension(); - String engineSuffix = knnQuery.getField() + engineExtension; - List engineFiles = reader.getSegmentInfo().files().stream() - .filter(fileName -> fileName.endsWith(engineSuffix)) - .collect(Collectors.toList()); - - if(engineFiles.isEmpty()) { - logger.debug("[KNN] No engine index found for field {} for segment {}", - knnQuery.getField(), reader.getSegmentName()); - return null; + knnEngine = modelMetadata.getKnnEngine(); + spaceType = modelMetadata.getSpaceType(); + } else { + String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName()); + knnEngine = KNNEngine.getEngine(engineName); + String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); + spaceType = SpaceType.getSpace(spaceTypeName); + } + + /* + * In case of compound file, extension would be + c otherwise + */ + String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile() + ? knnEngine.getExtension() + KNNConstants.COMPOUND_EXTENSION + : knnEngine.getExtension(); + String engineSuffix = knnQuery.getField() + engineExtension; + List engineFiles = reader.getSegmentInfo() + .files() + .stream() + .filter(fileName -> fileName.endsWith(engineSuffix)) + .collect(Collectors.toList()); + + if (engineFiles.isEmpty()) { + logger.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); + return null; + } + + Path indexPath = PathUtils.get(directory, engineFiles.get(0)); + final KNNQueryResult[] results; + KNNCounter.GRAPH_QUERY_REQUESTS.increment(); + + // We need to first get index allocation + NativeMemoryAllocation indexAllocation; + try { + indexAllocation = nativeMemoryCacheManager.get( + new NativeMemoryEntryContext.IndexEntryContext( + indexPath.toString(), + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), + getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()), + knnQuery.getIndexName() + ), + true + ); + } catch (ExecutionException e) { + GRAPH_QUERY_ERRORS.increment(); + throw new RuntimeException(e); + } + + // Now that we have the allocation, we need to readLock it + indexAllocation.readLock(); + + try { + if (indexAllocation.isClosed()) { + throw new RuntimeException("Index has already been closed"); } - Path indexPath = PathUtils.get(directory, engineFiles.get(0)); - final KNNQueryResult[] results; - KNNCounter.GRAPH_QUERY_REQUESTS.increment(); - - // We need to first get index allocation - NativeMemoryAllocation indexAllocation; - try { - indexAllocation = nativeMemoryCacheManager.get( - new NativeMemoryEntryContext.IndexEntryContext( - indexPath.toString(), - NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()), - knnQuery.getIndexName() - ), true); - } catch (ExecutionException e) { - GRAPH_QUERY_ERRORS.increment(); - throw new RuntimeException(e); - } - - // Now that we have the allocation, we need to readLock it - indexAllocation.readLock(); - - try { - if (indexAllocation.isClosed()) { - throw new RuntimeException("Index has already been closed"); - } - - results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), knnEngine.getName()); - } catch (Exception e) { - GRAPH_QUERY_ERRORS.increment(); - throw new RuntimeException(e); - } finally { - indexAllocation.readUnlock(); - } - - /* - * Scores represent the distance of the documents with respect to given query vector. - * Lesser the score, the closer the document is to the query vector. - * Since by default results are retrieved in the descending order of scores, to get the nearest - * neighbors we are inverting the scores. - */ - if (results.length == 0) { - logger.debug("[KNN] Query yielded 0 results"); - return null; - } - - Map scores = Arrays.stream(results).collect( - Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); - int maxDoc = Collections.max(scores.keySet()) + 1; - DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); - DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc); - Arrays.stream(results).forEach(result -> setAdder.add(result.getId())); - DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator(); - return new KNNScorer(this, docIdSetIter, scores, boost); + results = JNIService.queryIndex( + indexAllocation.getMemoryAddress(), + knnQuery.getQueryVector(), + knnQuery.getK(), + knnEngine.getName() + ); + } catch (Exception e) { + GRAPH_QUERY_ERRORS.increment(); + throw new RuntimeException(e); + } finally { + indexAllocation.readUnlock(); + } + + /* + * Scores represent the distance of the documents with respect to given query vector. + * Lesser the score, the closer the document is to the query vector. + * Since by default results are retrieved in the descending order of scores, to get the nearest + * neighbors we are inverting the scores. + */ + if (results.length == 0) { + logger.debug("[KNN] Query yielded 0 results"); + return null; + } + + Map scores = Arrays.stream(results) + .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); + int maxDoc = Collections.max(scores.keySet()) + 1; + DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); + DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc); + Arrays.stream(results).forEach(result -> setAdder.add(result.getId())); + DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator(); + return new KNNScorer(this, docIdSetIter, scores, boost); } @Override @@ -189,9 +193,7 @@ public boolean isCacheable(LeafReaderContext context) { } public static float normalizeScore(float score) { - if (score >= 0) - return 1 / (1 + score); + if (score >= 0) return 1 / (1 + score); return -score + 1; } } - diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java index 5eb36e748b..0064f49fef 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80Codec.java @@ -53,8 +53,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { * This function returns the Lucene80 Codec. */ public Codec getDelegatee() { - if (lucene80Codec == null) - lucene80Codec = Codec.forName(LUCENE_80); + if (lucene80Codec == null) lucene80Codec = Codec.forName(LUCENE_80); return lucene80Codec; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java index c4ae7ab2e9..b55365c094 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN84Codec/KNN84Codec.java @@ -43,7 +43,7 @@ public KNN84Codec() { super(KNN_84); // Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80 // DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses - this.docValuesFormat = new KNN80DocValuesFormat(); + this.docValuesFormat = new KNN80DocValuesFormat(); this.perFieldDocValuesFormat = new PerFieldDocValuesFormat() { @Override public DocValuesFormat getDocValuesFormatForField(String field) { @@ -57,8 +57,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { * This function returns the Lucene84 Codec. */ public Codec getDelegatee() { - if (lucene84Codec == null) - lucene84Codec = Codec.forName(LUCENE_84); + if (lucene84Codec == null) lucene84Codec = Codec.forName(LUCENE_84); return lucene84Codec; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java index 154f44f1f9..a3b34559a5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN86Codec/KNN86Codec.java @@ -44,7 +44,7 @@ public KNN86Codec() { super(KNN_86); // Note that DocValuesFormat can use old Codec's DocValuesFormat. For instance Lucene84 uses Lucene80 // DocValuesFormat. Refer to defaultDVFormat in LuceneXXCodec.java to find out which version it uses - this.docValuesFormat = new KNN80DocValuesFormat(); + this.docValuesFormat = new KNN80DocValuesFormat(); this.perFieldDocValuesFormat = new PerFieldDocValuesFormat() { @Override public DocValuesFormat getDocValuesFormatForField(String field) { @@ -58,8 +58,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) { * This function returns the Lucene84 Codec. */ public Codec getDelegatee() { - if (lucene86Codec == null) - lucene86Codec = Codec.forName(LUCENE_86); + if (lucene86Codec == null) lucene86Codec = Codec.forName(LUCENE_86); return lucene86Codec; } @@ -74,7 +73,6 @@ public DocValuesFormat docValuesFormat() { * approach of manually overriding. */ - public void setPostingsFormat(PostingsFormat postingsFormat) { this.postingsFormat = postingsFormat; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91Codec.java index 1d3bb95801..e635855739 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN91Codec/KNN91Codec.java @@ -35,6 +35,7 @@ public final class KNN91Codec extends FilterCodec { public KNN91Codec() { this(new Lucene91Codec()); } + /** * Constructor that takes a Codec delegate to delegate all methods this code does not implement to. * diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index f54c8bc34e..6ab1b2cca3 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -40,10 +40,7 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Strings; -import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.XContentHelper; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.mapper.MapperService; @@ -205,7 +202,7 @@ public void create(ActionListener actionListener) throws IO return; } String mapping = Strings.toString( - JsonXContent.contentBuilder().startObject().startObject(MapperService.SINGLE_MAPPING_NAME).endObject().endObject() + JsonXContent.contentBuilder().startObject().startObject(MapperService.SINGLE_MAPPING_NAME).endObject().endObject() ); CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(mapping) .settings( diff --git a/src/test/java/org/opensearch/knn/TestUtils.java b/src/test/java/org/opensearch/knn/TestUtils.java index ce09394d23..03b50280f1 100644 --- a/src/test/java/org/opensearch/knn/TestUtils.java +++ b/src/test/java/org/opensearch/knn/TestUtils.java @@ -36,7 +36,7 @@ class DistVector { public float dist; public String docID; - public DistVector (float dist, String docID) { + public DistVector(float dist, String docID) { this.dist = dist; this.docID = docID; } @@ -117,10 +117,10 @@ public static List> computeGroundTruthValues(float[][] indexVectors, } if (pq.size() < k) { - pq.add(new DistVector(dist, String.valueOf(j+1))); + pq.add(new DistVector(dist, String.valueOf(j + 1))); } else if (pq.peek().getDist() > dist) { pq.poll(); - pq.add(new DistVector(dist, String.valueOf(j+1))); + pq.add(new DistVector(dist, String.valueOf(j + 1))); } } @@ -137,7 +137,7 @@ public static List> computeGroundTruthValues(float[][] indexVectors, public static float[][] getQueryVectors(int queryCount, int dimensions, int docCount, boolean isStandard) { if (isStandard) { - return randomlyGenerateStandardVectors(queryCount, dimensions, docCount+1); + return randomlyGenerateStandardVectors(queryCount, dimensions, docCount + 1); } else { return generateRandomVectors(queryCount, dimensions); } @@ -169,8 +169,8 @@ public static double calculateRecallValue(List> searchResults, List recalls.add(recallVal / k); } - double sum = recalls.stream().reduce((a,b)->a+b).get(); - return sum/recalls.size(); + double sum = recalls.stream().reduce((a, b) -> a + b).get(); + return sum / recalls.size(); } /** @@ -192,14 +192,15 @@ private KNNCodecUtil.Pair readIndexData(String path) throws IOException { BufferedReader reader = new BufferedReader(new FileReader(path)); String line = reader.readLine(); while (line != null) { - Map doc = XContentFactory.xContent(XContentType.JSON).createParser( - NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, line).map(); + Map doc = XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, line) + .map(); idsList.add((Integer) doc.get("id")); @SuppressWarnings("unchecked") ArrayList vector = (ArrayList) doc.get("vector"); Float[] floatArray = new Float[vector.size()]; - for (int i =0; i< vector.size(); i++) { + for (int i = 0; i < vector.size(); i++) { floatArray[i] = vector.get(i).floatValue(); } vectorsList.add(floatArray); @@ -208,7 +209,7 @@ private KNNCodecUtil.Pair readIndexData(String path) throws IOException { } reader.close(); - int[] idsArray = new int [idsList.size()]; + int[] idsArray = new int[idsList.size()]; float[][] vectorsArray = new float[vectorsList.size()][vectorsList.get(0).length]; for (int i = 0; i < idsList.size(); i++) { idsArray[i] = idsList.get(i); diff --git a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java index 3c79051261..e0b2c05cfe 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java @@ -17,7 +17,6 @@ import org.opensearch.common.Strings; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.util.KNNEngine; @@ -48,14 +47,23 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException // "Train" a faiss flat index - this really just creates an empty index that does brute force k-NN long vectorsPointer = JNIService.transferVectors(0, new float[0][0]); - byte [] modelBlob = JNIService.trainIndex(ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, "Flat", - SPACE_TYPE, spaceType.getValue()), dimension, vectorsPointer, - KNNEngine.FAISS.getName()); + byte[] modelBlob = JNIService.trainIndex( + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "Flat", SPACE_TYPE, spaceType.getValue()), + dimension, + vectorsPointer, + KNNEngine.FAISS.getName() + ); // Setup model - ModelMetadata modelMetadata = new ModelMetadata(knnEngine, spaceType, dimension, ModelState.CREATED, - ZonedDateTime.now(ZoneOffset.UTC).toString(), "", ""); + ModelMetadata modelMetadata = new ModelMetadata( + knnEngine, + spaceType, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ); Model model = new Model(modelMetadata, modelBlob, modelId); @@ -68,35 +76,30 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException String fieldName = "test-field"; final String mapping = Strings.toString( - XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("model_id", modelId) - .endObject() - .endObject() - .endObject()); + XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("model_id", modelId) + .endObject() + .endObject() + .endObject() + ); modelDao.put(model, ActionListener.wrap(indexResponse -> { - CreateIndexRequestBuilder createIndexRequestBuilder = client().admin().indices().prepareCreate(indexName) - .setSettings(Settings.builder() - .put("number_of_shards", 1) - .put("number_of_replicas", 0) - .put("index.knn", true) - .build() - ).setMapping(mapping); - - client().admin().indices().create(createIndexRequestBuilder.request(), - ActionListener.wrap( - createIndexResponse -> { - assertTrue(createIndexResponse.isAcknowledged()); - inProgressLatch.countDown(); - }, e -> fail("Unable to create index: " + e.getMessage()) - ) - ); - - }, e ->fail("Unable to put model: " + e.getMessage()))); + CreateIndexRequestBuilder createIndexRequestBuilder = client().admin() + .indices() + .prepareCreate(indexName) + .setSettings(Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.knn", true).build()) + .setMapping(mapping); + + client().admin().indices().create(createIndexRequestBuilder.request(), ActionListener.wrap(createIndexResponse -> { + assertTrue(createIndexResponse.isAcknowledged()); + inProgressLatch.countDown(); + }, e -> fail("Unable to create index: " + e.getMessage()))); + + }, e -> fail("Unable to put model: " + e.getMessage()))); assertTrue(inProgressLatch.await(20, TimeUnit.SECONDS)); } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java index e1d1889f4d..8bda1aefc7 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java @@ -43,9 +43,11 @@ private void createKNNVectorDocument(Directory directory) throws IOException { IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); knnDocument.add( - new BinaryDocValuesField( - MOCK_INDEX_FIELD_NAME, - new VectorField(MOCK_INDEX_FIELD_NAME, new float[]{1.0f, 2.0f}, new FieldType()).binaryValue())); + new BinaryDocValuesField( + MOCK_INDEX_FIELD_NAME, + new VectorField(MOCK_INDEX_FIELD_NAME, new float[] { 1.0f, 2.0f }, new FieldType()).binaryValue() + ) + ); knnDocument.add(new NumericDocValuesField(MOCK_NUMERIC_INDEX_FIELD_NAME, 1000)); writer.addDocument(knnDocument); writer.commit(); @@ -67,16 +69,14 @@ public void testGetScriptValues() { } public void testGetScriptValuesWrongFieldName() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( - leafReaderContext.reader(), "invalid"); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "invalid"); ScriptDocValues scriptValues = leafFieldData.getScriptValues(); assertNotNull(scriptValues); } public void testGetScriptValuesWrongFieldType() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( - leafReaderContext.reader(), MOCK_NUMERIC_INDEX_FIELD_NAME); - expectThrows(IllegalStateException.class, ()->leafFieldData.getScriptValues()); + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), MOCK_NUMERIC_INDEX_FIELD_NAME); + expectThrows(IllegalStateException.class, () -> leafFieldData.getScriptValues()); } public void testRamBytesUsed() { @@ -86,7 +86,6 @@ public void testRamBytesUsed() { public void testGetBytesValues() { KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), ""); - expectThrows(UnsupportedOperationException.class, - () -> leafFieldData.getBytesValues()); + expectThrows(UnsupportedOperationException.class, () -> leafFieldData.getBytesValues()); } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java index 3460526187..8523c4146f 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorIndexFieldDataTests.java @@ -72,13 +72,14 @@ public void testLoadDirect() throws IOException { public void testSortField() { - expectThrows(UnsupportedOperationException.class, - () -> indexFieldData.sortField(null, null, null, false)); + expectThrows(UnsupportedOperationException.class, () -> indexFieldData.sortField(null, null, null, false)); } public void testNewBucketedSort() { - expectThrows(UnsupportedOperationException.class, - () -> indexFieldData.newBucketedSort(null, null, null, null, null, null, 0, null)); + expectThrows( + UnsupportedOperationException.class, + () -> indexFieldData.newBucketedSort(null, null, null, null, null, null, 0, null) + ); } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index e54e19141e..8761179409 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -23,7 +23,7 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { private static final String MOCK_INDEX_FIELD_NAME = "test-index-field-name"; - private static final float[] SAMPLE_VECTOR_DATA = new float[]{1.0f, 2.0f}; + private static final float[] SAMPLE_VECTOR_DATA = new float[] { 1.0f, 2.0f }; private KNNVectorScriptDocValues scriptDocValues; private Directory directory; private DirectoryReader reader; @@ -36,7 +36,9 @@ public void setUp() throws Exception { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); scriptDocValues = new KNNVectorScriptDocValues( - leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), MOCK_INDEX_FIELD_NAME); + leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME + ); } private void createKNNVectorDocument(Directory directory) throws IOException { @@ -44,9 +46,11 @@ private void createKNNVectorDocument(Directory directory) throws IOException { IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); knnDocument.add( - new BinaryDocValuesField( - MOCK_INDEX_FIELD_NAME, - new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue())); + new BinaryDocValuesField( + MOCK_INDEX_FIELD_NAME, + new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue() + ) + ); writer.addDocument(knnDocument); writer.commit(); writer.close(); @@ -64,8 +68,7 @@ public void testGetValue() throws IOException { Assert.assertArrayEquals(SAMPLE_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); } - - //Test getValue without calling setNextDocId + // Test getValue without calling setNextDocId public void testGetValueFails() throws IOException { expectThrows(IllegalStateException.class, () -> scriptDocValues.getValue()); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index dd7b82cf3d..47d13a87e2 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -19,7 +19,6 @@ import org.opensearch.knn.index.KNNWeight; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorField; -import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; import org.apache.lucene.codecs.Codec; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType; diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 0375b6b489..7f66e909a8 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -35,36 +35,35 @@ private List getTestQueryVector() { } public void testL2SquaredScoringFunction() { - float[] queryVector = {1.0f, 1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; Float distance = KNNScoringUtil.l2Squared(queryVector, inputVector); assertTrue(distance == 27.0f); } public void testWrongDimensionL2SquaredScoringFunction() { - float[] queryVector = {1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.l2Squared(queryVector, inputVector)); } public void testCosineSimilScoringFunction() { - float[] queryVector = {1.0f, 1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; float queryVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(queryVector); float inputVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(inputVector); float dotProduct = 12.0f; float expectedScore = (float) (dotProduct / (Math.sqrt(queryVectorMagnitude * inputVectorMagnitude))); - Float actualScore = KNNScoringUtil.cosinesimil(queryVector, inputVector); assertEquals(expectedScore, actualScore, 0.0001); } public void testCosineSimilOptimizedScoringFunction() { - float[] queryVector = {1.0f, 1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; float queryVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(queryVector); float inputVectorMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(inputVector); float dotProduct = 12.0f; @@ -86,26 +85,26 @@ public void testConvertInvalidVectorToPrimitive() { } public void testCosineSimilQueryVectorZeroMagnitude() { - float[] queryVector = {0, 0}; - float[] inputVector = {4.0f, 4.0f}; + float[] queryVector = { 0, 0 }; + float[] inputVector = { 4.0f, 4.0f }; assertEquals(0, KNNScoringUtil.cosinesimil(queryVector, inputVector), 0.00001); } public void testCosineSimilOptimizedQueryVectorZeroMagnitude() { - float[] inputVector = {4.0f, 4.0f}; - float[] queryVector = {0, 0}; + float[] inputVector = { 4.0f, 4.0f }; + float[] queryVector = { 0, 0 }; assertTrue(0 == KNNScoringUtil.cosinesimilOptimized(queryVector, inputVector, 0.0f)); } public void testWrongDimensionCosineSimilScoringFunction() { - float[] queryVector = {1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.cosinesimil(queryVector, inputVector)); } public void testWrongDimensionCosineSimilOPtimizedScoringFunction() { - float[] queryVector = {1.0f, 1.0f}; - float[] inputVector = {4.0f, 4.0f, 4.0f}; + float[] queryVector = { 1.0f, 1.0f }; + float[] inputVector = { 4.0f, 4.0f, 4.0f }; expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.cosinesimilOptimized(queryVector, inputVector, 1.0f)); } @@ -173,7 +172,7 @@ public void testBitHammingDistance_Long() { public void testL2SquaredWhitelistedScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); scriptDocValues.setNextDocId(0); Float distance = KNNScoringUtil.l2Squared(queryVector, scriptDocValues); @@ -184,7 +183,7 @@ public void testL2SquaredWhitelistedScoringFunction() throws IOException { public void testScriptDocValuesFailsL2() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); expectThrows(IllegalStateException.class, () -> KNNScoringUtil.l2Squared(queryVector, scriptDocValues)); dataset.close(); @@ -193,7 +192,7 @@ public void testScriptDocValuesFailsL2() throws IOException { public void testCosineSimilarityScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); scriptDocValues.setNextDocId(0); @@ -205,7 +204,7 @@ public void testCosineSimilarityScoringFunction() throws IOException { public void testScriptDocValuesFailsCosineSimilarity() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); expectThrows(IllegalStateException.class, () -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues)); dataset.close(); @@ -214,7 +213,7 @@ public void testScriptDocValuesFailsCosineSimilarity() throws IOException { public void testCosineSimilarityOptimizedScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); scriptDocValues.setNextDocId(0); Float actualScore = KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f); @@ -225,7 +224,7 @@ public void testCosineSimilarityOptimizedScoringFunction() throws IOException { public void testScriptDocValuesFailsCosineSimilarityOptimized() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); - dataset.createKNNVectorDocument(new float[]{4.0f, 4.0f, 4.0f}, "test-index-field-name"); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); expectThrows(IllegalStateException.class, () -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f)); dataset.close(); @@ -244,16 +243,14 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx if (scriptDocValues == null) { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues(leafReaderContext.reader().getBinaryDocValues(fieldName),fieldName ); + scriptDocValues = new KNNVectorScriptDocValues(leafReaderContext.reader().getBinaryDocValues(fieldName), fieldName); } return scriptDocValues; } public void close() throws IOException { - if (reader != null) - reader.close(); - if (directory != null) - directory.close(); + if (reader != null) reader.close(); + if (directory != null) directory.close(); } public void createKNNVectorDocument(final float[] content, final String fieldName) throws IOException { @@ -261,10 +258,7 @@ public void createKNNVectorDocument(final float[] content, final String fieldNam IndexWriter writer = new IndexWriter(directory, conf); conf.setMergePolicy(NoMergePolicy.INSTANCE); // prevent merges for this test Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( - fieldName, - new VectorField(fieldName, content, new FieldType()).binaryValue())); + knnDocument.add(new BinaryDocValuesField(fieldName, new VectorField(fieldName, content, new FieldType()).binaryValue())); writer.addDocument(knnDocument); writer.commit(); writer.close();