Skip to content

Commit

Permalink
Adding version check for deserialization of filter field
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Sep 28, 2022
1 parent 097506a commit c385555
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.opensearch.Version;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.KNNMethodContext;
Expand Down Expand Up @@ -109,11 +110,11 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
fieldName = in.readString();
vector = in.readFloatArray();
k = in.readInt();
if (in.readBoolean()) {
filter = in.readNamedWriteable(QueryBuilder.class);
if (in.getVersion().onOrAfter(Version.V_2_4_0)) {
filter = in.readOptionalNamedWriteable(QueryBuilder.class);
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder: " + ex);
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
}

Expand Down Expand Up @@ -184,12 +185,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeFloatArray(vector);
out.writeInt(k);
if (filter != null) {
out.writeBoolean(true);
out.writeNamedWriteable(filter);
} else {
out.writeBoolean(false);
}
out.writeOptionalNamedWriteable(filter);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class KNNQueryBuilderTests extends KNNTestCase {
private static final String FIELD_NAME = "myvector";
private static final int K = 1;
private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value");
private static final float[] QUERY_VECTOR = new float[]{ 1.0f, 2.0f, 3.0f, 4.0f };

public void testInvalidK() {
float[] queryVector = { 1.0f, 1.0f };
Expand Down Expand Up @@ -222,9 +223,8 @@ public void testDoToQuery_InvalidFieldType() throws IOException {
}

public void testSerialization() throws Exception {
final float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };

final KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
final KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K);

try (BytesStreamOutput output = new BytesStreamOutput()) {
output.setVersion(Version.CURRENT);
Expand All @@ -233,17 +233,11 @@ public void testSerialization() throws Exception {
try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) {
in.setVersion(Version.CURRENT);
final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class);
assertNotNull(deserializedQuery);
assertTrue(deserializedQuery instanceof KNNQueryBuilder);
final KNNQueryBuilder deserializedKnnQueryBuilder = (KNNQueryBuilder) deserializedQuery;
assertEquals(FIELD_NAME, deserializedKnnQueryBuilder.fieldName());
assertArrayEquals(queryVector, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f);
assertEquals(K, deserializedKnnQueryBuilder.getK());
assertNull(deserializedKnnQueryBuilder.getFilter());
assertSerialization(deserializedQuery, true);
}
}

final KNNQueryBuilder knnQueryBuilderWithFilter = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY);
final KNNQueryBuilder knnQueryBuilderWithFilter = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, TERM_QUERY);

try (BytesStreamOutput output = new BytesStreamOutput()) {
output.setVersion(Version.CURRENT);
Expand All @@ -252,15 +246,35 @@ public void testSerialization() throws Exception {
try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) {
in.setVersion(Version.CURRENT);
final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class);
assertNotNull(deserializedQuery);
assertTrue(deserializedQuery instanceof KNNQueryBuilder);
final KNNQueryBuilder deserializedKnnQueryBuilder = (KNNQueryBuilder) deserializedQuery;
assertEquals(FIELD_NAME, deserializedKnnQueryBuilder.fieldName());
assertArrayEquals(queryVector, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f);
assertEquals(K, deserializedKnnQueryBuilder.getK());
assertNotNull(deserializedKnnQueryBuilder.getFilter());
assertEquals(TERM_QUERY, deserializedKnnQueryBuilder.getFilter());
assertSerialization(deserializedQuery, false);
}
}

//test serialization from < 2.4 versions
try (BytesStreamOutput output = new BytesStreamOutput()) {
output.setVersion(Version.V_2_3_0);
output.writeNamedWriteable(knnQueryBuilderWithFilter);

try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) {
in.setVersion(Version.V_2_3_0);
final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class);
assertSerialization(deserializedQuery, true);
}
}
}

private void assertSerialization(final QueryBuilder deserializedQuery, boolean assertFilterIsNull) {
assertNotNull(deserializedQuery);
assertTrue(deserializedQuery instanceof KNNQueryBuilder);
final KNNQueryBuilder deserializedKnnQueryBuilder = (KNNQueryBuilder) deserializedQuery;
assertEquals(FIELD_NAME, deserializedKnnQueryBuilder.fieldName());
assertArrayEquals(QUERY_VECTOR, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f);
assertEquals(K, deserializedKnnQueryBuilder.getK());
if (assertFilterIsNull) {
assertNull(deserializedKnnQueryBuilder.getFilter());
} else {
assertNotNull(deserializedKnnQueryBuilder.getFilter());
assertEquals(TERM_QUERY, deserializedKnnQueryBuilder.getFilter());
}
}
}

0 comments on commit c385555

Please sign in to comment.