Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #1137 - Adding query_image to Neural query #1138

Merged
merged 3 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ This section is for maintaining a changelog for all breaking changes for the cli
## [Unreleased 2.x]

### Added
- Adds `queryImage` (query_image) field to `NeuralQuery`, following definition in ([Neural Query](https://opensearch.org/docs/latest/query-dsl/specialized/neural/)) ([#1137](https://github.com/opensearch-project/opensearch-java/pull/1138))

### Dependencies

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.MissingRequiredPropertiesException;
import org.opensearch.client.util.ObjectBuilder;

@JsonpDeserializable
public class NeuralQuery extends QueryBase implements QueryVariant {

private final String field;
private final String queryText;
private final String queryImage;
private final int k;
@Nullable
private final String modelId;
Expand All @@ -34,7 +36,11 @@ private NeuralQuery(NeuralQuery.Builder builder) {
super(builder);

this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field");
this.queryText = ApiTypeHelper.requireNonNull(builder.queryText, this, "queryText");
if (builder.queryText == null && builder.queryImage == null && !ApiTypeHelper.requiredPropertiesCheckDisabled()) {
throw new MissingRequiredPropertiesException(this, "queryText", "queryImage");
}
this.queryText = builder.queryText;
this.queryImage = builder.queryImage;
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.modelId = builder.modelId;
this.filter = builder.filter;
Expand Down Expand Up @@ -64,14 +70,25 @@ public final String field() {
}

/**
* Required - Search query text.
* Required - The query_text if query_image is not set.
* Optional - The query_text if query_image is set.
*
* @return Search query text.
*/
public final String queryText() {
return this.queryText;
}

/**
* Required - The query_image if query_text is not set.
* Optional - The query_image if query_text is set.
*
* @return Search query image.
*/
public final String queryImage() {
return this.queryImage;
}

/**
* Required - The number of neighbors to return.
*
Expand Down Expand Up @@ -112,7 +129,13 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

super.serializeInternal(generator, mapper);

generator.write("query_text", this.queryText);
if (this.queryText != null) {
generator.write("query_text", this.queryText);
}

if (this.queryImage != null) {
generator.write("query_image", this.queryImage);
}

if (this.modelId != null) {
generator.write("model_id", this.modelId);
Expand All @@ -129,7 +152,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
}

public Builder toBuilder() {
return new Builder().field(field).queryText(queryText).k(k).modelId(modelId).filter(filter);
return new Builder().field(field).queryText(queryText).queryImage(queryImage).k(k).modelId(modelId).filter(filter);
}

/**
Expand All @@ -138,6 +161,7 @@ public Builder toBuilder() {
public static class Builder extends QueryBase.AbstractBuilder<NeuralQuery.Builder> implements ObjectBuilder<NeuralQuery> {
private String field;
private String queryText;
private String queryImage;
private Integer k;
@Nullable
private String modelId;
Expand All @@ -156,7 +180,8 @@ public NeuralQuery.Builder field(@Nullable String field) {
}

/**
* Required - Search query text.
* Required - The query_text if query_image is not set.
* Optional - The query_text if query_image is set.
*
* @param queryText Search query text.
* @return This builder.
Expand All @@ -166,6 +191,18 @@ public NeuralQuery.Builder queryText(@Nullable String queryText) {
return this;
}

/**
* Required - The query_image if query_text is not set.
* Optional - The query_image if query_text is set.
*
* @param queryImage Search query image.
* @return This builder.
*/
public NeuralQuery.Builder queryImage(@Nullable String queryImage) {
this.queryImage = queryImage;
return this;
}

/**
* Optional - The model_id field if the default model for the index or field is set.
* Required - The model_id field if there is no default model set for the index or field.
Expand Down Expand Up @@ -227,6 +264,7 @@ protected static void setupNeuralQueryDeserializer(ObjectDeserializer<NeuralQuer
setupQueryBaseDeserializer(op);

op.add(NeuralQuery.Builder::queryText, JsonpDeserializer.stringDeserializer(), "query_text");
op.add(NeuralQuery.Builder::queryImage, JsonpDeserializer.stringDeserializer(), "query_image");
op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(NeuralQuery.Builder::filter, Query._DESERIALIZER, "filter");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.client.util;

import java.util.StringJoiner;

/**
* Thrown by {@link ObjectBuilder#build()} when one of the required properties is missing.
* <p>
* If you think this is an error and that the reported property is actually optional, a workaround is
* available in {@link ApiTypeHelper} to disable checks. Use with caution.
*/
public class MissingRequiredPropertiesException extends RuntimeException {
private Class<?> clazz;
private String[] properties;

public MissingRequiredPropertiesException(Object obj, String... properties) {
super(
"Missing at least one required property between "
+ buildPropertiesMsg(properties)
+ " in '"
+ obj.getClass().getSimpleName()
+ "'"
);
this.clazz = obj.getClass();
this.properties = properties;
}

/**
* The class where the missing property was found
*/
public Class<?> getObjectClass() {
return clazz;
}

public String[] getPropertiesName() {
return properties;
}

private static String buildPropertiesMsg(String[] properties) {
final StringJoiner sj = new StringJoiner(",", "'", "'");
for (final String property : properties) {
sj.add(property);
}
return sj.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@

import org.junit.Test;
import org.opensearch.client.opensearch.model.ModelTestCase;
import org.opensearch.client.util.MissingRequiredPropertiesException;

public class NeuralQueryTest extends ModelTestCase {
@Test
public void toBuilder() {
public void toBuilder_queryText() {
NeuralQuery origin = new NeuralQuery.Builder().field("field")
.queryText("queryText")
.k(1)
Expand All @@ -23,4 +24,37 @@ public void toBuilder() {

assertEquals(toJson(copied), toJson(origin));
}

@Test
public void toBuilder_queryImage() {
NeuralQuery origin = new NeuralQuery.Builder().field("field")
.queryImage("queryImage")
.k(1)
.filter(IdsQuery.of(builder -> builder.values("Some_ID")).toQuery())
.build();
NeuralQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
}

@Test
public void toBuilder_both() {
NeuralQuery origin = new NeuralQuery.Builder().field("field")
.queryText("queryText")
.queryImage("queryImage")
.k(1)
.filter(IdsQuery.of(builder -> builder.values("Some_ID")).toQuery())
.build();
NeuralQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
}

@Test
public void toBuilder_missing_query() {
assertThrows(
MissingRequiredPropertiesException.class,
() -> new NeuralQuery.Builder().field("field").k(1).filter(IdsQuery.of(builder -> builder.values("Some_ID")).toQuery()).build()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ public void testNeuralQueryFromJson() {
+ " \"neural\": {\n"
+ " \"passage_embedding\": {\n"
+ " \"query_text\": \"Hi world!\",\n"
+ " \"query_image\": \"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAAAXNSR0IArs4c6QAAAA1JREFUGFdj+L+U4T8ABu8CpCYJ1DQAAAAASUVORK5CYII=\",\n"
+ " \"model_id\": \"bQ1J8ooBpBj3wT4HVUsb\",\n"
+ " \"k\": 100\n"
+ " }\n"
Expand All @@ -245,6 +246,10 @@ public void testNeuralQueryFromJson() {

assertEquals("passage_embedding", searchRequest.query().neural().field());
assertEquals("Hi world!", searchRequest.query().neural().queryText());
assertEquals(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAAAXNSR0IArs4c6QAAAA1JREFUGFdj+L+U4T8ABu8CpCYJ1DQAAAAASUVORK5CYII=",
searchRequest.query().neural().queryImage()
);
assertEquals("bQ1J8ooBpBj3wT4HVUsb", searchRequest.query().neural().modelId());
assertEquals(100, searchRequest.query().neural().k());
}
Expand Down
Loading