Skip to content

Commit

Permalink
Added Tests for KNNQueryBuilder for Faiss Filtering (#942)
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Jul 10, 2023
1 parent 4211013 commit 6738952
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ protected Query doToQuery(QueryShardContext context) {
);
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && filter != null && knnEngine != KNNEngine.FAISS) {
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)
&& filter != null
&& !KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) {
throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,38 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception {
assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class));
}

public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY);
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of());
KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext);
when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
Query query = knnQueryBuilder.doToQuery(mockQueryShardContext);
assertNotNull(query);
assertTrue(query.getClass().isAssignableFrom(KNNQuery.class));
}

public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY);
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of());
KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, methodComponentContext);
when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
}

public void testDoToQuery_FromModel() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
Expand Down

0 comments on commit 6738952

Please sign in to comment.