diff --git a/beir/retrieval/search/dense/exact_search.py b/beir/retrieval/search/dense/exact_search.py index 642b21b..3f36570 100644 --- a/beir/retrieval/search/dense/exact_search.py +++ b/beir/retrieval/search/dense/exact_search.py @@ -28,6 +28,7 @@ def search(self, top_k: int, score_function: str, return_sorted: bool = False, + ignore_identical_ids: bool = True, **kwargs) -> Dict[str, Dict[str, float]]: # Create embeddings for all queries using model.encode_queries() # Runs semantic search against the corpus embeddings @@ -45,6 +46,9 @@ def search(self, logger.info("Sorting Corpus by document length (Longest first)...") corpus_ids = sorted(corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True) + if ignore_identical_ids: + # We remove the query from results if it exists in corpus + corpus_ids = [cid for cid in corpus_ids if cid not in query_ids] corpus = [corpus[cid] for cid in corpus_ids] logger.info("Encoding Corpus in batches... Warning: This might take a while!") @@ -70,7 +74,7 @@ def search(self, cos_scores[torch.isnan(cos_scores)] = -1 # Get top-k values - cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(cos_scores, min(top_k+1, len(cos_scores[1])), dim=1, largest=True, sorted=return_sorted) + cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(cos_scores, min(top_k, len(cos_scores[1])), dim=1, largest=True, sorted=return_sorted) cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist() cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist() @@ -78,13 +82,13 @@ def search(self, query_id = query_ids[query_itr] for sub_corpus_id, score in zip(cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr]): corpus_id = corpus_ids[corpus_start_idx+sub_corpus_id] - if corpus_id != query_id: - if len(result_heaps[query_id]) < top_k: - # Push item on the heap - heapq.heappush(result_heaps[query_id], (score, corpus_id)) - else: - # If item is larger than the smallest in the heap, push it on the heap then pop the smallest element - heapq.heappushpop(result_heaps[query_id], (score, corpus_id)) + assert ignore_identical_ids is False or corpus_id != query_id, "Query id and corpus id should not be the same if ignore_identical_ids is set to True" + if len(result_heaps[query_id]) < top_k: + # Push item on the heap + heapq.heappush(result_heaps[query_id], (score, corpus_id)) + else: + # If item is larger than the smallest in the heap, push it on the heap then pop the smallest element + heapq.heappushpop(result_heaps[query_id], (score, corpus_id)) for qid in result_heaps: for score, corpus_id in result_heaps[qid]: