Skip to content

Commit

Permalink
Add constructor functions for model and tokenizer of MonoBERT/T5 (#93)
Browse files Browse the repository at this point in the history
* Add construct functions for model and tokenizer of MonoBERT/T5

* Remove useless import

* Add pretrained_name as argument
  • Loading branch information
yuxuan-ji authored Oct 23, 2020
1 parent 6a487b7 commit 978a071
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 58 deletions.
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ Here's how to initalize the T5 reranker from [Document Ranking with a Pretrained
from pygaggle.rerank.base import Query, Text
from pygaggle.rerank.transformer import MonoT5

model_name = 'castorini/monot5-base-msmarco'
tokenizer_name = 't5-base'
reranker = MonoT5(model_name, tokenizer_name)
reranker = MonoT5()
```

Alternatively, here's the BERT reranker from [Passage Re-ranking with BERT](https://arxiv.org/pdf/1901.04085.pdf), which isn't as good as the T5 reranker:
Expand All @@ -49,9 +47,7 @@ Alternatively, here's the BERT reranker from [Passage Re-ranking with BERT](http
from pygaggle.rerank.base import Query, Text
from pygaggle.rerank.transformer import MonoBERT

model_name = 'castorini/monobert-large-msmarco'
tokenizer_name = 'bert-large-uncased'
reranker = MonoBERT(model_name, tokenizer_name)
reranker = MonoBERT()
```

Either way, continue with a complete reranking example:
Expand Down
59 changes: 36 additions & 23 deletions pygaggle/rerank/transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import List, Union
from typing import List

from transformers import (AutoTokenizer,
AutoModelForSequenceClassification,
Expand All @@ -26,19 +26,27 @@

class MonoT5(Reranker):
def __init__(self,
model_name_or_instance: Union[str, T5ForConditionalGeneration] = 'castorini/monot5-base-msmarco',
tokenizer_name_or_instance: Union[str, QueryDocumentBatchTokenizer] = 't5-base'):
if isinstance(model_name_or_instance, str):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name_or_instance = T5ForConditionalGeneration.from_pretrained(model_name_or_instance).to(device).eval()
self.model = model_name_or_instance

if isinstance(tokenizer_name_or_instance, str):
tokenizer_name_or_instance = T5BatchTokenizer(AutoTokenizer.from_pretrained(tokenizer_name_or_instance), batch_size=8)
self.tokenizer = tokenizer_name_or_instance

model: T5ForConditionalGeneration = None,
tokenizer: QueryDocumentBatchTokenizer = None):
self.model = model or self.get_model()
self.tokenizer = tokenizer or self.get_tokenizer()
self.device = next(self.model.parameters(), None).device

@staticmethod
def get_model(pretrained_model_name_or_path: str = 'castorini/monot5-base-msmarco',
*args, device: str = None, **kwargs) -> T5ForConditionalGeneration:
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device(device)
return T5ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path, *args, **kwargs).to(device).eval()

@staticmethod
def get_tokenizer(pretrained_model_name_or_path: str = 't5-base',
*args, batch_size: int = 8, **kwargs) -> T5BatchTokenizer:
return T5BatchTokenizer(
AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *args, **kwargs),
batch_size=batch_size
)

def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
texts = deepcopy(texts)
batch_input = QueryDocumentBatch(query=query, documents=texts)
Expand Down Expand Up @@ -108,19 +116,24 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:

class MonoBERT(Reranker):
def __init__(self,
model_name_or_instance: Union[str, PreTrainedModel] = 'castorini/monobert-large-msmarco',
tokenizer_name_or_instance: Union[str, PreTrainedTokenizer] = 'bert-large-uncased'):
if isinstance(model_name_or_instance, str):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name_or_instance = AutoModelForSequenceClassification.from_pretrained(model_name_or_instance).to(device).eval()
self.model = model_name_or_instance

if isinstance(tokenizer_name_or_instance, str):
tokenizer_name_or_instance = AutoTokenizer.from_pretrained(tokenizer_name_or_instance)
self.tokenizer = tokenizer_name_or_instance

model: PreTrainedModel = None,
tokenizer: PreTrainedTokenizer = None):
self.model = model or self.get_model()
self.tokenizer = tokenizer or self.get_tokenizer()
self.device = next(self.model.parameters(), None).device

@staticmethod
def get_model(pretrained_model_name_or_path: str = 'castorini/monobert-large-msmarco',
*args, device: str = None, **kwargs) -> AutoModelForSequenceClassification:
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device(device)
return AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *args, **kwargs).to(device).eval()

@staticmethod
def get_tokenizer(pretrained_model_name_or_path: str = 'bert-large-uncased',
*args, **kwargs) -> AutoTokenizer:
return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

@torch.no_grad()
def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
texts = deepcopy(texts)
Expand Down
15 changes: 6 additions & 9 deletions pygaggle/run/evaluate_document_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,10 @@ def tokenizer_sane(cls, v: str, values, **kwargs):


def construct_t5(options: DocumentRankingEvaluationOptions) -> Reranker:
device = torch.device(options.device)
model = T5ForConditionalGeneration.from_pretrained(options.model,
from_tf=options.from_tf).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.model_type)
tokenizer = T5BatchTokenizer(tokenizer, options.batch_size)
model = MonoT5.get_model(options.model,
from_tf=options.from_tf,
device=options.device)
tokenizer = MonoT5.get_tokenizer(options.model_type, batch_size=options.batch_size)
return MonoT5(model, tokenizer)


Expand All @@ -102,10 +101,8 @@ def construct_transformer(options:

def construct_seq_class_transformer(options: DocumentRankingEvaluationOptions
) -> Reranker:
model = AutoModelForSequenceClassification.from_pretrained(options.model, from_tf=options.from_tf)
device = torch.device(options.device)
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.tokenizer_name)
model = MonoBERT.get_model(options.model, from_tf=options.from_tf, device=options.device)
tokenizer = MonoBERT.get_tokenizer(options.tokenizer_name)
return MonoBERT(model, tokenizer)


Expand Down
20 changes: 10 additions & 10 deletions pygaggle/run/evaluate_kaggle_highlighter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def construct_t5(options: KaggleEvaluationOptions) -> Reranker:
SETTINGS.flush_cache)
device = torch.device(options.device)
model = loader.load().to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(
options.model_name, do_lower_case=options.do_lower_case)
tokenizer = T5BatchTokenizer(tokenizer, options.batch_size)
tokenizer = MonoT5.get_tokenizer(options.model_type,
do_lower_case=options.do_lower_case,
batch_size=options.batch_size)
return MonoT5(model, tokenizer)


Expand All @@ -103,13 +103,13 @@ def construct_transformer(options: KaggleEvaluationOptions) -> Reranker:
def construct_seq_class_transformer(options:
KaggleEvaluationOptions) -> Reranker:
try:
model = AutoModelForSequenceClassification.from_pretrained(
options.model_name)
model = MonoBERT.get_model(options.model_name, device=options.device)
except OSError:
try:
model = AutoModelForSequenceClassification.from_pretrained(
model = MonoBERT.get_model(
options.model_name,
from_tf=True)
from_tf=True,
device=options.device)
except AttributeError:
# Hotfix for BioBERT MS MARCO. Refactor.
BertForSequenceClassification.bias = torch.nn.Parameter(
Expand All @@ -120,9 +120,9 @@ def construct_seq_class_transformer(options:
options.model_name, from_tf=True)
model.classifier.weight = BertForSequenceClassification.weight
model.classifier.bias = BertForSequenceClassification.bias
device = torch.device(options.device)
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(
device = torch.device(options.device)
model = model.to(device).eval()
tokenizer = MonoBERT.get_tokenizer(
options.tokenizer_name, do_lower_case=options.do_lower_case)
return MonoBERT(model, tokenizer)

Expand Down
19 changes: 9 additions & 10 deletions pygaggle/run/evaluate_passage_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,10 @@ def tokenizer_sane(cls, v: str, values, **kwargs):


def construct_t5(options: PassageRankingEvaluationOptions) -> Reranker:
device = torch.device(options.device)
model = T5ForConditionalGeneration.from_pretrained(options.model,
from_tf=options.from_tf).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.model_type)
tokenizer = T5BatchTokenizer(tokenizer, options.batch_size)
model = MonoT5.get_model(options.model,
from_tf=options.from_tf,
device=options.device)
tokenizer = MonoT5.get_tokenizer(options.model_type, batch_size=options.batch_size)
return MonoT5(model, tokenizer)


Expand All @@ -101,8 +100,8 @@ def construct_transformer(options:
def construct_seq_class_transformer(options: PassageRankingEvaluationOptions
) -> Reranker:
try:
model = AutoModelForSequenceClassification.from_pretrained(
options.model, from_tf=options.from_tf)
model = MonoBERT.get_model(
options.model, from_tf=options.from_tf, device=options.device)
except AttributeError:
# Hotfix for BioBERT MS MARCO. Refactor.
BertForSequenceClassification.bias = torch.nn.Parameter(
Expand All @@ -113,9 +112,9 @@ def construct_seq_class_transformer(options: PassageRankingEvaluationOptions
options.model, from_tf=options.from_tf)
model.classifier.weight = BertForSequenceClassification.weight
model.classifier.bias = BertForSequenceClassification.bias
device = torch.device(options.device)
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.tokenizer_name)
device = torch.device(options.device)
model = model.to(device).eval()
tokenizer = MonoBERT.get_tokenizer(options.tokenizer_name)
return MonoBERT(model, tokenizer)


Expand Down

0 comments on commit 978a071

Please sign in to comment.