Skip to content

Commit

Permalink
Bump up pyserini to 0.9.4.0 (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
lintool authored Jul 29, 2020
1 parent 377c283 commit 5b03294
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
- dataclasses;python_version<"3.7"
- numpy>=1.18
- pydantic==1.5
- pyserini==0.9.0.0
- pyserini==0.9.4.0
- scikit-learn>=0.22
- scipy>=1.4
- spacy==2.2.4
Expand Down
4 changes: 2 additions & 2 deletions pygaggle/data/relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import re

from pyserini.search import pysearch
from pyserini.search import SimpleSearcher

from pygaggle.rerank.base import Query, Text

Expand Down Expand Up @@ -43,7 +43,7 @@ class Cord19DocumentLoader:
double_space_pattern = re.compile(r'\s\s+')

def __init__(self, index_path: str):
self.searcher = pysearch.SimpleSearcher(index_path)
self.searcher = SimpleSearcher(index_path)

@lru_cache(maxsize=1024)
def load_document(self, id: str) -> Cord19Document:
Expand Down
24 changes: 14 additions & 10 deletions pygaggle/rerank/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import List, Union, Optional, Mapping, Any
import abc

from pyserini.pyclass import JSimpleSearcherResult
from pyserini.search import JSimpleSearcherResult


__all__ = ['Query', 'Text', 'Reranker', 'to_texts', 'TextType']
__all__ = ['Query', 'Text', 'Reranker', 'hits_to_texts', 'TextType']


TextType = Union['Query', 'Text']
Expand Down Expand Up @@ -36,7 +36,7 @@ class Text:
----------
text : str
The text to be reranked.
raw : Mapping[str, Any]
metadata : Mapping[str, Any]
Additional metadata and other annotations.
score : Optional[float]
The score of the text. For example, the score might be the BM25 score
Expand All @@ -45,12 +45,12 @@ class Text:

def __init__(self,
text: str,
raw: Mapping[str, Any] = None,
metadata: Mapping[str, Any] = None,
score: Optional[float] = 0):
self.text = text
if raw is None:
raw = dict()
self.raw = raw
if metadata is None:
metadata = dict()
self.metadata = metadata
self.score = score


Expand Down Expand Up @@ -78,13 +78,15 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
pass


def to_texts(hits: List[JSimpleSearcherResult]) -> List[Text]:
def hits_to_texts(hits: List[JSimpleSearcherResult], field='raw') -> List[Text]:
"""Converts hits from Pyserini into a list of texts.
Parameters
----------
hits : List[JSimpleSearcherResult]
The hits.
The hits.
field : str
Field to use.
Returns
-------
Expand All @@ -93,5 +95,7 @@ def to_texts(hits: List[JSimpleSearcherResult]) -> List[Text]:
"""
texts = []
for i in range(0, len(hits)):
texts.append(Text(hits[i].contents, hits[i].raw, hits[i].score))
t = hits[i].raw if field == 'raw' else hits[i].contents
metadata = {'raw': hits[i].raw, 'docid': hits[i].docid}
texts.append(Text(t, metadata, hits[i].score))
return texts
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ coloredlogs==14.0
dataclasses;python_version<"3.7"
numpy>=1.18
pydantic==1.5
pyserini==0.9.0.0
pyserini==0.9.4.0
scikit-learn>=0.22
scipy>=1.4
spacy==2.2.4
Expand Down

0 comments on commit 5b03294

Please sign in to comment.