diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index a26dd8263..f40d4bf59 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -67,6 +67,7 @@ import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; @@ -141,8 +142,8 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() { assertNotNull(queryOnlyNeural); assertTrue(queryOnlyNeural instanceof HybridQuery); assertEquals(1, ((HybridQuery) queryOnlyNeural).getSubQueries().size()); - assertTrue(((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next() instanceof KNNQuery); - KNNQuery knnQuery = (KNNQuery) ((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next(); + assertTrue(((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next() instanceof NativeEngineKnnVectorQuery); + KNNQuery knnQuery = ((NativeEngineKnnVectorQuery) ((HybridQuery) queryOnlyNeural).getSubQueries().iterator().next()).getKnnQuery(); assertEquals(VECTOR_FIELD_NAME, knnQuery.getField()); assertEquals(K, knnQuery.getK()); assertNotNull(knnQuery.getQueryVector()); @@ -183,8 +184,8 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() { // verify knn vector query Iterator queryIterator = ((HybridQuery) queryTwoSubQueries).getSubQueries().iterator(); Query firstQuery = queryIterator.next(); - assertTrue(firstQuery instanceof KNNQuery); - KNNQuery knnQuery = (KNNQuery) firstQuery; + assertTrue(firstQuery instanceof NativeEngineKnnVectorQuery); + KNNQuery knnQuery = ((NativeEngineKnnVectorQuery) firstQuery).getKnnQuery(); assertEquals(VECTOR_FIELD_NAME, knnQuery.getField()); assertEquals(K, knnQuery.getK()); assertNotNull(knnQuery.getQueryVector());