Skip to content

Commit

Permalink
Add more unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Oct 9, 2023
1 parent 5a8b3a6 commit 91bf62c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
4 changes: 3 additions & 1 deletion haystack/preview/components/samplers/top_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def run(self, documents: List[Document], top_p: Optional[float] = None):
if not 0 <= top_p <= 1:
raise ComponentError(f"top_p must be between 0 and 1. Got {top_p}.")

epsilon = 1e-6 # account for floating-point precision issues

similarity_scores = torch.tensor(self._collect_scores(documents), dtype=torch.float32)

# Apply softmax normalization to the similarity scores
Expand All @@ -89,7 +91,7 @@ def run(self, documents: List[Document], top_p: Optional[float] = None):
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

# Find the indices with cumulative probabilities that exceed top_p
top_p_indices = torch.where(cumulative_probs <= top_p)[0]
top_p_indices = torch.where(torch.BoolTensor(cumulative_probs <= (top_p + epsilon)))[0]

# Map the selected indices back to their original indices
original_indices = sorted_indices[top_p_indices]
Expand Down
20 changes: 20 additions & 0 deletions test/preview/components/samplers/test_top_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,26 @@ def test_run_scores(self):

assert [doc.score for doc in docs_filtered] == sorted_scores[:1]

@pytest.mark.unit
def test_run_scores_top_p_1(self):
"""
Test if the component runs correctly top_p=1.
"""
sampler = TopPSampler(top_p=1.0)
docs = [
Document(text="Berlin", score=-10.6),
Document(text="Belgrade", score=-8.9),
Document(text="Sarajevo", score=-4.6),
]

random.shuffle(docs)
output = sampler.run(documents=docs)
docs_filtered = output["documents"]
assert len(docs_filtered) == len(docs)
assert docs_filtered[0].text == "Sarajevo"

assert [doc.score for doc in docs_filtered] == sorted([doc.score for doc in docs], reverse=True)

# Returns an empty list if no documents are provided
@pytest.mark.unit
def test_returns_empty_list_if_no_documents_are_provided(self):
Expand Down

0 comments on commit 91bf62c

Please sign in to comment.