Skip to content

Commit

Permalink
Multiple innerHit in nested fields
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Nov 27, 2024
1 parent 7523cc3 commit d49ca8a
Show file tree
Hide file tree
Showing 39 changed files with 1,846 additions and 155 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x)
### Features
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
### Bug Fixes
Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ dependencies {
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.10'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3'
testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.15.4'
testFixturesImplementation 'com.jayway.jsonpath:json-path:2.8.0'
testFixturesImplementation "org.opensearch:common-utils:${version}"
implementation 'com.github.oshi:oshi-core:6.4.13'
api "net.java.dev.jna:jna:5.13.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public enum KNNEngine implements KNNLibrary {
private static final Set<KNNEngine> CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS);
private static final Set<KNNEngine> ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_MULTI_VECTORS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);

private static Map<KNNEngine, Integer> MAX_DIMENSIONS_BY_ENGINE = Map.of(
KNNEngine.NMSLIB,
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.BitSet;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
Expand Down Expand Up @@ -68,8 +67,8 @@ public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext,
if (exactSearcherContext.getKnnQuery().getRadius() != null) {
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator);
}
if (exactSearcherContext.getMatchedDocs() != null
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
if (exactSearcherContext.getMatchedDocsIterator() != null
&& exactSearcherContext.numberOfMatchedDocs <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
}
return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue());
Expand Down Expand Up @@ -155,7 +154,7 @@ private Map<Integer, Float> filterDocsByMinScore(ExactSearcherContext context, K

private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
final DocIdSetIterator matchedDocs = exactSearcherContext.getMatchedDocsIterator();
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
if (fieldInfo == null) {
Expand Down Expand Up @@ -245,7 +244,8 @@ public static class ExactSearcherContext {
*/
boolean useQuantizedVectorsForSearch;
int k;
BitSet matchedDocs;
DocIdSetIterator matchedDocsIterator;
long numberOfMatchedDocs;
KNNQuery knnQuery;
/**
* whether the matchedDocs contains parent ids or child ids. This is relevant in the case of
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class KNNQuery extends Query {

@Setter
private Query filterQuery;
@Getter
private BitSetProducer parentsFilter;
private Float radius;
private Context context;
Expand Down
45 changes: 37 additions & 8 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.common.QueryUtils;
import org.opensearch.knn.index.query.lucenelib.NestedKnnVectorQueryFactory;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

Expand All @@ -24,13 +26,13 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;
import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_MULTI_VECTORS;

/**
* Creates the Lucene k-NN queries
*/
@Log4j2
public class KNNQueryFactory extends BaseQueryFactory {

/**
* Creates a Lucene query for a particular engine.
* @param createQueryRequest request object that has all required fields to construct the query
Expand All @@ -48,11 +50,14 @@ public static Query create(CreateQueryRequest createQueryRequest) {
final Query filterQuery = getFilterQuery(createQueryRequest);
final Map<String, ?> methodParameters = createQueryRequest.getMethodParameters();
final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null);
final KNNEngine knnEngine = createQueryRequest.getKnnEngine();

BitSetProducer parentFilter = null;
boolean isInnerHitQuery = false;
if (createQueryRequest.getContext().isPresent()) {
QueryShardContext context = createQueryRequest.getContext().get();
parentFilter = context.getParentFilter();
isInnerHitQuery = context.isInnerHitQuery();
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
Expand Down Expand Up @@ -95,7 +100,14 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.rescoreContext(rescoreContext)
.build();
}
return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery;

if (createQueryRequest.getRescoreContext().isPresent()) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, isInnerHitQuery);
} else if (ENGINES_SUPPORTING_MULTI_VECTORS.contains(knnEngine) && isInnerHitQuery) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, isInnerHitQuery);
} else {
return knnQuery;
}
}

Integer requestEfSearch = null;
Expand All @@ -106,9 +118,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, isInnerHitQuery);
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter, isInnerHitQuery);
default:
throw new IllegalArgumentException(
String.format(
Expand Down Expand Up @@ -139,12 +151,21 @@ private static Query getKnnByteVectorQuery(
final byte[] byteVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
final BitSetProducer parentFilter,
final boolean isInnerHitQuery
) {
if (parentFilter == null) {
assert isInnerHitQuery == false;
return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
} else {
return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter);
return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(
fieldName,
byteVector,
k,
filterQuery,
parentFilter,
isInnerHitQuery
);
}
}

Expand All @@ -157,12 +178,20 @@ private static Query getKnnFloatVectorQuery(
final float[] floatVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
final BitSetProducer parentFilter,
final boolean isInnerHitQuery
) {
if (parentFilter == null) {
return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
} else {
return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter);
return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(
fieldName,
floatVector,
k,
filterQuery,
parentFilter,
isInnerHitQuery
);
}
}
}
22 changes: 14 additions & 8 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
Expand Down Expand Up @@ -66,6 +67,7 @@ public class KNNWeight extends Weight {
private final float boost;

private final NativeMemoryCacheManager nativeMemoryCacheManager;
@Getter
private final Weight filterWeight;
private final ExactSearcher exactSearcher;

Expand Down Expand Up @@ -140,15 +142,15 @@ public Map<Integer, Float> searchLeaf(LeafReaderContext context, int k) throws I
* This improves the recall.
*/
if (isFilteredExactSearchPreferred(cardinality)) {
return doExactSearch(context, filterBitSet, k);
return doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k);
}
Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k);
// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
// results less than K, though we have more than k filtered docs
if (isExactSearchRequire(context, cardinality, docIdsToScoreMap.size())) {
final BitSet docs = filterWeight != null ? filterBitSet : null;
return doExactSearch(context, docs, k);
final BitSetIterator docs = filterWeight != null ? new BitSetIterator(filterBitSet, cardinality) : null;
return doExactSearch(context, docs, cardinality, k);
}
return docIdsToScoreMap;
}
Expand Down Expand Up @@ -205,17 +207,21 @@ private int[] bitSetToIntArray(final BitSet bitSet) {
return intArray;
}

private Map<Integer, Float> doExactSearch(final LeafReaderContext context, final BitSet acceptedDocs, int k) throws IOException {
private Map<Integer, Float> doExactSearch(
final LeafReaderContext context,
final DocIdSetIterator acceptedDocs,
final long numberOfAcceptedDocs,
int k
) throws IOException {
final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder()
.isParentHits(true)
.k(k)
// setting to true, so that if quantization details are present we want to do search on the quantized
// vectors as this flow is used in first pass of search.
.useQuantizedVectorsForSearch(true)
.knnQuery(knnQuery);
if (acceptedDocs != null) {
exactSearcherContextBuilder.matchedDocs(acceptedDocs);
}
.knnQuery(knnQuery)
.matchedDocsIterator(acceptedDocs)
.numberOfMatchedDocs(numberOfAcceptedDocs);
return exactSearch(context, exactSearcherContextBuilder.build());
}

Expand Down
14 changes: 6 additions & 8 deletions src/main/java/org/opensearch/knn/index/query/ResultUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.DocIdSetBuilder;

import java.io.IOException;
Expand Down Expand Up @@ -58,19 +57,18 @@ public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k)
}

/**
* Convert map to bit set, if resultMap is empty or null then returns an Optional. Returning an optional here to
* ensure that the caller is aware that BitSet may not be present
* Convert map of docs to doc id set iterator
*
* @param resultMap Map of results
* @return BitSet of results; null is returned if the result map is empty
* @return Doc id set iterator
* @throws IOException If an error occurs during the search.
*/
public static BitSet resultMapToMatchBitSet(Map<Integer, Float> resultMap) throws IOException {
if (resultMap == null || resultMap.isEmpty()) {
return null;
public static DocIdSetIterator resultMapToDocIds(Map<Integer, Float> resultMap) throws IOException {
if (resultMap.isEmpty()) {
return DocIdSetIterator.empty();
}
final int maxDoc = Collections.max(resultMap.keySet()) + 1;
return BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc);
return resultMapToDocIds(resultMap, maxDoc);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.nativelib;
package org.opensearch.knn.index.query.common;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
Expand Down Expand Up @@ -32,7 +32,7 @@ final class DocAndScoreQuery extends Query {
private final int[] segmentStarts;
private final Object contextIdentity;

DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
public DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
this.k = k;
this.docs = docs;
this.scores = scores;
Expand Down
Loading

0 comments on commit d49ca8a

Please sign in to comment.