diff --git a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java index 85415cb63..34a9cff2d 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java @@ -165,7 +165,7 @@ private static Tuple, Map> pruneByAlphaMass( * @param pruneRatio The ratio or threshold for prune * @param sparseVector The input sparse vector as a map of string keys to float values * @return A tuple containing two maps: the first with high-scoring elements, - * the second with low-scoring elements (or null if requiresPrunedEntries is false) + * the second with low-scoring elements */ public static Tuple, Map> splitSparseVector( PruneType pruneType, diff --git a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java index 99bbe3c92..f0869ac53 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java @@ -112,6 +112,33 @@ public void testPruneByAlphaMass() { assertEquals(1.0f, tupleResult.v2().get("d"), 0.001); } + public void testNonePrune() { + Map input = new HashMap<>(); + input.put("a", 5.0f); + input.put("b", 3.0f); + input.put("c", 4.0f); + input.put("d", 1.0f); + + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.NONE, 2, input); + + assertEquals(4, result.size()); + assertEquals(5.0f, result.get("a"), 0.001); + assertEquals(3.0f, result.get("b"), 0.001); + assertEquals(4.0f, result.get("c"), 0.001); + assertEquals(1.0f, result.get("d"), 0.001); + + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.NONE, 2, input); + + assertEquals(4, tupleResult.v1().size()); + assertEquals(0, tupleResult.v2().size()); + assertEquals(5.0f, tupleResult.v1().get("a"), 0.001); + assertEquals(3.0f, tupleResult.v1().get("b"), 0.001); + assertEquals(4.0f, tupleResult.v1().get("c"), 0.001); + assertEquals(1.0f, tupleResult.v1().get("d"), 0.001); + } + public void testEmptyInput() { Map input = new HashMap<>(); @@ -154,13 +181,13 @@ public void testInvalidPruneType() { // Test prune IllegalArgumentException exception1 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.splitSparseVector(null, 2, input) + () -> PruneUtils.pruneSparseVector(null, 2, input) ); assertEquals(exception1.getMessage(), "Prune type and prune ratio must be provided"); IllegalArgumentException exception2 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.splitSparseVector(null, 2, input) + () -> PruneUtils.pruneSparseVector(null, 2, input) ); assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided"); @@ -181,7 +208,7 @@ public void testInvalidPruneType() { public void testNullSparseVector() { IllegalArgumentException exception1 = assertThrows( IllegalArgumentException.class, - () -> PruneUtils.splitSparseVector(PruneType.TOP_K, 2, null) + () -> PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, null) ); assertEquals(exception1.getMessage(), "Sparse vector must be provided");