Skip to content

Commit

Permalink
Multiple innerHit in nested fields
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Nov 18, 2024
1 parent a07bad1 commit 549a74e
Show file tree
Hide file tree
Showing 32 changed files with 1,061 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public enum KNNEngine implements KNNLibrary {
private static final Set<KNNEngine> CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS);
private static final Set<KNNEngine> ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_MULTI_VECTORS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);

private static Map<KNNEngine, Integer> MAX_DIMENSIONS_BY_ENGINE = Map.of(
KNNEngine.NMSLIB,
Expand Down
13 changes: 10 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.knn.indices.ModelDao;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
Expand All @@ -60,11 +61,14 @@ public class ExactSearcher {
public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext)
throws IOException {
KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
if (iterator == null) {
return Collections.emptyMap();
}
if (exactSearcherContext.getKnnQuery().getRadius() != null) {
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator);
}
if (exactSearcherContext.getMatchedDocs() != null
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
&& exactSearcherContext.getMatchedDocs().cost() <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
}
return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue());
Expand Down Expand Up @@ -147,9 +151,12 @@ private Map<Integer, Float> filterDocsByMinScore(ExactSearcherContext context, K

private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
final DocIdSetIterator matchedDocs = exactSearcherContext.getMatchedDocs();
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
if (fieldInfo == null) {
return null;
}
final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);

boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null;
Expand Down Expand Up @@ -233,7 +240,7 @@ public static class ExactSearcherContext {
*/
boolean useQuantizedVectorsForSearch;
int k;
BitSet matchedDocs;
DocIdSetIterator matchedDocs;
KNNQuery knnQuery;
/**
* whether the matchedDocs contains parent ids or child ids. This is relevant in the case of
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class KNNQuery extends Query {

@Setter
private Query filterQuery;
@Getter
private BitSetProducer parentsFilter;
private Float radius;
private Context context;
Expand Down
34 changes: 26 additions & 8 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.common.QueryUtils;
import org.opensearch.knn.index.query.lucenelib.NestedKnnVectorQueryFactory;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

Expand All @@ -24,13 +26,14 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;
import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_MULTI_VECTORS;

/**
* Creates the Lucene k-NN queries
*/
@Log4j2
public class KNNQueryFactory extends BaseQueryFactory {

private static final QueryUtils QUERY_UTILS = new QueryUtils();
/**
* Creates a Lucene query for a particular engine.
* @param createQueryRequest request object that has all required fields to construct the query
Expand All @@ -48,11 +51,14 @@ public static Query create(CreateQueryRequest createQueryRequest) {
final Query filterQuery = getFilterQuery(createQueryRequest);
final Map<String, ?> methodParameters = createQueryRequest.getMethodParameters();
final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null);
final KNNEngine knnEngine = createQueryRequest.getKnnEngine();

BitSetProducer parentFilter = null;
boolean isInnerHitQuery = false;
if (createQueryRequest.getContext().isPresent()) {
QueryShardContext context = createQueryRequest.getContext().get();
parentFilter = context.getParentFilter();
isInnerHitQuery = context.isInnerHitQuery();
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
Expand Down Expand Up @@ -95,7 +101,16 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.rescoreContext(rescoreContext)
.build();
}
return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery;

if (createQueryRequest.getRescoreContext().isPresent()) {
return new NativeEngineKnnVectorQuery(knnQuery, QUERY_UTILS, isInnerHitQuery);
}
else if (ENGINES_SUPPORTING_MULTI_VECTORS.contains(knnEngine) && isInnerHitQuery) {
return new NativeEngineKnnVectorQuery(knnQuery, QUERY_UTILS, isInnerHitQuery);
}
else {
return knnQuery;
}
}

Integer requestEfSearch = null;
Expand All @@ -106,9 +121,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, isInnerHitQuery);
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter, isInnerHitQuery);
default:
throw new IllegalArgumentException(
String.format(
Expand Down Expand Up @@ -139,12 +154,14 @@ private static Query getKnnByteVectorQuery(
final byte[] byteVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
final BitSetProducer parentFilter,
final boolean isInnerHitQuery
) {
if (parentFilter == null) {
assert isInnerHitQuery == false;
return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
} else {
return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter);
return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter, isInnerHitQuery);
}
}

Expand All @@ -157,12 +174,13 @@ private static Query getKnnFloatVectorQuery(
final float[] floatVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
final BitSetProducer parentFilter,
final boolean isInnerHitQuery
) {
if (parentFilter == null) {
return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
} else {
return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter);
return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(fieldName, floatVector, k, filterQuery, parentFilter, isInnerHitQuery);
}
}
}
15 changes: 8 additions & 7 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.knn.index.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
Expand Down Expand Up @@ -66,6 +68,7 @@ public class KNNWeight extends Weight {
private final float boost;

private final NativeMemoryCacheManager nativeMemoryCacheManager;
@Getter
private final Weight filterWeight;
private final ExactSearcher exactSearcher;

Expand Down Expand Up @@ -140,14 +143,14 @@ public Map<Integer, Float> searchLeaf(LeafReaderContext context, int k) throws I
* This improves the recall.
*/
if (isFilteredExactSearchPreferred(cardinality)) {
return doExactSearch(context, filterBitSet, k);
return doExactSearch(context, new BitSetIterator(filterBitSet, filterBitSet.length()), k);
}
Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k);
// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
// results less than K, though we have more than k filtered docs
if (isExactSearchRequire(context, cardinality, docIdsToScoreMap.size())) {
final BitSet docs = filterWeight != null ? filterBitSet : null;
final BitSetIterator docs = filterWeight != null ? new BitSetIterator(filterBitSet, filterBitSet.length()) : null;
return doExactSearch(context, docs, k);
}
return docIdsToScoreMap;
Expand Down Expand Up @@ -205,17 +208,15 @@ private int[] bitSetToIntArray(final BitSet bitSet) {
return intArray;
}

private Map<Integer, Float> doExactSearch(final LeafReaderContext context, final BitSet acceptedDocs, int k) throws IOException {
private Map<Integer, Float> doExactSearch(final LeafReaderContext context, final DocIdSetIterator acceptedDocs, int k) throws IOException {
final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder()
.isParentHits(true)
.k(k)
// setting to true, so that if quantization details are present we want to do search on the quantized
// vectors as this flow is used in first pass of search.
.useQuantizedVectorsForSearch(true)
.knnQuery(knnQuery);
if (acceptedDocs != null) {
exactSearcherContextBuilder.matchedDocs(acceptedDocs);
}
.knnQuery(knnQuery)
.matchedDocs(acceptedDocs);
return exactSearch(context, exactSearcherContextBuilder.build());
}

Expand Down
10 changes: 5 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/ResultUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,19 @@ public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k)
}

/**
* Convert map to bit set
* Convert map of docs to doc id set iterator
*
* @param resultMap Map of results
* @return BitSet of results
* @return Doc id set iterator
* @throws IOException If an error occurs during the search.
*/
public static BitSet resultMapToMatchBitSet(Map<Integer, Float> resultMap) throws IOException {
public static DocIdSetIterator resultMapToDocIds(Map<Integer, Float> resultMap) throws IOException {
if (resultMap.isEmpty()) {
return BitSet.of(DocIdSetIterator.empty(), 0);
return DocIdSetIterator.empty();
}

final int maxDoc = Collections.max(resultMap.keySet()) + 1;
return BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc);
return resultMapToDocIds(resultMap, maxDoc);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.nativelib;
package org.opensearch.knn.index.query.common;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
Expand All @@ -24,15 +24,15 @@
/**
* This is the same as {@link org.apache.lucene.search.AbstractKnnVectorQuery.DocAndScoreQuery}
*/
final class DocAndScoreQuery extends Query {
final public class DocAndScoreQuery extends Query {

private final int k;
private final int[] docs;
private final float[] scores;
private final int[] segmentStarts;
private final Object contextIdentity;

DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
public DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
this.k = k;
this.docs = docs;
this.scores = scores;
Expand Down
Loading

0 comments on commit 549a74e

Please sign in to comment.