Skip to content

Commit

Permalink
Adding serialization for KnnQueryBuilder
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Sep 28, 2022
1 parent 417947e commit bf7469c
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
fieldName = in.readString();
vector = in.readFloatArray();
k = in.readInt();
if (in.readBoolean()) {
filter = in.readNamedWriteable(QueryBuilder.class);
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder: " + ex);
}
Expand Down Expand Up @@ -181,6 +184,12 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeFloatArray(vector);
out.writeInt(k);
if (filter != null) {
out.writeBoolean(true);
out.writeNamedWriteable(filter);
} else {
out.writeBoolean(false);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import com.google.common.collect.ImmutableMap;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.cluster.ClusterModule;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
Expand Down Expand Up @@ -39,42 +44,46 @@

public class KNNQueryBuilderTests extends KNNTestCase {

private static final String FIELD_NAME = "myvector";
private static final int K = 1;
private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value");

public void testInvalidK() {
float[] queryVector = { 1.0f, 1.0f };

/**
* -ve k
*/
expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, -1));
expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, -K));

/**
* zero k
*/
expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, 0));
expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, 0));

/**
* k > KNNQueryBuilder.K_MAX
*/
expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, KNNQueryBuilder.K_MAX + 1));
expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, KNNQueryBuilder.K_MAX + K));
}

public void testEmptyVector() {
/**
* null query vector
*/
float[] queryVector = null;
expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector, 1));
expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, K));

/**
* empty query vector
*/
float[] queryVector1 = {};
expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder("myvector", queryVector1, 1));
expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector1, K));
}

public void testFromXcontent() throws Exception {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1);
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.startObject(knnQueryBuilder.fieldName());
Expand All @@ -90,7 +99,7 @@ public void testFromXcontent() throws Exception {

public void testFromXcontent_WithFilter() throws Exception {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1, QueryBuilders.termQuery("field", "value"));
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY);
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.startObject(knnQueryBuilder.fieldName());
Expand Down Expand Up @@ -118,9 +127,17 @@ protected NamedXContentRegistry xContentRegistry() {
return registry;
}

@Override
protected NamedWriteableRegistry writableRegistry() {
final List<NamedWriteableRegistry.Entry> entries = ClusterModule.getNamedWriteables();
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, KNNQueryBuilder.NAME, KNNQueryBuilder::new));
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, TermQueryBuilder.NAME, TermQueryBuilder::new));
return new NamedWriteableRegistry(entries);
}

public void testDoToQuery_Normal() throws Exception {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1);
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
Expand All @@ -135,7 +152,7 @@ public void testDoToQuery_Normal() throws Exception {

public void testDoToQuery_KnnQueryWithFilter() throws Exception {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1, QueryBuilders.termQuery("field", "value"));
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY);
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
Expand All @@ -152,14 +169,14 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception {

public void testDoToQuery_FromModel() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1);
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);

// Dimension is -1. In this case, model metadata will need to provide dimension
when(mockKNNVectorField.getDimension()).thenReturn(-1);
when(mockKNNVectorField.getDimension()).thenReturn(-K);
when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null);
String modelId = "test-model-id";
when(mockKNNVectorField.getModelId()).thenReturn(modelId);
Expand All @@ -181,26 +198,69 @@ public void testDoToQuery_FromModel() {

public void testDoToQuery_InvalidDimensions() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("myvector", queryVector, 1);
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(400);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
when(mockKNNVectorField.getDimension()).thenReturn(1);
when(mockKNNVectorField.getDimension()).thenReturn(K);
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
}

public void testDoToQuery_InvalidFieldType() throws IOException {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("mynumber", queryVector, 1);
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder("mynumber", queryVector, K);
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
NumberFieldMapper.NumberFieldType mockNumberField = mock(NumberFieldMapper.NumberFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockNumberField);
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
}

public void testSerialization() throws Exception {
final float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };

final KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);

try (BytesStreamOutput output = new BytesStreamOutput()) {
output.setVersion(Version.CURRENT);
output.writeNamedWriteable(knnQueryBuilder);

try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) {
in.setVersion(Version.CURRENT);
final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class);
assertNotNull(deserializedQuery);
assertTrue(deserializedQuery instanceof KNNQueryBuilder);
final KNNQueryBuilder deserializedKnnQueryBuilder = (KNNQueryBuilder) deserializedQuery;
assertEquals(FIELD_NAME, deserializedKnnQueryBuilder.fieldName());
assertArrayEquals(queryVector, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f);
assertEquals(K, deserializedKnnQueryBuilder.getK());
assertNull(deserializedKnnQueryBuilder.getFilter());
}
}

final KNNQueryBuilder knnQueryBuilderWithFilter = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY);

try (BytesStreamOutput output = new BytesStreamOutput()) {
output.setVersion(Version.CURRENT);
output.writeNamedWriteable(knnQueryBuilderWithFilter);

try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry())) {
in.setVersion(Version.CURRENT);
final QueryBuilder deserializedQuery = in.readNamedWriteable(QueryBuilder.class);
assertNotNull(deserializedQuery);
assertTrue(deserializedQuery instanceof KNNQueryBuilder);
final KNNQueryBuilder deserializedKnnQueryBuilder = (KNNQueryBuilder) deserializedQuery;
assertEquals(FIELD_NAME, deserializedKnnQueryBuilder.fieldName());
assertArrayEquals(queryVector, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f);
assertEquals(K, deserializedKnnQueryBuilder.getK());
assertNotNull(deserializedKnnQueryBuilder.getFilter());
assertEquals(TERM_QUERY, deserializedKnnQueryBuilder.getFilter());
}
}
}
}

0 comments on commit bf7469c

Please sign in to comment.