diff --git a/CHANGELOG.md b/CHANGELOG.md index 03a8d7974..3650de023 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x) ### Features +- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] ### Bug Fixes diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java b/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java index 7eca6287c..0c54cb370 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorSimilarityFunction.java @@ -29,7 +29,8 @@ public float compare(byte[] v1, byte[] v2) { @Override public VectorSimilarityFunction getVectorSimilarityFunction() { - throw new IllegalStateException("VectorSimilarityFunction is not available for Hamming space"); + // This is not used in binary case + return VectorSimilarityFunction.EUCLIDEAN; } }; diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index 4827a4582..a6051164a 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -40,7 +40,7 @@ public enum VectorDataType { @Override public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { - throw new IllegalStateException("Unsupported method"); + return KnnByteVectorField.createFieldType(dimension / 8, vectorSimilarityFunction); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 72187516f..f3a125838 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -114,7 +114,12 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { } } - KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth); + KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams( + params, + defaultMaxConnections, + defaultBeamWidth, + knnMethodContext.getSpaceType() + ); log.debug( "Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"", field, diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990BinaryVectorScorer.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990BinaryVectorScorer.java new file mode 100644 index 000000000..24f3081a8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990BinaryVectorScorer.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; + +import java.io.IOException; + +public class KNN990BinaryVectorScorer implements FlatVectorsScorer { + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessVectorValues randomAccessVectorValues + ) throws IOException { + assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes; + if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) { + return new BinaryRandomVectorScorerSupplier((RandomAccessVectorValues.Bytes) randomAccessVectorValues); + } + throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessVectorValues randomAccessVectorValues, + float[] queryVector + ) throws IOException { + throw new IllegalArgumentException("binary vectors do not support float[] targets"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessVectorValues randomAccessVectorValues, + byte[] queryVector + ) throws IOException { + assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes; + if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes) { + return new BinaryRandomVectorScorer((RandomAccessVectorValues.Bytes) randomAccessVectorValues, queryVector); + } + throw new IllegalArgumentException("vectorValues must be an instance of RandomAccessVectorValues.Bytes"); + } + + static class BinaryRandomVectorScorer implements RandomVectorScorer { + private final RandomAccessVectorValues.Bytes vectorValues; + private final int bitDimensions; + private final byte[] queryVector; + + BinaryRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) { + this.queryVector = query; + this.bitDimensions = vectorValues.dimension() * Byte.SIZE; + this.vectorValues = vectorValues; + } + + @Override + public float score(int node) throws IOException { + return (bitDimensions - VectorUtil.xorBitCount(queryVector, vectorValues.vectorValue(node))) / (float) bitDimensions; + } + + @Override + public int maxOrd() { + return vectorValues.size(); + } + + @Override + public int ordToDoc(int ord) { + return vectorValues.ordToDoc(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return vectorValues.getAcceptOrds(acceptDocs); + } + } + + static class BinaryRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + protected final RandomAccessVectorValues.Bytes vectorValues; + protected final RandomAccessVectorValues.Bytes vectorValues1; + protected final RandomAccessVectorValues.Bytes vectorValues2; + + public BinaryRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) throws IOException { + this.vectorValues = vectorValues; + this.vectorValues1 = vectorValues.copy(); + this.vectorValues2 = vectorValues.copy(); + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] queryVector = vectorValues1.vectorValue(ord); + return new BinaryRandomVectorScorer(vectorValues2, queryVector); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BinaryRandomVectorScorerSupplier(vectorValues.copy()); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990HnswBinaryVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990HnswBinaryVectorsFormat.java new file mode 100644 index 000000000..587ab2b84 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990HnswBinaryVectorsFormat.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +public class KNN990HnswBinaryVectorsFormat extends KnnVectorsFormat { + + private final int maxConn; + private final int beamWidth; + private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(new KNN990BinaryVectorScorer()); + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + private static final String NAME = "KNN990HnswBinaryVectorsFormat"; + + public KNN990HnswBinaryVectorsFormat() { + this(16, 100, 1, (ExecutorService) null); + } + + public KNN990HnswBinaryVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, 1, (ExecutorService) null); + } + + public KNN990HnswBinaryVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME); + if (maxConn > 0 && maxConn <= 512) { + if (beamWidth > 0 && beamWidth <= 3200) { + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge"); + } else { + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + + } + } else { + throw new IllegalArgumentException("beamWidth must be positive and less than or equal to 3200; beamWidth=" + beamWidth); + } + } else { + throw new IllegalArgumentException("maxConn must be positive and less than or equal to 512; maxConn=" + maxConn); + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter( + state, + this.maxConn, + this.beamWidth, + flatVectorsFormat.fieldsWriter(state), + this.numMergeWorkers, + this.mergeExec + ); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "KNN990HnswBinaryVectorsFormat(name=KNN990HnswBinaryVectorsFormat, maxConn=" + + this.maxConn + + ", beamWidth=" + + this.beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java index f565dfe5b..75289f489 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java @@ -8,6 +8,7 @@ import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat; import org.opensearch.knn.index.engine.KNNEngine; @@ -24,11 +25,17 @@ public KNN990PerFieldKnnVectorsFormat(final Optional mapperServic mapperService, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, - () -> new Lucene99HnswVectorsFormat(), - knnVectorsFormatParams -> new Lucene99HnswVectorsFormat( - knnVectorsFormatParams.getMaxConnections(), - knnVectorsFormatParams.getBeamWidth() - ), + Lucene99HnswVectorsFormat::new, + knnVectorsFormatParams -> { + if (knnVectorsFormatParams.getSpaceType() == SpaceType.HAMMING) { + return new KNN990HnswBinaryVectorsFormat( + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth() + ); + } else { + return new Lucene99HnswVectorsFormat(knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth()); + } + }, knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat( knnScalarQuantizedVectorsFormatParams.getMaxConnections(), knnScalarQuantizedVectorsFormatParams.getBeamWidth(), diff --git a/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java b/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java index 52134bc7e..ebf985fbb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java @@ -7,6 +7,7 @@ import lombok.Getter; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.SpaceType; import java.util.Map; @@ -17,10 +18,16 @@ public class KNNVectorsFormatParams { private int maxConnections; private int beamWidth; + private final SpaceType spaceType; public KNNVectorsFormatParams(final Map params, int defaultMaxConnections, int defaultBeamWidth) { + this(params, defaultMaxConnections, defaultBeamWidth, SpaceType.UNDEFINED); + } + + public KNNVectorsFormatParams(final Map params, int defaultMaxConnections, int defaultBeamWidth, SpaceType spaceType) { initMaxConnections(params, defaultMaxConnections); initBeamWidth(params, defaultBeamWidth); + this.spaceType = spaceType; } public boolean validate(final Map params) { diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java index 57cc016a6..701f79768 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java @@ -30,13 +30,18 @@ */ public class LuceneHNSWMethod extends AbstractKNNMethod { - private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT, VectorDataType.BYTE); + private static final Set SUPPORTED_DATA_TYPES = ImmutableSet.of( + VectorDataType.FLOAT, + VectorDataType.BYTE, + VectorDataType.BINARY + ); public final static List SUPPORTED_SPACES = Arrays.asList( SpaceType.UNDEFINED, SpaceType.L2, SpaceType.COSINESIMIL, - SpaceType.INNER_PRODUCT + SpaceType.INNER_PRODUCT, + SpaceType.HAMMING ); final static Encoder SQ_ENCODER = new LuceneSQEncoder(); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index dab2e08c8..7e597ad64 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -106,6 +106,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); switch (vectorDataType) { case BYTE: + case BINARY: return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter); case FLOAT: return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter); diff --git a/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index d799c3869..fbdb77887 100644 --- a/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -10,3 +10,4 @@ # org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat +org.opensearch.knn.index.codec.KNN990Codec.KNN990HnswBinaryVectorsFormat diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java index 4fb267eb5..dcd7e6da8 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.integ; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Floats; import lombok.SneakyThrows; @@ -27,6 +28,7 @@ import java.io.IOException; import java.net.URL; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -41,6 +43,16 @@ public class BinaryIndexIT extends KNNRestTestCase { private static TestUtils.TestData testData; private static final int NEVER_BUILD_GRAPH = -1; private static final int ALWAYS_BUILD_GRAPH = 0; + private final KNNEngine engine; + + public BinaryIndexIT(KNNEngine engine) { + this.engine = engine; + } + + @ParametersFactory + public static Collection parameters() { + return Arrays.asList(new Object[] { KNNEngine.LUCENE }, new Object[] { KNNEngine.FAISS }); + } @BeforeClass public static void setUpClass() throws IOException { @@ -68,7 +80,7 @@ public void cleanUp() { @SneakyThrows public void testFaissHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 16); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 16); // Ingest Byte[] vector1 = { 0b00000001, 0b00000001 }; @@ -95,7 +107,7 @@ public void testFaissHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() { @SneakyThrows public void testFaissHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 128); ingestTestData(INDEX_NAME, FIELD_NAME); int k = 100; @@ -112,7 +124,7 @@ public void testFaissHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() { @SneakyThrows public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_thenBuildGraphBasedOnSetting() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, NEVER_BUILD_GRAPH); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 128, NEVER_BUILD_GRAPH); ingestTestData(INDEX_NAME, FIELD_NAME); assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); @@ -135,7 +147,7 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_ @SneakyThrows public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length); ingestTestData(INDEX_NAME, FIELD_NAME, false); assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); @@ -158,7 +170,7 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_ @SneakyThrows public void testFaissHnswBinary_whenRadialSearch_thenThrowException() { // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 16); + createKnnHnswBinaryIndex(engine, INDEX_NAME, FIELD_NAME, 16); // Query float[] queryVector = { (byte) 0b10001111, (byte) 0b10000000 }; diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java index 29e710ec1..a706dd0cd 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexInvalidMappingIT.java @@ -46,11 +46,6 @@ public void cleanUp() { public static Collection parameters() throws IOException { return Arrays.asList( $$( - $( - "Creation of binary index with lucene engine should fail", - createKnnHnswBinaryIndexMapping(KNNEngine.LUCENE, FIELD_NAME, 16, null), - "Validation Failed" - ), $( "Creation of binary index with nmslib engine should fail", createKnnHnswBinaryIndexMapping(KNNEngine.NMSLIB, FIELD_NAME, 16, null),