Skip to content

Commit

Permalink
change option from model-name-or-path to simpler model, fix flake8 le…
Browse files Browse the repository at this point in the history
…n 120 (#40)

* change option from hgf's model-name-or-path to simpler model, fix flake8 max-len 120

* fix hgf -> huggingface
  • Loading branch information
ronakice authored May 29, 2020
1 parent 6c49f6c commit 82dc086
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 30 deletions.
38 changes: 19 additions & 19 deletions docs/experiments-msmarco-passage.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,23 @@ First, lets evaluate using monoBERT!

```
python -um pygaggle.run.evaluate_passage_ranker --split dev \
--method seq_class_transformer \
--model-name-or-path castorini/monobert-large-msmarco \
--dataset data/msmarco_ans_small/ \
--index-dir indexes/index-msmarco-passage-20191117-0ed488 \
--task msmarco \
--output-file runs/run.monobert.ans_small.dev.tsv
--method seq_class_transformer \
--model castorini/monobert-large-msmarco \
--dataset data/msmarco_ans_small/ \
--index-dir indexes/index-msmarco-passage-20191117-0ed488 \
--task msmarco \
--output-file runs/run.monobert.ans_small.dev.tsv
```

Upon completion, the following output will be visible:

```
precision@1 0.2761904761904762
recall@3 0.42698412698412697
recall@50 0.8174603174603176
recall@1000 0.8476190476190476
mrr 0.41089693612003686
mrr@10 0.4026795162509449
precision@1 0.2761904761904762
recall@3 0.42698412698412697
recall@50 0.8174603174603176
recall@1000 0.8476190476190476
mrr 0.41089693612003686
mrr@10 0.4026795162509449
```

It takes about ~52 minutes to re-rank this subset on MS MARCO using a P100.
Expand All @@ -106,7 +106,7 @@ We use the monoT5-base variant as it is the easiest to run without access to lar
```
python -um pygaggle.run.evaluate_passage_ranker --split dev \
--method t5 \
--model-name-or-path castorini/monot5-base-msmarco \
--model castorini/monot5-base-msmarco \
--dataset data/msmarco_ans_small \
--model-type t5-base \
--task msmarco \
Expand All @@ -118,12 +118,12 @@ python -um pygaggle.run.evaluate_passage_ranker --split dev \
The following output will be visible after it has finished:

```
precision@1 0.26666666666666666
recall@3 0.4603174603174603
recall@50 0.8063492063492063
recall@1000 0.8476190476190476
mrr 0.3973368360121561
mrr@10 0.39044217687074834
precision@1 0.26666666666666666
recall@3 0.4603174603174603
recall@50 0.8063492063492063
recall@1000 0.8476190476190476
mrr 0.3973368360121561
mrr@10 0.39044217687074834
```

It takes about ~13 minutes to re-rank this subset on MS MARCO using a P100.
Expand Down
23 changes: 13 additions & 10 deletions pygaggle/run/evaluate_passage_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class PassageRankingEvaluationOptions(BaseModel):
dataset: Path
index_dir: Path
method: str
model_name_or_path: str
model: str
split: str
batch_size: int
device: str
Expand All @@ -63,8 +63,8 @@ def index_dir_exists(cls, v: Path):
assert v.exists(), 'index directory must exist'
return v

@validator('model_name_or_path')
def model_name_sane(cls, v: Optional[str], values, **kwargs):
@validator('model')
def model_sane(cls, v: str, values, **kwargs):
method = values['method']
if method == 'transformer' and v is None:
raise ValueError('transformer name or path must be specified')
Expand All @@ -73,13 +73,13 @@ def model_name_sane(cls, v: Optional[str], values, **kwargs):
@validator('tokenizer_name')
def tokenizer_sane(cls, v: str, values, **kwargs):
if v is None:
return values['model_name_or_path']
return values['model']
return v


def construct_t5(options: PassageRankingEvaluationOptions) -> Reranker:
device = torch.device(options.device)
model = T5ForConditionalGeneration.from_pretrained(options.model_name_or_path,
model = T5ForConditionalGeneration.from_pretrained(options.model,
from_tf=options.from_tf).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(options.model_type)
tokenizer = T5BatchTokenizer(tokenizer, options.batch_size)
Expand All @@ -89,7 +89,7 @@ def construct_t5(options: PassageRankingEvaluationOptions) -> Reranker:
def construct_transformer(options:
PassageRankingEvaluationOptions) -> Reranker:
device = torch.device(options.device)
model = AutoModel.from_pretrained(options.model_name_or_path,
model = AutoModel.from_pretrained(options.model,
from_tf=options.from_tf).to(device).eval()
tokenizer = SimpleBatchTokenizer(AutoTokenizer.from_pretrained(
options.tokenizer_name),
Expand All @@ -102,15 +102,15 @@ def construct_seq_class_transformer(options: PassageRankingEvaluationOptions
) -> Reranker:
try:
model = AutoModelForSequenceClassification.from_pretrained(
options.model_name_or_path, from_tf=options.from_tf)
options.model, from_tf=options.from_tf)
except AttributeError:
# Hotfix for BioBERT MS MARCO. Refactor.
BertForSequenceClassification.bias = torch.nn.Parameter(
torch.zeros(2))
BertForSequenceClassification.weight = torch.nn.Parameter(
torch.zeros(2, 768))
model = BertForSequenceClassification.from_pretrained(
options.model_name_or_path, from_tf=options.from_tf)
options.model, from_tf=options.from_tf)
model.classifier.weight = BertForSequenceClassification.weight
model.classifier.bias = BertForSequenceClassification.bias
device = torch.device(options.device)
Expand All @@ -134,7 +134,10 @@ def main():
required=True,
type=str,
choices=METHOD_CHOICES),
opt('--model-name-or-path', type=str),
opt('--model',
required=True,
type=str,
help='Path to pre-trained model or huggingface model name'),
opt('--output-file', type=Path, default='.'),
opt('--overwrite-output', action='store_true'),
opt('--split',
Expand All @@ -150,7 +153,7 @@ def main():
nargs='+',
default=metric_names(),
choices=metric_names()),
opt('--model-type', type=str, default='bert-base'),
opt('--model-type', type=str),
opt('--tokenizer-name', type=str))
args = apb.parser.parse_args()
options = PassageRankingEvaluationOptions(**vars(args))
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[flake8]
max-line-length = 100
max-line-length = 120

0 comments on commit 82dc086

Please sign in to comment.