Skip to content

Commit

Permalink
Add support neural query type (#674)
Browse files Browse the repository at this point in the history
* Add support neural query type

Signed-off-by: Kirill_Ostanin <[email protected]>

* fix spotless

Signed-off-by: Kirill Ostanin <[email protected]>

* Update CHANGELOG.md

Signed-off-by: Kirill Ostanin <[email protected]>

* refactoring

Signed-off-by: Kirill Ostanin <[email protected]>

* attempt to restart integration tests

Signed-off-by: Kirill Ostanin <[email protected]>

---------

Signed-off-by: Kirill_Ostanin <[email protected]>
Signed-off-by: Kirill Ostanin <[email protected]>
Co-authored-by: Kirill Ostanin <[email protected]>
  • Loading branch information
GranT1337 and Kirill Ostanin authored Oct 24, 2023
1 parent 9a4273f commit 4603af9
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ This section is for maintaining a changelog for all breaking changes for the cli
- Added support for "script_fields" in multi search request ([#632](https://github.com/opensearch-project/opensearch-java/pull/632))
- Added size attribute to MultiTermsAggregation ([#627](https://github.com/opensearch-project/opensearch-java/pull/627))
- Added version increment workflow that executes after release ([#664](https://github.com/opensearch-project/opensearch-java/pull/664))
- Added support for neural query type ([#674](https://github.com/opensearch-project/opensearch-java/pull/674))

### Dependencies
- Bumps `org.ajoberstar.grgit:grgit-gradle` from 5.0.0 to 5.2.0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.client.opensearch._types.query_dsl;

import jakarta.json.stream.JsonGenerator;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.opensearch.client.json.JsonpDeserializable;
import org.opensearch.client.json.JsonpDeserializer;
import org.opensearch.client.json.JsonpMapper;
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.ObjectBuilder;

@JsonpDeserializable
public class NeuralQuery extends QueryBase implements QueryVariant {

private final String field;
private final String queryText;
private final int k;
@Nullable
private final String modelId;

private NeuralQuery(NeuralQuery.Builder builder) {
super(builder);

this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field");
this.queryText = ApiTypeHelper.requireNonNull(builder.queryText, this, "queryText");
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.modelId = builder.modelId;
}

public static NeuralQuery of(Function<NeuralQuery.Builder, ObjectBuilder<NeuralQuery>> fn) {
return fn.apply(new NeuralQuery.Builder()).build();
}

/**
* Query variant kind.
*
* @return The query variant kind.
*/
@Override
public Query.Kind _queryKind() {
return Query.Kind.Neural;
}

/**
* Required - The target field.
*
* @return The target field.
*/
public final String field() {
return this.field;
}

/**
* Required - Search query text.
*
* @return Search query text.
*/
public final String queryText() {
return this.queryText;
}

/**
* Required - The number of neighbors to return.
*
* @return The number of neighbors to return.
*/
public final int k() {
return this.k;
}

/**
* Builder for {@link NeuralQuery}.
*/

/**
* Optional - The model_id field if the default model for the index or field is set.
* Required - The model_id field if there is no default model set for the index or field.
*
* @return The model_id field.
*/
@Nullable
public final String modelId() {
return this.modelId;
}

@Override
protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.writeStartObject(this.field);

super.serializeInternal(generator, mapper);

generator.write("query_text", this.queryText);

if (this.modelId != null) {
generator.write("model_id", this.modelId);
}

generator.write("k", this.k);

generator.writeEnd();
}

/**
* Builder for {@link NeuralQuery}.
*/
public static class Builder extends QueryBase.AbstractBuilder<NeuralQuery.Builder> implements ObjectBuilder<NeuralQuery> {
private String field;
private String queryText;
private Integer k;
@Nullable
private String modelId;

/**
* Required - The target field.
*
* @param field The target field.
* @return This builder.
*/
public NeuralQuery.Builder field(@Nullable String field) {
this.field = field;
return this;
}

/**
* Required - Search query text.
*
* @param queryText Search query text.
* @return This builder.
*/
public NeuralQuery.Builder queryText(@Nullable String queryText) {
this.queryText = queryText;
return this;
}

/**
* Optional - The model_id field if the default model for the index or field is set.
* Required - The model_id field if there is no default model set for the index or field.
*
* @param modelId The model_id field.
* @return This builder.
*/
public NeuralQuery.Builder modelId(@Nullable String modelId) {
this.modelId = modelId;
return this;
}

/**
* Required - The number of neighbors to return.
*
* @param k The number of neighbors to return.
* @return This builder.
*/
public NeuralQuery.Builder k(@Nullable Integer k) {
this.k = k;
return this;
}

@Override
protected NeuralQuery.Builder self() {
return this;
}

/**
* Builds a {@link NeuralQuery}.
*
* @return The built {@link NeuralQuery}.
*/
@Override
public NeuralQuery build() {
_checkSingleUse();

return new NeuralQuery(this);
}
}

public static final JsonpDeserializer<NeuralQuery> _DESERIALIZER = ObjectBuilderDeserializer.lazy(
NeuralQuery.Builder::new,
NeuralQuery::setupNeuralQueryDeserializer
);

protected static void setupNeuralQueryDeserializer(ObjectDeserializer<NeuralQuery.Builder> op) {
setupQueryBaseDeserializer(op);

op.add(NeuralQuery.Builder::queryText, JsonpDeserializer.stringDeserializer(), "query_text");
op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k");

op.setKey(NeuralQuery.Builder::field, JsonpDeserializer.stringDeserializer());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ public enum Kind implements JsonEnum {

Nested("nested"),

Neural("neural"),

ParentId("parent_id"),

Percolate("percolate"),
Expand Down Expand Up @@ -706,6 +708,23 @@ public NestedQuery nested() {
return TaggedUnionUtils.get(this, Kind.Nested);
}

/**
* Is this variant instance of kind {@code neural}?
*/
public boolean isNeural() {
return _kind == Kind.Neural;
}

/**
* Get the {@code neural} variant value.
*
* @throws IllegalStateException
* if the current variant is not of the {@code neural} kind.
*/
public NeuralQuery neural() {
return TaggedUnionUtils.get(this, Kind.Neural);
}

/**
* Is this variant instance of kind {@code parent_id}?
*/
Expand Down Expand Up @@ -1450,6 +1469,16 @@ public ObjectBuilder<Query> nested(Function<NestedQuery.Builder, ObjectBuilder<N
return this.nested(fn.apply(new NestedQuery.Builder()).build());
}

public ObjectBuilder<Query> neural(NeuralQuery v) {
this._kind = Kind.Neural;
this._value = v;
return this;
}

public ObjectBuilder<Query> neural(Function<NeuralQuery.Builder, ObjectBuilder<NeuralQuery>> fn) {
return this.neural(fn.apply(new NeuralQuery.Builder()).build());
}

public ObjectBuilder<Query> parentId(ParentIdQuery v) {
this._kind = Kind.ParentId;
this._value = v;
Expand Down Expand Up @@ -1747,6 +1776,7 @@ protected static void setupQueryDeserializer(ObjectDeserializer<Builder> op) {
op.add(Builder::moreLikeThis, MoreLikeThisQuery._DESERIALIZER, "more_like_this");
op.add(Builder::multiMatch, MultiMatchQuery._DESERIALIZER, "multi_match");
op.add(Builder::nested, NestedQuery._DESERIALIZER, "nested");
op.add(Builder::neural, NeuralQuery._DESERIALIZER, "neural");
op.add(Builder::parentId, ParentIdQuery._DESERIALIZER, "parent_id");
op.add(Builder::percolate, PercolateQuery._DESERIALIZER, "percolate");
op.add(Builder::pinned, PinnedQuery._DESERIALIZER, "pinned");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ public static NestedQuery.Builder nested() {
return new NestedQuery.Builder();
}

/**
* Creates a builder for the {@link NeuralQuery nested} {@code Query} variant.
*/
public static NeuralQuery.Builder neural() {
return new NeuralQuery.Builder();
}

/**
* Creates a builder for the {@link ParentIdQuery parent_id} {@code Query}
* variant.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,42 @@ public void testNestedVariantsWithContainerProperties() {
assertEquals("m1 value", search.aggregations().get("agg1").meta().get("m1").to(String.class));
assertEquals("m2 value", search.aggregations().get("agg1").meta().get("m2").to(String.class));
}

@Test
public void testNeuralQuery() {

SearchRequest searchRequest = SearchRequest.of(
s -> s.query(q -> q.neural(n -> n.field("passage_embedding").queryText("Hi world").modelId("bQ1J8ooBpBj3wT4HVUsb").k(100)))
);

assertEquals("passage_embedding", searchRequest.query().neural().field());
assertEquals("Hi world", searchRequest.query().neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
}

@Test
public void testNeuralQueryFromJson() {

String json = "{\n"
+ " \"from\": 0,\n"
+ " \"size\": 100,\n"
+ " \"query\": {\n"
+ " \"neural\": {\n"
+ " \"passage_embedding\": {\n"
+ " \"query_text\": \"Hi world!\",\n"
+ " \"model_id\": \"bQ1J8ooBpBj3wT4HVUsb\",\n"
+ " \"k\": 100\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}";

SearchRequest searchRequest = ModelTestCase.fromJson(json, SearchRequest.class, mapper);

assertEquals("passage_embedding", searchRequest.query().neural().field());
assertEquals("Hi world!", searchRequest.query().neural().queryText());
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
}
}

0 comments on commit 4603af9

Please sign in to comment.