Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filter option for query type #88

Merged
merged 2 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.query;

import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD;
import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray;

import java.io.IOException;
Expand Down Expand Up @@ -80,6 +81,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
@Getter(AccessLevel.PACKAGE)
@Setter(AccessLevel.PACKAGE)
private Supplier<float[]> vectorSupplier;
private QueryBuilder filter;

/**
* Constructor from stream input
Expand All @@ -93,6 +95,7 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
this.queryText = in.readString();
this.modelId = in.readString();
this.k = in.readVInt();
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
}

@Override
Expand All @@ -101,6 +104,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(this.queryText);
out.writeString(this.modelId);
out.writeVInt(this.k);
out.writeOptionalNamedWriteable(this.filter);
}

@Override
Expand All @@ -110,6 +114,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
xContentBuilder.field(K_FIELD.getPreferredName(), k);
if (filter != null) {
xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter);
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
Expand All @@ -125,7 +132,8 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
* "model_id": "string",
* "k": int,
* "name": "string", (optional)
* "boost": float (optional)
* "boost": float (optional),
* "filter": map (optional)
* }
* }
*
Expand Down Expand Up @@ -184,6 +192,10 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n
"[" + NAME + "] query does not support [" + currentFieldName + "]"
);
}
} else if (token == XContentParser.Token.START_OBJECT) {
if (FILTER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
neuralQueryBuilder.filter(parseInnerQueryBuilder(parser));
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand All @@ -205,7 +217,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
// create a new builder. Once the supplier's value gets set, we return a KNNQueryBuilder. Otherwise, we just
// return the current unmodified query builder.
if (vectorSupplier() != null) {
return vectorSupplier().get() == null ? this : new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k());
return vectorSupplier().get() == null ? this : new KNNQueryBuilder(fieldName(), vectorSupplier.get(), k(), filter());
}

SetOnce<float[]> vectorSetOnce = new SetOnce<>();
Expand All @@ -215,7 +227,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
actionListener.onResponse(null);
}, actionListener::onFailure)))
);
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get);
return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get, filter());
}

@Override
Expand All @@ -233,6 +245,7 @@ protected boolean doEquals(NeuralQueryBuilder obj) {
equalsBuilder.append(queryText, obj.queryText);
equalsBuilder.append(modelId, obj.modelId);
equalsBuilder.append(k, obj.k);
equalsBuilder.append(filter, obj.filter);
return equalsBuilder.isEquals();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,19 @@ protected Map<String, Object> getFirstInnerHit(Map<String, Object> searchRespons
return (Map<String, Object>) hits2List.get(0);
}

/**
* Parse the total number of hits from the search
*
* @param searchResponseAsMap Complete search response as a map
* @return number of hits from the search
*/
@SuppressWarnings("unchecked")
protected int getHitCount(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hits1map = (Map<String, Object>) searchResponseAsMap.get("hits");
List<Object> hits1List = (List<Object>) hits1map.get("hits");
return hits1List.size();
}

/**
* Create a k-NN index from a list of KNNFieldConfigs
*
Expand Down
Loading