diff --git a/pyserini/hsearch/__main__.py b/pyserini/hsearch/__main__.py index b7f6ee82d..790c756cd 100644 --- a/pyserini/hsearch/__main__.py +++ b/pyserini/hsearch/__main__.py @@ -38,7 +38,9 @@ def define_fusion_args(parser): parser.add_argument('--alpha', type=float, metavar='num', required=False, default=0.1, help="alpha for hybrid search") + parser.add_argument('--hits', type=int, required=False, default=10, help='number of hits from dense and sparse') parser.add_argument('--normalization', action='store_true', required=False, help='hybrid score with normalization') + parser.add_argument('--weight-on-dense', action='store_true', required=False, help='weight on dense part') def parse_args(parser, commands): @@ -160,7 +162,7 @@ def parse_args(parser, commands): batch_topic_ids = list() for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))): if args.run.batch_size <= 1 and args.run.threads <= 1: - hits = hsearcher.search(text, args.run.hits, args.fusion.alpha, args.fusion.normalization) + hits = hsearcher.search(text, args.fusion.hits, args.run.hits, args.fusion.alpha, args.fusion.normalization, args.fusion.weight_on_dense) results = [(topic_id, hits)] else: batch_topic_ids.append(str(topic_id)) @@ -168,8 +170,8 @@ def parse_args(parser, commands): if (index + 1) % args.run.batch_size == 0 or \ index == len(topics.keys()) - 1: results = hsearcher.batch_search( - batch_topics, batch_topic_ids, args.run.hits, args.run.threads, - args.fusion.alpha, args.fusion.normalization) + batch_topics, batch_topic_ids, args.fusion.hits, args.run.hits, args.run.threads, + args.fusion.alpha, args.fusion.normalization, args.fusion.weight_on_dense) results = [(id_, results[id_]) for id_ in batch_topic_ids] batch_topic_ids.clear() batch_topics.clear() diff --git a/pyserini/hsearch/_hybrid.py b/pyserini/hsearch/_hybrid.py index f9e00273a..a329e6011 100644 --- a/pyserini/hsearch/_hybrid.py +++ b/pyserini/hsearch/_hybrid.py @@ -36,24 +36,24 @@ def __init__(self, dense_searcher, sparse_searcher): self.dense_searcher = dense_searcher self.sparse_searcher = sparse_searcher - def search(self, query: str, k: int = 10, alpha: float = 0.1, normalization: bool = False) -> List[DenseSearchResult]: - dense_hits = self.dense_searcher.search(query, k) - sparse_hits = self.sparse_searcher.search(query, k) - return self._hybrid_results(dense_hits, sparse_hits, alpha, k, normalization) + def search(self, query: str, k0: int = 10, k: int = 10, alpha: float = 0.1, normalization: bool = False, weight_on_dense: bool = False) -> List[DenseSearchResult]: + dense_hits = self.dense_searcher.search(query, k0) + sparse_hits = self.sparse_searcher.search(query, k0) + return self._hybrid_results(dense_hits, sparse_hits, alpha, k, normalization, weight_on_dense) - def batch_search(self, queries: List[str], q_ids: List[str], k: int = 10, threads: int = 1, - alpha: float = 0.1, normalization: bool = False) \ + def batch_search(self, queries: List[str], q_ids: List[str], k0: int = 10, k: int = 10, threads: int = 1, + alpha: float = 0.1, normalization: bool = False, weight_on_dense: bool = False) \ -> Dict[str, List[DenseSearchResult]]: - dense_result = self.dense_searcher.batch_search(queries, q_ids, k, threads) - sparse_result = self.sparse_searcher.batch_search(queries, q_ids, k, threads) + dense_result = self.dense_searcher.batch_search(queries, q_ids, k0, threads) + sparse_result = self.sparse_searcher.batch_search(queries, q_ids, k0, threads) hybrid_result = { - key: self._hybrid_results(dense_result[key], sparse_result[key], alpha, k, normalization) + key: self._hybrid_results(dense_result[key], sparse_result[key], alpha, k, normalization, weight_on_dense) for key in dense_result } return hybrid_result @staticmethod - def _hybrid_results(dense_results, sparse_results, alpha, k, normalization=False): + def _hybrid_results(dense_results, sparse_results, alpha, k, normalization=False, weight_on_dense=False): dense_hits = {hit.docid: hit.score for hit in dense_results} sparse_hits = {hit.docid: hit.score for hit in sparse_results} hybrid_result = [] @@ -76,6 +76,6 @@ def _hybrid_results(dense_results, sparse_results, alpha, k, normalization=False / (max_sparse_score - min_sparse_score) dense_score = (dense_score - (min_dense_score + max_dense_score) / 2) \ / (max_dense_score - min_dense_score) - score = alpha * sparse_score + dense_score + score = alpha * sparse_score + dense_score if not weight_on_dense else sparse_score + alpha * dense_score hybrid_result.append(DenseSearchResult(doc, score)) return sorted(hybrid_result, key=lambda x: x.score, reverse=True)[:k] diff --git a/pyserini/search/_impact_searcher.py b/pyserini/search/_impact_searcher.py index 0d767d4ce..e63e28328 100644 --- a/pyserini/search/_impact_searcher.py +++ b/pyserini/search/_impact_searcher.py @@ -23,7 +23,6 @@ from typing import Dict, List, Optional, Union import numpy as np from ._base import Document -from pyserini.index import IndexReader from pyserini.pyclass import autoclass, JFloat, JArrayList, JHashMap, JString from pyserini.util import download_prebuilt_index from pyserini.encode import QueryEncoder, TokFreqQueryEncoder, UniCoilQueryEncoder, CachedDataQueryEncoder @@ -230,6 +229,7 @@ def _init_query_encoder_from_str(query_encoder): @staticmethod def _compute_idf(index_path): + from pyserini.index import IndexReader index_reader = IndexReader(index_path) tokens = [] dfs = []