Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add support for approximate k-NN queries #559

Merged
merged 2 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

## [Unreleased]
### Added
- Add support for knn_vector field type ([#529](https://github.com/opensearch-project/opensearch-java/pull/524))
- Add support for knn_vector field type ([#524](https://github.com/opensearch-project/opensearch-java/pull/524))
- Add translog option object and missing translog sync interval option in index settings ([#518](https://github.com/opensearch-project/opensearch-java/pull/518))
- Adds the option to set slices=auto for UpdateByQueryRequest, DeleteByQueryRequest and ReindexRequest ([#538](https://github.com/opensearch-project/opensearch-java/pull/538))
- Add support for approximate k-NN queries ([#548](https://github.com/opensearch-project/opensearch-java/pull/548))

### Dependencies
- Bumps `com.github.jk1.dependency-license-report` from 2.2 to 2.4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ public final Builder parameters(@Nullable Map<String, JsonData> map) {
return this;
}

/**
* API name: {@code parameters}
*/
public final Builder parameters(String key, JsonData value) {
this.parameters = _mapPut(this.parameters, key, value);
return this;
}

/**
* Builds a {@link KnnVectorMethod}.
*
Expand Down Expand Up @@ -194,10 +202,12 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

if (this.parameters != null) {
generator.writeKey("parameters");
generator.writeStartObject();
for (Map.Entry<String, JsonData> item0 : this.parameters.entrySet()) {
generator.writeKey(item0.getKey());
item0.getValue().serialize(generator, mapper);
}
generator.writeEnd();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ public final Builder method(@Nullable KnnVectorMethod value) {
return this;
}

/**
* API name: {@code method}
*/
public final Builder method(Function<KnnVectorMethod.Builder, ObjectBuilder<KnnVectorMethod>> fn) {
return this.method(fn.apply(new KnnVectorMethod.Builder()).build());
}

/**
* Builds a {@link KnnVectorProperty}.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.client.opensearch._types.query_dsl;

import jakarta.json.stream.JsonGenerator;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.opensearch.client.json.JsonpDeserializable;
import org.opensearch.client.json.JsonpDeserializer;
import org.opensearch.client.json.JsonpMapper;
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.ObjectBuilder;

@JsonpDeserializable
public class KnnQuery extends QueryBase implements QueryVariant {
private final String field;
private final float[] vector;
private final int k;
@Nullable
private final Query filter;

private KnnQuery(Builder builder) {
super(builder);

this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field");
this.vector = ApiTypeHelper.requireNonNull(builder.vector, this, "vector");
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.filter = builder.filter;
}

public static KnnQuery of(Function<Builder, ObjectBuilder<KnnQuery>> fn) {
return fn.apply(new Builder()).build();
}

/**
* Query variant kind.
* @return The query variant kind.
*/
@Override
public Query.Kind _queryKind() {
return Query.Kind.Knn;
}

/**
* Required - The target field.
* @return The target field.
*/
public final String field() {
return this.field;
}

/**
* Required - The vector to search for.
* @return The vector to search for.
*/
public final float[] vector() {
return this.vector;
}

/**
* Required - The number of neighbors the search of each graph will return.
* @return The number of neighbors to return.
*/
public final int k() {
return this.k;
}

/**
* Optional - A query to filter the results of the query.
* @return The filter query.
*/
@Nullable
public final Query filter() {
return this.filter;
}

@Override
protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.writeStartObject(this.field);

super.serializeInternal(generator, mapper);

// TODO: Implement the rest of the serialization.

generator.writeKey("vector");
generator.writeStartArray();
for (float value : this.vector) {
generator.write(value);
}
generator.writeEnd();

generator.write("k", this.k);

if (this.filter != null) {
generator.writeKey("filter");
this.filter.serialize(generator, mapper);
}

generator.writeEnd();
}

/**
* Builder for {@link KnnQuery}.
*/
public static class Builder extends QueryBase.AbstractBuilder<Builder> implements ObjectBuilder<KnnQuery> {
@Nullable
private String field;
@Nullable
private float[] vector;
@Nullable
private Integer k;
@Nullable
private Query filter;

/**
* Required - The target field.
* @param field The target field.
* @return This builder.
*/
public Builder field(@Nullable String field) {
this.field = field;
return this;
}

/**
* Required - The vector to search for.
*
* @param vector The vector to search for.
* @return This builder.
*/
public Builder vector(@Nullable float[] vector) {
this.vector = vector;
return this;
}

/**
* Required - The number of neighbors the search of each graph will return.
*
* @param k The number of neighbors to return.
* @return This builder.
*/
public Builder k(@Nullable Integer k) {
this.k = k;
return this;
}

/**
* Optional - A query to filter the results of the knn query.
*
* @param filter The filter query.
* @return This builder.
*/
public Builder filter(@Nullable Query filter) {
this.filter = filter;
return this;
}

@Override
protected Builder self() {
return this;
}

/**
* Builds a {@link KnnQuery}.
*
* @return The built {@link KnnQuery}.
*/
@Override
public KnnQuery build() {
_checkSingleUse();

return new KnnQuery(this);
}
}

public static final JsonpDeserializer<KnnQuery> _DESERIALIZER = ObjectBuilderDeserializer
.lazy(Builder::new, KnnQuery::setupKnnQueryDeserializer);

protected static void setupKnnQueryDeserializer(ObjectDeserializer<Builder> op) {
setupQueryBaseDeserializer(op);
op.add((b, v) -> {
float[] vector = new float[v.size()];
int i = 0;
for (Float value : v) {
vector[i++] = value;
}
b.vector(vector);
}, JsonpDeserializer.arrayDeserializer(JsonpDeserializer.floatDeserializer()), "vector");
op.add(Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(Builder::filter, Query._DESERIALIZER, "filter");

op.setKey(Builder::field, JsonpDeserializer.stringDeserializer());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ public enum Kind implements JsonEnum {

Intervals("intervals"),

Knn("knn"),

Match("match"),

MatchAll("match_all"),
Expand Down Expand Up @@ -535,6 +537,23 @@ public IntervalsQuery intervals() {
return TaggedUnionUtils.get(this, Kind.Intervals);
}

/**
* Is this variant instance of kind {@code knn}?
*/
public boolean isKnn() {
return _kind == Kind.Knn;
}

/**
* Get the {@code knn} variant value.
*
* @throws IllegalStateException
* if the current variant is not of the {@code knn} kind.
*/
public KnnQuery knn() {
return TaggedUnionUtils.get(this, Kind.Knn);
}

/**
* Is this variant instance of kind {@code match}?
*/
Expand Down Expand Up @@ -1340,6 +1359,16 @@ public ObjectBuilder<Query> intervals(Function<IntervalsQuery.Builder, ObjectBui
return this.intervals(fn.apply(new IntervalsQuery.Builder()).build());
}

public ObjectBuilder<Query> knn(KnnQuery v) {
this._kind = Kind.Knn;
this._value = v;
return this;
}

public ObjectBuilder<Query> knn(Function<KnnQuery.Builder, ObjectBuilder<KnnQuery>> fn) {
return this.knn(fn.apply(new KnnQuery.Builder()).build());
}

public ObjectBuilder<Query> match(MatchQuery v) {
this._kind = Kind.Match;
this._value = v;
Expand Down Expand Up @@ -1728,6 +1757,7 @@ protected static void setupQueryDeserializer(ObjectDeserializer<Builder> op) {
op.add(Builder::hasParent, HasParentQuery._DESERIALIZER, "has_parent");
op.add(Builder::ids, IdsQuery._DESERIALIZER, "ids");
op.add(Builder::intervals, IntervalsQuery._DESERIALIZER, "intervals");
op.add(Builder::knn, KnnQuery._DESERIALIZER, "knn");
op.add(Builder::match, MatchQuery._DESERIALIZER, "match");
op.add(Builder::matchAll, MatchAllQuery._DESERIALIZER, "match_all");
op.add(Builder::matchBoolPrefix, MatchBoolPrefixQuery._DESERIALIZER, "match_bool_prefix");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ public static IntervalsQuery.Builder intervals() {
return new IntervalsQuery.Builder();
}

/**
* Creates a builder for the {@link KnnQuery knn} {@code Query} variant.
*/
public static KnnQuery.Builder knn() {
return new KnnQuery.Builder();
}

/**
* Creates a builder for the {@link MatchQuery match} {@code Query} variant.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

}
if (this.knnAlgoParamEfSearch != null) {
generator.writeKey("knn.algo_param_ef_search");
generator.writeKey("knn.algo_param.ef_search");
generator.write(this.knnAlgoParamEfSearch);

}
Expand Down
Loading