From 55e4961dfc26f30e2e19347a3c81e496f81460b8 Mon Sep 17 00:00:00 2001 From: Ronak Date: Thu, 30 Apr 2020 09:13:23 +0530 Subject: [PATCH] MSMARCO support: monoBERT (#14) * add monobert for marco * temp change python3.7 to 3.6 for colab compatibility * fix evaluation options * fix issues * add missing options in evaluate_passage_ranker * working monobert * update transformers, clean code * update tokenizers * add dataclasses if < 3.7 * cleanup todos * update to newer transformers along with syntax, clean up settings * model-name-or-path as str type * fix tokenizer loading for t5 --- pygaggle/data/__init__.py | 1 + pygaggle/data/msmarco.py | 137 ++++++++++++++++++ pygaggle/data/relevance.py | 21 +++ pygaggle/data/unicode.py | 8 ++ pygaggle/model/decode.py | 5 +- pygaggle/model/evaluate.py | 17 +++ pygaggle/rerank/base.py | 5 +- pygaggle/run/evaluate_kaggle_highlighter.py | 4 +- pygaggle/run/evaluate_passage_ranker.py | 150 ++++++++++++++++++++ pygaggle/settings.py | 18 ++- setup.py | 9 +- 11 files changed, 360 insertions(+), 15 deletions(-) create mode 100644 pygaggle/data/msmarco.py create mode 100644 pygaggle/data/unicode.py create mode 100644 pygaggle/run/evaluate_passage_ranker.py diff --git a/pygaggle/data/__init__.py b/pygaggle/data/__init__.py index 8910717c..8a82448e 100644 --- a/pygaggle/data/__init__.py +++ b/pygaggle/data/__init__.py @@ -1,2 +1,3 @@ from .kaggle import * from .relevance import * +from .msmarco import * diff --git a/pygaggle/data/msmarco.py b/pygaggle/data/msmarco.py new file mode 100644 index 00000000..3aa405f9 --- /dev/null +++ b/pygaggle/data/msmarco.py @@ -0,0 +1,137 @@ +import os +from collections import OrderedDict, defaultdict +from typing import List, Set, DefaultDict +import json +import logging +from itertools import permutations + +from pydantic import BaseModel +import scipy.special as sp +import numpy as np + +from .relevance import RelevanceExample, MsMarcoPassageLoader +from pygaggle.model.tokenize import SpacySenticizer +from pygaggle.rerank.base import Query, Text +from pygaggle.data.unicode import convert_to_unicode + + +__all__ = ['MsMarcoExample', 'MsMarcoDataset'] + + +class MsMarcoExample(BaseModel): + qid: str + text: str + candidates: List[str] + relevant_candidates: Set[str] + +class MsMarcoDataset(BaseModel): + examples: List[MsMarcoExample] + + @classmethod + def load_qrels(cls, path: str) -> DefaultDict[str, Set[str]]: + qrels = defaultdict(set) + with open(path) as f: + for i, line in enumerate(f): + qid, _, doc_id, relevance = line.rstrip().split('\t') + if int(relevance) >= 1: + qrels[qid].add(doc_id) + return qrels + + @classmethod + def load_run(cls, path: str): + '''Returns OrderedDict[str, List[str]]''' + run = OrderedDict() + with open(path) as f: + for i, line in enumerate(f): + qid, doc_title, rank = line.split('\t') + if qid not in run: + run[qid] = [] + run[qid].append((doc_title, int(rank))) + sorted_run = OrderedDict() + for qid, doc_titles_ranks in run.items(): + sorted(doc_titles_ranks, key=lambda x: x[1]) + doc_titles = [doc_titles for doc_titles, _ in doc_titles_ranks] + sorted_run[qid] = doc_titles + return sorted_run + + @classmethod + def load_queries(cls, + path: str, + qrels: DefaultDict[str, Set[str]], + run) -> List[MsMarcoExample]: + queries = [] + with open(path) as f: + for i, line in enumerate(f): + qid, query = line.rstrip().split('\t') + candidates = run[qid] + queries.append(MsMarcoExample(qid = qid, + text = query, + candidates = run[qid], + relevant_candidates = qrels[qid])) + return queries + + @classmethod + def from_folder(cls, + folder: str, + split: str = 'dev', + is_duo: bool = False) -> 'MsMarcoDataset': + run_mono = "mono." if is_duo else "" + query_path = os.path.join(folder, f"queries.{split}.small.tsv") + qrels_path = os.path.join(folder, f"qrels.{split}.small.tsv") + run_path = os.path.join(folder, f"run.{run_mono}{split}.small.tsv") + return cls(examples = cls.load_queries(query_path, + cls.load_qrels(qrels_path), + cls.load_run(run_path))) + + + def query_passage_tuples(self, is_duo: bool = False): + return (((ex.qid, ex.text, ex.relevant_candidates), perm_pas) for ex in self.examples + for perm_pas in permutations(ex.candidates, r=1+int(is_duo))) + + + def to_relevance_examples(self, + index_path: str, + is_duo: bool = False) -> List[RelevanceExample]: + loader = MsMarcoPassageLoader(index_path) + example_map = {} + for (qid, text, rel_cands), cands in self.query_passage_tuples(): + if qid not in example_map: + example_map[qid] = [convert_to_unicode(text), [], [], []] + example_map[qid][1].append([cand for cand in cands][0]) + try: + passages = [loader.load_passage(cand) for cand in cands] + example_map[qid][2].append([convert_to_unicode(passage.all_text) for passage in passages][0]) + except ValueError as e: + logging.warning(f'Skipping {passages}') + continue + example_map[qid][3].append(cands[0] in rel_cands) + mean_stats = defaultdict(list) + for ex in self.examples: + int_rels = np.array(list(map(int, example_map[ex.qid][3]))) + p = int_rels.sum()/(len(ex.candidates) - 1) if is_duo else int_rels.sum() + mean_stats['Random P@1'].append(np.mean(int_rels)) + n = len(ex.candidates) - p + N = len(ex.candidates) + if len(ex.candidates) <= 1000: + mean_stats['Random R@1000'].append(1 if 1 in int_rels else 0) + numer = np.array([sp.comb(n, i) / (N - i) for i in range(0, n + 1) if i!=N]) * p + if n == N: + numer = np.append(numer, 0) + denom = np.array([sp.comb(N, i) for i in range(0, n + 1)]) + rr = 1 / np.arange(1, n + 2) + rmrr = np.sum(numer * rr / denom) + mean_stats['Random MRR'].append(rmrr) + rmrr10 = np.sum(numer[:10] * rr[:10] / denom[:10]) + mean_stats['Random MRR@10'].append(rmrr10) + ex_index = len(ex.candidates) + for rel_cand in ex.relevant_candidates: + if rel_cand in ex.candidates: + ex_index = min(ex.candidates.index(rel_cand), ex_index) + mean_stats['Existing MRR'].append(1 / (ex_index + 1) if ex_index < len(ex.candidates) else 0) + mean_stats['Existing MRR@10'].append(1 / (ex_index + 1) if ex_index < 10 else 0) + for k, v in mean_stats.items(): + logging.info(f'{k}: {np.mean(v)}') + return [RelevanceExample(Query(text=query_text, id=qid), + list(map(lambda s: Text(s[1], dict(docid=s[0])), zip(cands, cands_text))), + rel_cands) \ + for qid, (query_text, cands, cands_text, rel_cands) in example_map.items()] diff --git a/pygaggle/data/relevance.py b/pygaggle/data/relevance.py index a0e65620..5fbdb518 100644 --- a/pygaggle/data/relevance.py +++ b/pygaggle/data/relevance.py @@ -30,6 +30,15 @@ def all_text(self): return '\n'.join((self.abstract, self.body_text, self.ref_entries)) +@dataclass +class MsMarcoPassage: + para_text: str + + @property + def all_text(self): + return self.para_text + + class Cord19DocumentLoader: double_space_pattern = re.compile(r'\s\s+') @@ -50,3 +59,15 @@ def unfold(entries): return Cord19Document(unfold(article['abstract']), unfold(article['body_text']), unfold(ref_entries)) + + +class MsMarcoPassageLoader: + def __init__(self, index_path: str): + self.searcher = pysearch.SimpleSearcher(index_path) + + def load_passage(self, id: str) -> MsMarcoPassage: + try: + passage = self.searcher.doc(id).lucene_document().get('raw') + except AttributeError: + raise ValueError('passage unretrievable') + return MsMarcoPassage(passage) diff --git a/pygaggle/data/unicode.py b/pygaggle/data/unicode.py new file mode 100644 index 00000000..fc1a0d26 --- /dev/null +++ b/pygaggle/data/unicode.py @@ -0,0 +1,8 @@ +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) \ No newline at end of file diff --git a/pygaggle/model/decode.py b/pygaggle/model/decode.py index 3b26f301..b6298f7d 100644 --- a/pygaggle/model/decode.py +++ b/pygaggle/model/decode.py @@ -19,7 +19,10 @@ def greedy_decode(model: PreTrainedModel, past = model.get_encoder()(input_ids, attention_mask=attention_mask) next_token_logits = None for _ in range(length): - model_inputs = model.prepare_inputs_for_generation(decode_ids, past=past, attention_mask=attention_mask) + model_inputs = model.prepare_inputs_for_generation(decode_ids, + past=past, + attention_mask=attention_mask, + use_cache=True) outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size, vocab_size) decode_ids = torch.cat([decode_ids, next_token_logits.max(1)[1].unsqueeze(-1)], dim=-1) diff --git a/pygaggle/model/evaluate.py b/pygaggle/model/evaluate.py index 3c355bbb..26958234 100644 --- a/pygaggle/model/evaluate.py +++ b/pygaggle/model/evaluate.py @@ -100,6 +100,16 @@ class RecallAt3Metric(TopkMixin, RecallAccumulator): top_k = 3 +@register_metric('recall@50') +class RecallAt50Metric(TopkMixin, RecallAccumulator): + top_k = 50 + + +@register_metric('recall@1000') +class RecallAt1000Metric(TopkMixin, RecallAccumulator): + top_k = 1000 + + @register_metric('mrr') class MrrMetric(MeanAccumulator): def accumulate(self, scores: List[float], gold: RelevanceExample): @@ -108,6 +118,13 @@ def accumulate(self, scores: List[float], gold: RelevanceExample): self.scores.append(rr) +@register_metric('mrr@10') +class MrrAt10Metric(MeanAccumulator): + def accumulate(self, scores: List[float], gold: RelevanceExample): + scores = sorted(list(enumerate(scores)), key=lambda x: x[1], reverse=True) + rr = next((1 / (rank_idx + 1) for rank_idx, (idx, _) in enumerate(scores) if (gold.labels[idx] and rank_idx < 10)), 0) + self.scores.append(rr) + class ThresholdedRecallMetric(DynamicThresholdingMixin, RecallAccumulator): threshold = 0.5 diff --git a/pygaggle/rerank/base.py b/pygaggle/rerank/base.py index 6c47bd92..a8e6b908 100644 --- a/pygaggle/rerank/base.py +++ b/pygaggle/rerank/base.py @@ -17,9 +17,12 @@ class Query: ---------- text : str The query text. + id : Optional[str] + The query id. """ - def __init__(self, text: str): + def __init__(self, text: str, id: Optional[str] = None): self.text = text + self.id = id class Text: diff --git a/pygaggle/run/evaluate_kaggle_highlighter.py b/pygaggle/run/evaluate_kaggle_highlighter.py index 9810cd79..459df352 100644 --- a/pygaggle/run/evaluate_kaggle_highlighter.py +++ b/pygaggle/run/evaluate_kaggle_highlighter.py @@ -16,10 +16,10 @@ from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider from pygaggle.model import SimpleBatchTokenizer, CachedT5ModelLoader, T5BatchTokenizer, RerankerEvaluator, metric_names from pygaggle.data import LitReviewDataset -from pygaggle.settings import Settings +from pygaggle.settings import Cord19Settings -SETTINGS = Settings() +SETTINGS = Cord19Settings() METHOD_CHOICES = ('transformer', 'bm25', 't5', 'seq_class_transformer', 'qa_transformer', 'random') diff --git a/pygaggle/run/evaluate_passage_ranker.py b/pygaggle/run/evaluate_passage_ranker.py new file mode 100644 index 00000000..693c6292 --- /dev/null +++ b/pygaggle/run/evaluate_passage_ranker.py @@ -0,0 +1,150 @@ +from typing import Optional, List +from pathlib import Path +import logging + +from pydantic import BaseModel, validator +from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, BertForSequenceClassification +import torch + +from .args import ArgumentParserBuilder, opt +from pygaggle.rerank.base import Reranker +from pygaggle.rerank.bm25 import Bm25Reranker +from pygaggle.rerank.transformer import UnsupervisedTransformerReranker, T5Reranker, \ + SequenceClassificationTransformerReranker +from pygaggle.rerank.random import RandomReranker +from pygaggle.rerank.similarity import CosineSimilarityMatrixProvider +from pygaggle.model import SimpleBatchTokenizer, CachedT5ModelLoader, T5BatchTokenizer, RerankerEvaluator, metric_names +from pygaggle.data import MsMarcoDataset +from pygaggle.settings import MsMarcoSettings + + +SETTINGS = MsMarcoSettings() +METHOD_CHOICES = ('transformer', 'bm25', 't5', 'seq_class_transformer', 'random') + + +class PassageRankingEvaluationOptions(BaseModel): + dataset: str + data_dir: Path + method: str + model_name_or_path: str + split: str + batch_size: int + device: str + is_duo: bool + metrics: List[str] + model_type: Optional[str] + tokenizer_name: Optional[str] + index_dir: Optional[Path] + + @validator('dataset') + def dataset_exists(cls, v: str): + assert v in ['msmarco', 'treccar'] + + @validator('data_dir') + def datadir_exists(cls, v: str): + assert v.exists(), 'data directory must exist' + return v + + @validator('index_dir') + def index_dir_exists(cls, v: str): + if v is None: + return SETTINGS.msmarco_index_path + return v + + @validator('model_name_or_path') + def model_name_sane(cls, v: Optional[str], values, **kwargs): + method = values['method'] + if method == 'transformer' and v is None: + raise ValueError('transformer name or path must be specified') + return v + + @validator('tokenizer_name') + def tokenizer_sane(cls, v: str, values, **kwargs): + if v is None: + return values['model_name_or_path'] + return v + + +def construct_t5(options: PassageRankingEvaluationOptions) -> Reranker: + loader = CachedT5ModelLoader(options.model_name_or_path, + SETTINGS.cache_dir, + 'ranker', + options.model_type, + SETTINGS.flush_cache) + device = torch.device(options.device) + model = loader.load().to(device).eval() + tokenizer = AutoTokenizer.from_pretrained(options.model_type) + tokenizer = T5BatchTokenizer(tokenizer, options.batch_size) + return T5Reranker(model, tokenizer) + + +def construct_transformer(options: PassageRankingEvaluationOptions) -> Reranker: + device = torch.device(options.device) + try: + model = AutoModel.from_pretrained(options.model_name_or_path).to(device).eval() + except OSError: + model = AutoModel.from_pretrained(options.model_name_or_path, from_tf=True).to(device).eval() + tokenizer = SimpleBatchTokenizer(AutoTokenizer.from_pretrained(options.tokenizer_name), + options.batch_size) + provider = CosineSimilarityMatrixProvider() + return UnsupervisedTransformerReranker(model, tokenizer, provider) + + +def construct_seq_class_transformer(options: PassageRankingEvaluationOptions) -> Reranker: + try: + model = AutoModelForSequenceClassification.from_pretrained(options.model_name_or_path) + except OSError: + try: + model = AutoModelForSequenceClassification.from_pretrained(options.model_name_or_path, from_tf=True) + except AttributeError: + # Hotfix for BioBERT MS MARCO. Refactor. + BertForSequenceClassification.bias = torch.nn.Parameter(torch.zeros(2)) + BertForSequenceClassification.weight = torch.nn.Parameter(torch.zeros(2, 768)) + model = BertForSequenceClassification.from_pretrained(options.model_name_or_path, from_tf=True) + model.classifier.weight = BertForSequenceClassification.weight + model.classifier.bias = BertForSequenceClassification.bias + device = torch.device(options.device) + model = model.to(device).eval() + tokenizer = AutoTokenizer.from_pretrained(options.tokenizer_name) + return SequenceClassificationTransformerReranker(model, tokenizer) + + +def construct_bm25(options: PassageRankingEvaluationOptions) -> Reranker: + return Bm25Reranker(index_path=options.msmarco_index_path) + + +def main(): + apb = ArgumentParserBuilder() + apb.add_opts(opt('--dataset', type=str, default='msmarco'), + opt('--data-dir', type=Path, default='/content/data/msmarco'), + opt('--method', required=True, type=str, choices=METHOD_CHOICES), + opt('--model-name-or-path', type=str), + opt('--split', type=str, default='dev', choices=('dev', 'eval')), + opt('--batch-size', '-bsz', type=int, default=96), + opt('--device', type=str, default='cuda:0'), + opt('--is-duo', action='store_true'), + opt('--metrics', type=str, nargs='+', default=metric_names(), choices=metric_names()), + opt('--model-type', type=str, default='bert-base'), + opt('--tokenizer-name', type=str), + opt('--index-dir', type=Path)) + args = apb.parser.parse_args() + options = PassageRankingEvaluationOptions(**vars(args)) + ds = MsMarcoDataset.from_folder(str(options.data_dir), split=options.split, is_duo=options.is_duo) + examples = ds.to_relevance_examples(SETTINGS.msmarco_index_path, is_duo=options.is_duo) + construct_map = dict(transformer=construct_transformer, + bm25=construct_bm25, + t5=construct_t5, + seq_class_transformer=construct_seq_class_transformer, + random=lambda _: RandomReranker()) + reranker = construct_map[options.method](options) + evaluator = RerankerEvaluator(reranker, options.metrics) + width = max(map(len, args.metrics)) + 1 + stdout = [] + for metric in evaluator.evaluate(examples): + logging.info(f'{metric.name:<{width}}{metric.value:.5}') + stdout.append(f'{metric.name}\t{metric.value}') + print('\n'.join(stdout)) + + +if __name__ == '__main__': + main() diff --git a/pygaggle/settings.py b/pygaggle/settings.py index 53c3e30b..ac239e49 100644 --- a/pygaggle/settings.py +++ b/pygaggle/settings.py @@ -5,13 +5,17 @@ class Settings(BaseSettings): - cord19_index_path: str = 'data/lucene-index-covid-paragraph' - - # T5 model settings - t5_model_dir: str = 'gs://neuralresearcher_data/covid/data/model_exp304' - t5_model_type: str = 't5-base' - t5_max_length: int = 512 - # Cache settings cache_dir: Path = Path(os.getenv('XDG_CACHE_HOME', str(Path.home() / '.cache'))) / 'covidex' flush_cache: bool = False + + +class MsMarcoSettings(Settings): + msmarco_index_path: str = 'data/index-msmarco-passage-20191117-0ed488' + +class Cord19Settings(Settings): + cord19_index_path: str = 'data/lucene-index-covid-paragraph' + + # T5 model settings + t5_model_dir: str = 'gs://neuralresearcher_data/covid/data/model_exp304' + t5_model_type: str = 't5-base' \ No newline at end of file diff --git a/setup.py b/setup.py index 5bfd1ab2..781838e1 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ long_description = fh.read() reqs = [ + 'dataclasses;python_version<"3.7"', 'coloredlogs==14.0', 'numpy==1.18.2', 'pydantic==1.5', @@ -14,9 +15,9 @@ 'spacy==2.2.4', 'tensorboard>=2.1.0', 'tensorflow>=2.2.0rc1', - 'tokenizers==0.5.2', + 'tokenizers>=0.5.2', 'tqdm==4.45.0', - 'transformers==2.7.0' + 'transformers==2.8.0' ] setuptools.setup( @@ -24,7 +25,7 @@ version='0.0.1', author='PyGaggle Gaggle', author_email='r33tang@uwaterloo.ca', - description='A gaggle of rerankers for CovidQA and CORD-19', + description='A gaggle of rerankers for CovidQA, CORD-19 and MS-MARCO', long_description=long_description, long_description_content_type='text/markdown', url='https://github.com/castorini/pygaggle', @@ -35,5 +36,5 @@ 'License :: OSI Approved :: Apache Software License', 'Operating System :: OS Independent', ], - python_requires='>=3.7', + python_requires='>=3.6', ) \ No newline at end of file