From 5b03294e5456c0b46ee61bb8a195403d0a3f6f9b Mon Sep 17 00:00:00 2001 From: Jimmy Lin Date: Wed, 29 Jul 2020 10:28:15 -0400 Subject: [PATCH] Bump up pyserini to 0.9.4.0 (#64) --- environment.yml | 2 +- pygaggle/data/relevance.py | 4 ++-- pygaggle/rerank/base.py | 24 ++++++++++++++---------- requirements.txt | 2 +- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/environment.yml b/environment.yml index 41fb9a56..05230340 100644 --- a/environment.yml +++ b/environment.yml @@ -7,7 +7,7 @@ dependencies: - dataclasses;python_version<"3.7" - numpy>=1.18 - pydantic==1.5 - - pyserini==0.9.0.0 + - pyserini==0.9.4.0 - scikit-learn>=0.22 - scipy>=1.4 - spacy==2.2.4 diff --git a/pygaggle/data/relevance.py b/pygaggle/data/relevance.py index dc6db1ea..ab49686e 100644 --- a/pygaggle/data/relevance.py +++ b/pygaggle/data/relevance.py @@ -4,7 +4,7 @@ import json import re -from pyserini.search import pysearch +from pyserini.search import SimpleSearcher from pygaggle.rerank.base import Query, Text @@ -43,7 +43,7 @@ class Cord19DocumentLoader: double_space_pattern = re.compile(r'\s\s+') def __init__(self, index_path: str): - self.searcher = pysearch.SimpleSearcher(index_path) + self.searcher = SimpleSearcher(index_path) @lru_cache(maxsize=1024) def load_document(self, id: str) -> Cord19Document: diff --git a/pygaggle/rerank/base.py b/pygaggle/rerank/base.py index 5f07bd65..f3c33cda 100644 --- a/pygaggle/rerank/base.py +++ b/pygaggle/rerank/base.py @@ -1,10 +1,10 @@ from typing import List, Union, Optional, Mapping, Any import abc -from pyserini.pyclass import JSimpleSearcherResult +from pyserini.search import JSimpleSearcherResult -__all__ = ['Query', 'Text', 'Reranker', 'to_texts', 'TextType'] +__all__ = ['Query', 'Text', 'Reranker', 'hits_to_texts', 'TextType'] TextType = Union['Query', 'Text'] @@ -36,7 +36,7 @@ class Text: ---------- text : str The text to be reranked. - raw : Mapping[str, Any] + metadata : Mapping[str, Any] Additional metadata and other annotations. score : Optional[float] The score of the text. For example, the score might be the BM25 score @@ -45,12 +45,12 @@ class Text: def __init__(self, text: str, - raw: Mapping[str, Any] = None, + metadata: Mapping[str, Any] = None, score: Optional[float] = 0): self.text = text - if raw is None: - raw = dict() - self.raw = raw + if metadata is None: + metadata = dict() + self.metadata = metadata self.score = score @@ -78,13 +78,15 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]: pass -def to_texts(hits: List[JSimpleSearcherResult]) -> List[Text]: +def hits_to_texts(hits: List[JSimpleSearcherResult], field='raw') -> List[Text]: """Converts hits from Pyserini into a list of texts. Parameters ---------- hits : List[JSimpleSearcherResult] - The hits. + The hits. + field : str + Field to use. Returns ------- @@ -93,5 +95,7 @@ def to_texts(hits: List[JSimpleSearcherResult]) -> List[Text]: """ texts = [] for i in range(0, len(hits)): - texts.append(Text(hits[i].contents, hits[i].raw, hits[i].score)) + t = hits[i].raw if field == 'raw' else hits[i].contents + metadata = {'raw': hits[i].raw, 'docid': hits[i].docid} + texts.append(Text(t, metadata, hits[i].score)) return texts diff --git a/requirements.txt b/requirements.txt index 1c76c3a1..e53100c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ coloredlogs==14.0 dataclasses;python_version<"3.7" numpy>=1.18 pydantic==1.5 -pyserini==0.9.0.0 +pyserini==0.9.4.0 scikit-learn>=0.22 scipy>=1.4 spacy==2.2.4