Skip to content

Commit

Permalink
HSEARCH-5039 Make knn predicates append tenant filters when necessary
Browse files Browse the repository at this point in the history
(cherry picked from commit a2024e1)
  • Loading branch information
marko-bekhta authored and yrodiere committed Feb 26, 2024
1 parent 879c35d commit 9e06526
Show file tree
Hide file tree
Showing 20 changed files with 364 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.hibernate.search.backend.elasticsearch.gson.impl.JsonArrayAccessor;
import org.hibernate.search.backend.elasticsearch.gson.impl.JsonObjectAccessor;
import org.hibernate.search.backend.elasticsearch.logging.impl.Log;
import org.hibernate.search.backend.elasticsearch.lowlevel.query.impl.Queries;
import org.hibernate.search.backend.elasticsearch.search.common.impl.AbstractElasticsearchCodecAwareSearchQueryElementFactory;
import org.hibernate.search.backend.elasticsearch.search.common.impl.ElasticsearchSearchIndexScope;
import org.hibernate.search.backend.elasticsearch.search.common.impl.ElasticsearchSearchIndexValueFieldContext;
Expand Down Expand Up @@ -45,6 +46,19 @@ private ElasticsearchKnnPredicate(AbstractKnnBuilder<?> builder) {
builder.vector = null;
}

protected JsonObject prepareFilter(PredicateRequestContext context) {
String tenantIdentifier = context.getTenantId();
if ( tenantIdentifier != null ) {
JsonObject tenantFilter = context.getSearchIndexScope().filterOrNull( tenantIdentifier );
if ( tenantFilter != null ) {
JsonArray filters = new JsonArray();
filters.add( tenantFilter );
return filter == null ? tenantFilter : Queries.boolFilter( filter.toJsonQuery( context ), filters );
}
}
return filter == null ? null : filter.toJsonQuery( context );
}

public static class Elasticsearch812Factory<F>
extends AbstractElasticsearchCodecAwareSearchQueryElementFactory<KnnPredicateBuilder, F> {
public Elasticsearch812Factory(ElasticsearchFieldCodec<F> codec) {
Expand Down Expand Up @@ -163,8 +177,9 @@ protected JsonObject doToJsonQuery(PredicateRequestContext context, JsonObject o
NUM_CANDIDATES_ACCESSOR.set( innerObject, k );
VECTOR_ACCESSOR.set( innerObject, vector );

JsonObject filter = prepareFilter( context );
if ( filter != null ) {
FILTER_ACCESSOR.set( innerObject, filter.toJsonQuery( context ) );
FILTER_ACCESSOR.set( innerObject, filter );
}
if ( similarity != null ) {
SIMILARITY_ACCESSOR.set( innerObject, similarity );
Expand Down Expand Up @@ -210,8 +225,9 @@ protected JsonObject doToJsonQuery(PredicateRequestContext context, JsonObject o
KNN_ACCESSOR.set( outerObject, field );

field.add( absoluteFieldPath, innerObject );
JsonObject filter = prepareFilter( context );
if ( filter != null ) {
FILTER_ACCESSOR.set( innerObject, filter.toJsonQuery( context ) );
FILTER_ACCESSOR.set( innerObject, filter );
}
K_ACCESSOR.set( innerObject, k );
VECTOR_ACCESSOR.set( innerObject, vector );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,23 @@
*/
package org.hibernate.search.backend.elasticsearch.search.predicate.impl;

import org.hibernate.search.backend.elasticsearch.search.common.impl.ElasticsearchSearchIndexScope;
import org.hibernate.search.engine.backend.session.spi.BackendSessionContext;

public class PredicateRequestContext {

private final BackendSessionContext sessionContext;
private final ElasticsearchSearchIndexScope<?> searchIndexScope;
private final String nestedPath;

public PredicateRequestContext(BackendSessionContext sessionContext) {
this.sessionContext = sessionContext;
this.nestedPath = null;
public PredicateRequestContext(BackendSessionContext sessionContext, ElasticsearchSearchIndexScope<?> searchIndexScope) {
this( sessionContext, searchIndexScope, null );
}

private PredicateRequestContext(BackendSessionContext sessionContext, String nestedPath) {
private PredicateRequestContext(BackendSessionContext sessionContext, ElasticsearchSearchIndexScope<?> searchIndexScope,
String nestedPath) {
this.sessionContext = sessionContext;
this.searchIndexScope = searchIndexScope;
this.nestedPath = nestedPath;
}

Expand All @@ -31,7 +34,11 @@ String getTenantId() {
return sessionContext.tenantIdentifier();
}

public ElasticsearchSearchIndexScope<?> getSearchIndexScope() {
return searchIndexScope;
}

public PredicateRequestContext withNestedPath(String path) {
return new PredicateRequestContext( sessionContext, path );
return new PredicateRequestContext( sessionContext, searchIndexScope, path );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public class ElasticsearchSearchQueryBuilder<H>
private final Integer scrollTimeout;

private final Set<String> routingKeys;
private JsonObject jsonPredicate;
private ElasticsearchSearchPredicate elasticsearchPredicate;
private JsonArray jsonSort;
private Map<DistanceSortKey, Integer> distanceSorts;
private Map<AggregationKey<?>, ElasticsearchSearchAggregation<?>> aggregations;
Expand Down Expand Up @@ -101,7 +101,7 @@ public ElasticsearchSearchQueryBuilder(
this.sessionContext = sessionContext;
this.routingKeys = new HashSet<>();

this.rootPredicateContext = new PredicateRequestContext( sessionContext );
this.rootPredicateContext = new PredicateRequestContext( sessionContext, scope );
this.loadingContextBuilder = loadingContextBuilder;
this.rootProjection = rootProjection;
this.scrollTimeout = scrollTimeout;
Expand All @@ -110,7 +110,7 @@ public ElasticsearchSearchQueryBuilder(
@Override
public void predicate(SearchPredicate predicate) {
ElasticsearchSearchPredicate elasticsearchPredicate = ElasticsearchSearchPredicate.from( scope, predicate );
this.jsonPredicate = elasticsearchPredicate.toJsonQuery( rootPredicateContext );
this.elasticsearchPredicate = elasticsearchPredicate;
}

@Override
Expand Down Expand Up @@ -232,6 +232,8 @@ public ElasticsearchSearchQuery<H> build() {
filters.add( Queries.anyTerm( "_routing", routingKeys ) );
}

JsonObject jsonPredicate = elasticsearchPredicate.toJsonQuery( rootPredicateContext );

JsonObject jsonQuery = Queries.boolFilter( jsonPredicate, filters );
if ( jsonQuery != null ) {
payload.add( "query", jsonQuery );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import org.hibernate.search.backend.lucene.lowlevel.collector.impl.CollectorKey;
import org.hibernate.search.backend.lucene.lowlevel.join.impl.NestedDocsProvider;
import org.hibernate.search.backend.lucene.search.extraction.impl.CollectorSet;
import org.hibernate.search.backend.lucene.search.predicate.impl.PredicateRequestContext;
import org.hibernate.search.backend.lucene.search.query.impl.LuceneSearchQueryIndexScope;
import org.hibernate.search.engine.backend.session.spi.BackendSessionContext;
import org.hibernate.search.engine.backend.types.converter.runtime.FromDocumentValueConvertContext;

import org.apache.lucene.index.IndexReader;
Expand All @@ -17,18 +20,27 @@

public class AggregationExtractContext {

private final LuceneSearchQueryIndexScope<?> queryIndexScope;
private final BackendSessionContext sessionContext;
private final IndexReader indexReader;
private final FromDocumentValueConvertContext fromDocumentValueConvertContext;
private final CollectorSet collectors;

public AggregationExtractContext(IndexReader indexReader,
public AggregationExtractContext(LuceneSearchQueryIndexScope<?> queryIndexScope, BackendSessionContext sessionContext,
IndexReader indexReader,
FromDocumentValueConvertContext fromDocumentValueConvertContext,
CollectorSet collectors) {
this.queryIndexScope = queryIndexScope;
this.sessionContext = sessionContext;
this.indexReader = indexReader;
this.fromDocumentValueConvertContext = fromDocumentValueConvertContext;
this.collectors = collectors;
}

public PredicateRequestContext toPredicateRequestContext(String absolutePath) {
return PredicateRequestContext.withSession( queryIndexScope, sessionContext ).withNestedPath( absolutePath );
}

public IndexReader getIndexReader() {
return indexReader;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public final Query toQuery(PredicateRequestContext context) {
// We'll make sure to wrap it in nested predicates as appropriate in the next few lines,
// so that the Query is actually executed in this context.
PredicateRequestContext contextAfterImplicitNesting =
new PredicateRequestContext( expectedNestedPath );
context.withNestedPath( expectedNestedPath );

Query result = super.toQuery( contextAfterImplicitNesting );

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.lang.reflect.Array;

import org.hibernate.search.backend.lucene.logging.impl.Log;
import org.hibernate.search.backend.lucene.lowlevel.query.impl.Queries;
import org.hibernate.search.backend.lucene.lowlevel.query.impl.VectorSimilarityFilterQuery;
import org.hibernate.search.backend.lucene.search.common.impl.AbstractLuceneValueFieldSearchQueryElementFactory;
import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexScope;
Expand Down Expand Up @@ -47,24 +48,34 @@ private LuceneKnnPredicate(Builder<?> builder) {

@Override
protected Query doToQuery(PredicateRequestContext context) {

if ( vector instanceof byte[] ) {
byte[] byteVector = (byte[]) vector;
KnnByteVectorQuery query = new KnnByteVectorQuery(
absoluteFieldPath, byteVector, k, filter == null ? null : filter.toQuery( context ) );
KnnByteVectorQuery query = new KnnByteVectorQuery( absoluteFieldPath, byteVector, k, prepareFilter( context ) );
return similarity == null
? query
: VectorSimilarityFilterQuery.create( query, similarity, byteVector.length, similarityFunction );
}
if ( vector instanceof float[] ) {
KnnFloatVectorQuery query = new KnnFloatVectorQuery(
absoluteFieldPath, (float[]) vector, k, filter == null ? null : filter.toQuery( context ) );
KnnFloatVectorQuery query =
new KnnFloatVectorQuery( absoluteFieldPath, (float[]) vector, k, prepareFilter( context ) );
return similarity == null ? query : VectorSimilarityFilterQuery.create( query, similarity, similarityFunction );
}

throw new UnsupportedOperationException(
"Unknown vector type " + vector.getClass() + ". only byte[] and float[] vectors are supported." );
}

private Query prepareFilter(PredicateRequestContext context) {
Query tenantFilter = context.tenantFilterOrNull();
if ( tenantFilter != null ) {
return filter == null ? tenantFilter : Queries.boolFilter( filter.toQuery( context ), tenantFilter );
}
else {
return filter == null ? null : filter.toQuery( context );
}
}

public static class DefaultFactory<F>
extends AbstractLuceneValueFieldSearchQueryElementFactory<KnnPredicateBuilder, F> {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ private LuceneNestedPredicate(Builder builder) {

@Override
protected Query doToQuery(PredicateRequestContext context) {
PredicateRequestContext childContext = new PredicateRequestContext( absoluteFieldPath );
PredicateRequestContext childContext = context.withNestedPath( absoluteFieldPath );
return createNestedQuery( context.getNestedPath(), absoluteFieldPath, nestedPredicate.toQuery( childContext ) );
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,85 @@
*/
package org.hibernate.search.backend.lucene.search.predicate.impl;

public class PredicateRequestContext {
import java.lang.invoke.MethodHandles;

private static final PredicateRequestContext ROOT = new PredicateRequestContext( null );
import org.hibernate.search.backend.lucene.logging.impl.Log;
import org.hibernate.search.backend.lucene.search.query.impl.LuceneSearchQueryIndexScope;
import org.hibernate.search.engine.backend.session.spi.BackendSessionContext;
import org.hibernate.search.util.common.AssertionFailure;
import org.hibernate.search.util.common.impl.Contracts;
import org.hibernate.search.util.common.logging.impl.LoggerFactory;

import org.apache.lucene.search.Query;

public abstract class PredicateRequestContext {
private static final Log log = LoggerFactory.make( Log.class, MethodHandles.lookup() );
private final String nestedPath;

public PredicateRequestContext(String nestedPath) {
private PredicateRequestContext(String nestedPath) {
this.nestedPath = nestedPath;
}

public String getNestedPath() {
return nestedPath;
}

public static PredicateRequestContext root() {
return ROOT;
public abstract Query tenantFilterOrNull();

public abstract PredicateRequestContext withNestedPath(String nestedPath);

public static PredicateRequestContext withSession(LuceneSearchQueryIndexScope<?> scope,
BackendSessionContext sessionContext) {
Contracts.assertNotNull( scope, "scope" );
Contracts.assertNotNull( scope, "sessionContext" );
return new FullPredicateRequestContext( null, scope, sessionContext );
}

public static PredicateRequestContext withoutSession() {
return new LimitedPredicateRequestContext( null );
}

private static class LimitedPredicateRequestContext extends PredicateRequestContext {

public LimitedPredicateRequestContext(String nestedPath) {
super( nestedPath );
}

@Override
public Query tenantFilterOrNull() {
// this context is created via migration utils, where the predicates are created as queries,
// hence we should not expect that a knn predicate is passed in.
// Alternatively it can be created in a place which we have total control over, and we only need to create an exists predicate,
// which does not need the session context anyway.
throw new AssertionFailure( "A tenant/routing filter requires session context." );
}

@Override
public PredicateRequestContext withNestedPath(String nestedPath) {
return new LimitedPredicateRequestContext( nestedPath );
}
}

private static class FullPredicateRequestContext extends PredicateRequestContext {
private final LuceneSearchQueryIndexScope<?> scope;

private final BackendSessionContext sessionContext;

private FullPredicateRequestContext(String nestedPath, LuceneSearchQueryIndexScope<?> scope,
BackendSessionContext sessionContext) {
super( nestedPath );
this.scope = scope;
this.sessionContext = sessionContext;
}

public Query tenantFilterOrNull() {
String tenantIdentifier = sessionContext.tenantIdentifier();

return tenantIdentifier == null ? null : scope.filterOrNull( tenantIdentifier );
}

public PredicateRequestContext withNestedPath(String nestedPath) {
return new FullPredicateRequestContext( nestedPath, scope, sessionContext );
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ public Builder create(LuceneSearchIndexScope<?> scope, LuceneSearchIndexComposit
}
try {
filter = LuceneSearchPredicate.from( scope, node.queryElement( PredicateTypeKeys.EXISTS, scope ).build() )
.toQuery( PredicateRequestContext.root() );
// We are creating an exists predicate that does not need any session info,
// hence it should be safe here to use the context without session:
.toQuery( PredicateRequestContext.withoutSession() );
}
catch (SearchException e) {
throw node.cannotUseQueryElement( ProjectionTypeKeys.OBJECT, e.getMessage(), e );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ private List<Object> extractHits(ProjectionHitMapper<?> projectionHitMapper, int

private Map<AggregationKey<?>, ?> extractAggregations() throws IOException {
AggregationExtractContext aggregationExtractContext = new AggregationExtractContext(
requestContext.getQueryIndexScope(),
requestContext.getSessionContext(),
indexSearcher.getIndexReader(),
fromDocumentValueConvertContext,
luceneCollectors.getCollectorsForAllMatchingDocs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public LuceneSearchQueryBuilder(
@Override
public void predicate(SearchPredicate predicate) {
LuceneSearchPredicate lucenePredicate = LuceneSearchPredicate.from( scope, predicate );
this.luceneQuery = lucenePredicate.toQuery( PredicateRequestContext.root() );
this.luceneQuery = lucenePredicate.toQuery( PredicateRequestContext.withSession( scope, sessionContext ) );
}

@Override
Expand Down Expand Up @@ -194,6 +194,11 @@ public void collectSortFields(SortField[] sortFields) {
Collections.addAll( this.sortFields, sortFields );
}

@Override
public PredicateRequestContext toPredicateRequestContext(String absoluteNestedPath) {
return PredicateRequestContext.withSession( scope, sessionContext ).withNestedPath( absoluteNestedPath );
}

@Override
public LuceneSearchQuery<H> build() {
SearchLoadingContext<?> loadingContext = loadingContextBuilder.build();
Expand Down Expand Up @@ -222,7 +227,7 @@ public LuceneSearchQuery<H> build() {
}

LuceneSearchQueryRequestContext requestContext = new LuceneSearchQueryRequestContext(
sessionContext, loadingContext, definitiveLuceneQuery, luceneSort
scope, sessionContext, loadingContext, definitiveLuceneQuery, luceneSort
);

LuceneAbstractSearchHighlighter resolvedGlobalHighlighter =
Expand Down
Loading

0 comments on commit 9e06526

Please sign in to comment.