Skip to content

Commit

Permalink
onnx & fp16 for tct_colbert encoding (#697)
Browse files Browse the repository at this point in the history
* onnx & fp16 for tct_colbert encoding

* fix bug
  • Loading branch information
MXueguang authored Jul 13, 2021
1 parent 5cc929d commit 228d5c9
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pyserini/encode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@
from ._dpr import DprDocumentEncoder, DprQueryEncoder
from ._tct_colbert import TctColBertDocumentEncoder, TctColBertQueryEncoder
from ._unicoil import UniCoilEncoder, UniCoilDocumentEncoder, UniCoilQueryEncoder
from ._pseudo import PseudoQueryEncoder
from ._cached_data import CachedDataQueryEncoder
from ._tok_freq import TokFreqQueryEncoder
3 changes: 1 addition & 2 deletions pyserini/encode/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import faiss
import torch
import numpy as np
from tqdm import tqdm


Expand Down Expand Up @@ -144,8 +145,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
faiss.write_index(self.index, os.path.join(self.dir_path, self.index_name))

def write(self, batch_info, fields=None):
if fields:
print("Warning, for Faiss Index, we do not save contents")
for id_ in batch_info['id']:
self.id_file.write(f'{id_}\n')
self.index.add(batch_info['vector'])
42 changes: 32 additions & 10 deletions pyserini/encode/_tct_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,34 @@
#

import numpy as np
from transformers import BertModel, BertTokenizer
import torch
from torch.cuda.amp import autocast
from transformers import BertModel, BertTokenizer, BertTokenizerFast

from pyserini.encode import DocumentEncoder, QueryEncoder
from onnxruntime import ExecutionMode, SessionOptions, InferenceSession


class TctColBertDocumentEncoder(DocumentEncoder):
def __init__(self, model_name, tokenizer_name=None, device='cuda:0'):
def __init__(self, model_name: str, tokenizer_name=None, device='cuda:0'):
self.device = device
self.model = BertModel.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or model_name)
self.onnx = False
if model_name.endswith('onnx'):
options = SessionOptions()
self.session = InferenceSession(model_name, options)
self.onnx = True
self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name or model_name[:-5])
else:
self.model = BertModel.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name or model_name)

def encode(self, texts, titles=None, **kwargs):
def encode(self, texts, titles=None, fp16=False, **kwargs):
if titles is not None:
texts = [f'[CLS] [D] {title} {text}' for title, text in zip(titles, texts)]
else:
texts = ['[CLS] [D] ' + text for text in texts]
max_length = 154 # hardcode for now
max_length = 512 # hardcode for now
inputs = self.tokenizer(
texts,
max_length=max_length,
Expand All @@ -41,9 +51,21 @@ def encode(self, texts, titles=None, **kwargs):
add_special_tokens=False,
return_tensors='pt'
)
inputs.to(self.device)
outputs = self.model(**inputs)
embeddings = self._mean_pooling(outputs["last_hidden_state"][:, 4:, :], inputs['attention_mask'][:, 4:])
if self.onnx:
inputs_onnx = {name: np.atleast_2d(value) for name, value in inputs.items()}
inputs.to(self.device)
outputs, _ = self.session.run(None, inputs_onnx)
outputs = torch.from_numpy(outputs).to(self.device)
embeddings = self._mean_pooling(outputs[:, 4:, :], inputs['attention_mask'][:, 4:])
else:
inputs.to(self.device)
if fp16:
with autocast():
with torch.no_grad():
outputs = self.model(**inputs)
else:
outputs = self.model(**inputs)
embeddings = self._mean_pooling(outputs["last_hidden_state"][:, 4:, :], inputs['attention_mask'][:, 4:])
return embeddings.detach().cpu().numpy()


Expand Down
3 changes: 1 addition & 2 deletions pyserini/search/_impact_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from ._base import Document
from pyserini.pyclass import autoclass, JFloat, JArrayList, JHashMap, JString
from pyserini.util import download_prebuilt_index
from pyserini.encode import QueryEncoder, TokFreqQueryEncoder, UniCoilQueryEncoder
from ..encode._pseudo import CachedDataQueryEncoder
from pyserini.encode import QueryEncoder, TokFreqQueryEncoder, UniCoilQueryEncoder, CachedDataQueryEncoder

logger = logging.getLogger(__name__)

Expand Down

0 comments on commit 228d5c9

Please sign in to comment.