diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 11199f9d7..3429b2400 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.query; import lombok.extern.log4j.Log4j2; +import org.opensearch.Version; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.KNNMethodContext; @@ -109,11 +110,11 @@ public KNNQueryBuilder(StreamInput in) throws IOException { fieldName = in.readString(); vector = in.readFloatArray(); k = in.readInt(); - if (in.readBoolean()) { - filter = in.readNamedWriteable(QueryBuilder.class); + if (in.getVersion().onOrAfter(Version.V_2_4_0)) { + filter = in.readOptionalNamedWriteable(QueryBuilder.class); } } catch (IOException ex) { - throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder: " + ex); + throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); } } @@ -184,12 +185,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeFloatArray(vector); out.writeInt(k); - if (filter != null) { - out.writeBoolean(true); - out.writeNamedWriteable(filter); - } else { - out.writeBoolean(false); - } + out.writeOptionalNamedWriteable(filter); } /** diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index f65fae030..bbf3f1442 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -47,6 +47,7 @@ public class KNNQueryBuilderTests extends KNNTestCase { private static final String FIELD_NAME = "myvector"; private static final int K = 1; private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); + private static final float[] QUERY_VECTOR = new float[]{ 1.0f, 2.0f, 3.0f, 4.0f }; public void testInvalidK() { float[] queryVector = { 1.0f, 1.0f }; @@ -222,9 +223,8 @@ public void testDoToQuery_InvalidFieldType() throws IOException { } public void testSerialization() throws Exception { - final float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - final KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); + final KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K); try (BytesStreamOutput output = new BytesStreamOutput()) { output.setVersion(Version.CURRENT); @@ -233,17 +233,11 @@ public void testSerialization() throws Exception { try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) { in.setVersion(Version.CURRENT); final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class); - assertNotNull(deserializedQuery); - assertTrue(deserializedQuery instanceof KNNQueryBuilder); - final KNNQueryBuilder deserializedKnnQueryBuilder = (KNNQueryBuilder) deserializedQuery; - assertEquals(FIELD_NAME, deserializedKnnQueryBuilder.fieldName()); - assertArrayEquals(queryVector, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f); - assertEquals(K, deserializedKnnQueryBuilder.getK()); - assertNull(deserializedKnnQueryBuilder.getFilter()); + assertSerialization(deserializedQuery, true); } } - final KNNQueryBuilder knnQueryBuilderWithFilter = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); + final KNNQueryBuilder knnQueryBuilderWithFilter = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, TERM_QUERY); try (BytesStreamOutput output = new BytesStreamOutput()) { output.setVersion(Version.CURRENT); @@ -252,15 +246,35 @@ public void testSerialization() throws Exception { try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) { in.setVersion(Version.CURRENT); final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class); - assertNotNull(deserializedQuery); - assertTrue(deserializedQuery instanceof KNNQueryBuilder); - final KNNQueryBuilder deserializedKnnQueryBuilder = (KNNQueryBuilder) deserializedQuery; - assertEquals(FIELD_NAME, deserializedKnnQueryBuilder.fieldName()); - assertArrayEquals(queryVector, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f); - assertEquals(K, deserializedKnnQueryBuilder.getK()); - assertNotNull(deserializedKnnQueryBuilder.getFilter()); - assertEquals(TERM_QUERY, deserializedKnnQueryBuilder.getFilter()); + assertSerialization(deserializedQuery, false); } } + + //test serialization from < 2.4 versions + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setVersion(Version.V_2_3_0); + output.writeNamedWriteable(knnQueryBuilderWithFilter); + + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) { + in.setVersion(Version.V_2_3_0); + final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class); + assertSerialization(deserializedQuery, true); + } + } + } + + private void assertSerialization(final QueryBuilder deserializedQuery, boolean assertFilterIsNull) { + assertNotNull(deserializedQuery); + assertTrue(deserializedQuery instanceof KNNQueryBuilder); + final KNNQueryBuilder deserializedKnnQueryBuilder = (KNNQueryBuilder) deserializedQuery; + assertEquals(FIELD_NAME, deserializedKnnQueryBuilder.fieldName()); + assertArrayEquals(QUERY_VECTOR, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f); + assertEquals(K, deserializedKnnQueryBuilder.getK()); + if (assertFilterIsNull) { + assertNull(deserializedKnnQueryBuilder.getFilter()); + } else { + assertNotNull(deserializedKnnQueryBuilder.getFilter()); + assertEquals(TERM_QUERY, deserializedKnnQueryBuilder.getFilter()); + } } }