From 91bf62c81574d3cb03d3748e6737b42937084739 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 9 Oct 2023 11:47:24 +0200 Subject: [PATCH] Add more unit tests --- haystack/preview/components/samplers/top_p.py | 4 +++- .../preview/components/samplers/test_top_p.py | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/haystack/preview/components/samplers/top_p.py b/haystack/preview/components/samplers/top_p.py index 6f3fe57772..084c949368 100644 --- a/haystack/preview/components/samplers/top_p.py +++ b/haystack/preview/components/samplers/top_p.py @@ -79,6 +79,8 @@ def run(self, documents: List[Document], top_p: Optional[float] = None): if not 0 <= top_p <= 1: raise ComponentError(f"top_p must be between 0 and 1. Got {top_p}.") + epsilon = 1e-6 # account for floating-point precision issues + similarity_scores = torch.tensor(self._collect_scores(documents), dtype=torch.float32) # Apply softmax normalization to the similarity scores @@ -89,7 +91,7 @@ def run(self, documents: List[Document], top_p: Optional[float] = None): cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # Find the indices with cumulative probabilities that exceed top_p - top_p_indices = torch.where(cumulative_probs <= top_p)[0] + top_p_indices = torch.where(torch.BoolTensor(cumulative_probs <= (top_p + epsilon)))[0] # Map the selected indices back to their original indices original_indices = sorted_indices[top_p_indices] diff --git a/test/preview/components/samplers/test_top_p.py b/test/preview/components/samplers/test_top_p.py index 958ab1f356..111dffa25c 100644 --- a/test/preview/components/samplers/test_top_p.py +++ b/test/preview/components/samplers/test_top_p.py @@ -64,6 +64,26 @@ def test_run_scores(self): assert [doc.score for doc in docs_filtered] == sorted_scores[:1] + @pytest.mark.unit + def test_run_scores_top_p_1(self): + """ + Test if the component runs correctly top_p=1. + """ + sampler = TopPSampler(top_p=1.0) + docs = [ + Document(text="Berlin", score=-10.6), + Document(text="Belgrade", score=-8.9), + Document(text="Sarajevo", score=-4.6), + ] + + random.shuffle(docs) + output = sampler.run(documents=docs) + docs_filtered = output["documents"] + assert len(docs_filtered) == len(docs) + assert docs_filtered[0].text == "Sarajevo" + + assert [doc.score for doc in docs_filtered] == sorted([doc.score for doc in docs], reverse=True) + # Returns an empty list if no documents are provided @pytest.mark.unit def test_returns_empty_list_if_no_documents_are_provided(self):