From 08d82b1fe461e10f09238e7bf4381fde1a9a21b6 Mon Sep 17 00:00:00 2001 From: Liyun Xiu Date: Fri, 19 Jul 2024 01:56:10 -0700 Subject: [PATCH] Fix KNN SpaceType.getVectorSimilarityFunction renamed issue (#834) * Fix KNN SpaceType.getVectorSimilarityFunction renamed issue. Signed-off-by: Liyun Xiu * Fix NPE for a test Signed-off-by: Liyun Xiu --------- Signed-off-by: Liyun Xiu --- .../org/opensearch/neuralsearch/query/HybridQueryTests.java | 2 ++ .../java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java index b74bd010c..afb9ecb44 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryTests.java @@ -39,6 +39,7 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.KNNQueryBuilder; @@ -119,6 +120,7 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() { when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(VECTOR_FIELD_NAME, VECTOR_QUERY, K); Query knnQuery = knnQueryBuilder.toQuery(mockQueryShardContext); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 3841e6dd0..13c8e230a 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -867,7 +867,7 @@ protected float computeExpectedScore( final String queryText ) { float[] queryVector = runInference(modelId, queryText); - return spaceType.getVectorSimilarityFunction().compare(queryVector, indexVector); + return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, indexVector); } protected Map getTaskQueryResponse(final String taskId) throws Exception {