From c7b37d6073cda62685f64d6d0b99dc46f0718346 Mon Sep 17 00:00:00 2001 From: stephaniewhoo Date: Thu, 3 Jun 2021 21:34:47 +0800 Subject: [PATCH] LTR refactoring for modularization (#636) Modularization LTR reranking and split ltr doc --- ...eriments-ltr-msmarco-passage-reranking.md} | 22 +- ...xperiments-ltr-msmarco-passage-training.md | 62 +++ .../sparse/test_ltr_msmarco_passage.py | 7 +- .../ltr/search_msmarco_passage/__init__.py | 18 + .../ltr/search_msmarco_passage/__main__.py | 238 +++++++++ .../_search_msmarco_passage.py | 236 +++++++++ .../rerank_with_ltr_model.py | 450 ------------------ .../ltr_msmarco-passage/train_ltr_model.py | 6 +- 8 files changed, 565 insertions(+), 474 deletions(-) rename docs/{experiments-ltr-msmarco-passage.md => experiments-ltr-msmarco-passage-reranking.md} (87%) create mode 100644 docs/experiments-ltr-msmarco-passage-training.md create mode 100644 pyserini/ltr/search_msmarco_passage/__init__.py create mode 100644 pyserini/ltr/search_msmarco_passage/__main__.py create mode 100644 pyserini/ltr/search_msmarco_passage/_search_msmarco_passage.py delete mode 100644 scripts/ltr_msmarco-passage/rerank_with_ltr_model.py diff --git a/docs/experiments-ltr-msmarco-passage.md b/docs/experiments-ltr-msmarco-passage-reranking.md similarity index 87% rename from docs/experiments-ltr-msmarco-passage.md rename to docs/experiments-ltr-msmarco-passage-reranking.md index e384690f9..f2db132b3 100644 --- a/docs/experiments-ltr-msmarco-passage.md +++ b/docs/experiments-ltr-msmarco-passage-reranking.md @@ -1,4 +1,4 @@ -# Pyserini: Learning-To-Rank Baseline for MS MARCO Passage +# Pyserini: Learning-To-Rank Reranking Baseline for MS MARCO Passage This guide contains instructions for running learning-to-rank baseline on the [MS MARCO *passage* reranking task](https://microsoft.github.io/msmarco/). Learning-to-rank serves as a second stage reranker after BM25 retrieval. @@ -52,7 +52,7 @@ tar -xzvf runs/msmarco-passage-ltr-mrr-v1.tar.gz -C runs Next we can run our inference script to get our reranking result. ```bash -python scripts/ltr_msmarco-passage/rerank_with_ltr_model.py \ +python -m pyserini.ltr.search_msmarco_passage \ --input runs/run.msmarco-passage.bm25tuned.txt \ --input-format tsv \ --model runs/msmarco-passage-ltr-mrr-v1 \ @@ -60,7 +60,9 @@ python scripts/ltr_msmarco-passage/rerank_with_ltr_model.py \ --output runs/run.ltr.msmarco-passage.tsv ``` -Here, our model is trained to maximize MRR@10. +Here, our model is trained to maximize MRR@10. + +Note that we can also train other models from scratch follow [training guide](experiments-ltr-msmarco-passage-training.md), and replace `--model` argument with your trained model dir. Inference speed will vary, on orca, it takes ~0.25s/query. @@ -99,20 +101,6 @@ Average precision or AP (also called mean average precision, MAP) and recall@100 AP captures aspects of both precision and recall in a single metric, and is the most common metric used by information retrieval researchers. On the other hand, recall@1000 provides the upper bound effectiveness of downstream reranking modules (i.e., rerankers are useless if there isn't a relevant document in the results). -## Training the Model From Scratch - -```bash -wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz -P collections/msmarco-passage/ -gzip -d collections/msmarco-passage/qidpidtriples.train.full.2.tsv.gz -``` -First download the file which has training triples and uncompress it. - -```bash -python scripts/ltr_msmarco-passage/train_ltr_model.py \ - --index indexes/index-msmarco-passage-ltr-20210519-e25e33f -``` -The above scripts will train a model at `runs/` with your running date in the file name. You can use this as the `--ltr_model_path` parameter for `predict_passage.py`. - ## Building the Index From Scratch Equivalently, we can preprocess collection and queries with our scripts: diff --git a/docs/experiments-ltr-msmarco-passage-training.md b/docs/experiments-ltr-msmarco-passage-training.md new file mode 100644 index 000000000..5a60a6cb8 --- /dev/null +++ b/docs/experiments-ltr-msmarco-passage-training.md @@ -0,0 +1,62 @@ +# Pyserini: Train Learning-To-Rank Reranking Models for MS MARCO Passage + +## Data Preprocessing + +Please first follow the [Pyserini BM25 retrieval guide](experiments-msmarco-passage.md) to obtain our reranking candidate. + +```bash +wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz -P collections/msmarco-passage/ +gzip -d collections/msmarco-passage/qidpidtriples.train.full.2.tsv.gz +``` +Then, download the file which has training triples and uncompress it. + +Next, we're going to use `collections/msmarco-ltr-passage/` as the working directory to download pre processed data. + +```bash +mkdir collections/msmarco-ltr-passage/ + +python scripts/ltr_msmarco-passage/convert_queries.py \ + --input collections/msmarco-passage/queries.eval.small.tsv \ + --output collections/msmarco-ltr-passage/queries.eval.small.json + +python scripts/ltr_msmarco-passage/convert_queries.py \ + --input collections/msmarco-passage/queries.dev.small.tsv \ + --output collections/msmarco-ltr-passage/queries.dev.small.json + +python scripts/ltr_msmarco-passage/convert_queries.py \ + --input collections/msmarco-passage/queries.train.tsv \ + --output collections/msmarco-ltr-passage/queries.train.json +``` + +The above scripts convert queries to json objects with `text`, `text_unlemm`, `raw`, and `text_bert_tok` fields. +The first two scripts take ~1 min and the third one is a bit longer (~1.5h). + +```bash +python -c "from pyserini.search import SimpleSearcher; SimpleSearcher.from_prebuilt_index('msmarco-passage-ltr')" +``` + +We run the above commands to obtain pre-built index in cache. + +Note you can also build index from scratch follow [this guide](./experiments-ltr-msmarco-passage-reranking.md#L104). + +```bash +wget https://www.dropbox.com/s/vlrfcz3vmr4nt0q/ibm_model.tar.gz -P collections/msmarco-ltr-passage/ +tar -xzvf collections/msmarco-ltr-passage/ibm_model.tar.gz -C collections/msmarco-ltr-passage/ +``` +Download pretrained IBM models: + +## Training the Model From Scratch +```bash +python scripts/ltr_msmarco-passage/train_ltr_model.py \ + --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 +``` +The above scripts will train a model at `runs/` with your running date in the file name. You can use this as the `--model` parameter for [reranking](experiments-ltr-msmarco-passage-reranking.md#L58). + +Number of negative samples used in training can be changed by `--neg-sample`, by default is 10. + +## Change the Optmization Goal of Your Trained Model +The script trains a model which optimizes MRR@10 by default. + +You can change the `mrr_at_10` of [this function](../scripts/ltr_msmarco-passage/train_ltr_model.py#L621) and [here](../scripts/ltr_msmarco-passage/train_ltr_model.py#L358) to `recall_at_20` to train a model which optimizes recall@20. + +You can also self defined a function format like [this](../scripts/ltr_msmarco-passage/train_ltr_model.py#L300) and change corresponding places mentioned above to have different optimization goal. diff --git a/integrations/sparse/test_ltr_msmarco_passage.py b/integrations/sparse/test_ltr_msmarco_passage.py index 1bfa89fef..cdc59aaf5 100644 --- a/integrations/sparse/test_ltr_msmarco_passage.py +++ b/integrations/sparse/test_ltr_msmarco_passage.py @@ -47,11 +47,8 @@ def test_reranking(self): os.system(f'tar -xzvf ltr_test/{ibm_model_tar_name} -C ltr_test') #queries process os.system('python scripts/ltr_msmarco-passage/convert_queries.py --input tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt --output ltr_test/queries.dev.small.json') - if(os.getcwd().endswith('sparse')): - os.system(f'python ../../scripts/ltr_msmarco-passage/rerank_with_ltr_model.py --input ltr_test/{inp} --input-format tsv --model ltr_test/msmarco-passage-ltr-mrr-v1 --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output ltr_test/{outp}') - else: - os.system(f'python scripts/ltr_msmarco-passage/rerank_with_ltr_model.py --input ltr_test/{inp} --input-format tsv --model ltr_test/msmarco-passage-ltr-mrr-v1 --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output ltr_test/{outp}') - result = subprocess.check_output(f'python tools/scripts/msmarco/msmarco_passage_eval.py tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt ltr_test/{outp}', shell=True).decode(sys.stdout.encoding) + os.system(f'python -m pyserini.ltr.search_msmarco_passage --input ltr_test/{inp} --input-format tsv --model ltr_test/msmarco-passage-ltr-mrr-v1 --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output ltr_test/{outp}') + result = subprocess.check_output(f'python tools/scripts/msmarco/msmarco_passage_eval.py tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt ltr_test/{outp}', shell=True).decode(sys.stdout.encoding) a,b = result.find('#####################\nMRR @10:'), result.find('\nQueriesRanked: 6980\n#####################\n') mrr = result[a+31:b] self.assertAlmostEqual(float(mrr),0.24709612498294367, delta=0.000001) diff --git a/pyserini/ltr/search_msmarco_passage/__init__.py b/pyserini/ltr/search_msmarco_passage/__init__.py new file mode 100644 index 000000000..14b70b6bf --- /dev/null +++ b/pyserini/ltr/search_msmarco_passage/__init__.py @@ -0,0 +1,18 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from ._search_msmarco_passage import MsmarcoPassageLtrSearcher +__all__ = ['MsmarcoPassageLtrSearcher'] \ No newline at end of file diff --git a/pyserini/ltr/search_msmarco_passage/__main__.py b/pyserini/ltr/search_msmarco_passage/__main__.py new file mode 100644 index 000000000..afe158d92 --- /dev/null +++ b/pyserini/ltr/search_msmarco_passage/__main__.py @@ -0,0 +1,238 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +# We're going to explicitly use a local installation of Pyserini (as opposed to a pip-installed one). +# Comment these lines out to use a pip-installed one instead. +sys.path.insert(0, './') + +import argparse +import json +import multiprocessing +import os +import pickle +import time + +import numpy as np +import pandas as pd +from tqdm import tqdm +from pyserini.ltr.search_msmarco_passage._search_msmarco_passage import MsmarcoPassageLtrSearcher +from pyserini.ltr import * + +""" +Running prediction on candidates +""" +def dev_data_loader(file, format, top=100): + if format == 'tsv': + dev = pd.read_csv(file, sep="\t", + names=['qid', 'pid', 'rank'], + dtype={'qid': 'S','pid': 'S', 'rank':'i',}) + elif format == 'trec': + dev = pd.read_csv(file, sep="\s+", + names=['qid', 'q0', 'pid', 'rank', 'score', 'tag'], + usecols=['qid', 'pid', 'rank'], + dtype={'qid': 'S','pid': 'S', 'rank':'i',}) + else: + raise Exception('unknown parameters') + assert dev['qid'].dtype == np.object + assert dev['pid'].dtype == np.object + assert dev['rank'].dtype == np.int32 + dev = dev[dev['rank']<=top] + dev_qrel = pd.read_csv('tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt', sep=" ", + names=["qid", "q0", "pid", "rel"], usecols=['qid', 'pid', 'rel'], + dtype={'qid': 'S','pid': 'S', 'rel':'i'}) + assert dev['qid'].dtype == np.object + assert dev['pid'].dtype == np.object + assert dev['rank'].dtype == np.int32 + dev = dev.merge(dev_qrel, left_on=['qid', 'pid'], right_on=['qid', 'pid'], how='left') + dev['rel'] = dev['rel'].fillna(0).astype(np.int32) + dev = dev.sort_values(['qid', 'pid']).set_index(['qid', 'pid']) + + print(dev.shape) + print(dev.index.get_level_values('qid').drop_duplicates().shape) + print(dev.groupby('qid').count().mean()) + print(dev.head(10)) + print(dev.info()) + + dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel'] + + recall_point = [10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000] + recall_curve = {k: [] for k in recall_point} + for qid, group in tqdm(dev.groupby('qid')): + group = group.reset_index() + assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) + total_rel = dev_rel_num.loc[qid] + query_recall = [0 for k in recall_point] + for t in group.sort_values('rank').itertuples(): + if t.rel > 0: + for i, p in enumerate(recall_point): + if t.rank <= p: + query_recall[i] += 1 + for i, p in enumerate(recall_point): + if total_rel > 0: + recall_curve[p].append(query_recall[i] / total_rel) + else: + recall_curve[p].append(0.) + + for k, v in recall_curve.items(): + avg = np.mean(v) + print(f'recall@{k}:{avg}') + + return dev, dev_qrel + + +def query_loader(): + queries = {} + with open(f'{args.queries}/queries.dev.small.json') as f: + for line in f: + query = json.loads(line) + qid = query.pop('id') + query['analyzed'] = query['analyzed'].split(" ") + query['text'] = query['text_unlemm'].split(" ") + query['text_unlemm'] = query['text_unlemm'].split(" ") + query['text_bert_tok'] = query['text_bert_tok'].split(" ") + queries[qid] = query + return queries + + +def eval_mrr(dev_data): + score_tie_counter = 0 + score_tie_query = set() + MRR = [] + for qid, group in tqdm(dev_data.groupby('qid')): + group = group.reset_index() + rank = 0 + prev_score = None + assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) + # stable sort is also used in LightGBM + + for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): + if prev_score is not None and abs(t.score - prev_score) < 1e-8: + score_tie_counter += 1 + score_tie_query.add(qid) + prev_score = t.score + rank += 1 + if t.rel > 0: + MRR.append(1.0 / rank) + break + elif rank == 10 or rank == len(group): + MRR.append(0.) + break + + score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' + print(score_tie) + mrr_10 = np.mean(MRR).item() + print(f'MRR@10:{mrr_10} with {len(MRR)} queries') + return {'score_tie': score_tie, 'mrr_10': mrr_10} + + +def eval_recall(dev_qrel, dev_data): + dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel'] + + score_tie_counter = 0 + score_tie_query = set() + + recall_point = [10,20,50,100,200,250,300,333,400,500,1000] + recall_curve = {k: [] for k in recall_point} + for qid, group in tqdm(dev_data.groupby('qid')): + group = group.reset_index() + rank = 0 + prev_score = None + assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) + # stable sort is also used in LightGBM + total_rel = dev_rel_num.loc[qid] + query_recall = [0 for k in recall_point] + for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): + if prev_score is not None and abs(t.score - prev_score) < 1e-8: + score_tie_counter += 1 + score_tie_query.add(qid) + prev_score = t.score + rank += 1 + if t.rel > 0: + for i, p in enumerate(recall_point): + if rank <= p: + query_recall[i] += 1 + for i, p in enumerate(recall_point): + if total_rel > 0: + recall_curve[p].append(query_recall[i] / total_rel) + else: + recall_curve[p].append(0.) + + score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' + print(score_tie) + res = {'score_tie': score_tie} + + for k, v in recall_curve.items(): + avg = np.mean(v) + print(f'recall@{k}:{avg}') + res[f'recall@{k}'] = avg + + return res + + +def output(file, dev_data): + score_tie_counter = 0 + score_tie_query = set() + output_file = open(file,'w') + + for qid, group in tqdm(dev_data.groupby('qid')): + group = group.reset_index() + rank = 0 + prev_score = None + assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) + # stable sort is also used in LightGBM + + for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): + if prev_score is not None and abs(t.score - prev_score) < 1e-8: + score_tie_counter += 1 + score_tie_query.add(qid) + prev_score = t.score + rank += 1 + output_file.write(f"{qid}\t{t.pid}\t{rank}\n") + + score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' + print(score_tie) + +if __name__ == "__main__": + os.environ["ANSERINI_CLASSPATH"] = "./pyserini/resources/jars" + parser = argparse.ArgumentParser(description='Learning to rank reranking') + parser.add_argument('--input', required=True) + parser.add_argument('--reranking-top', type=int, default=1000) + parser.add_argument('--input-format', required=True) + parser.add_argument('--model', required=True) + parser.add_argument('--index', required=True) + parser.add_argument('--output', required=True) + parser.add_argument('--ibm-model',default='./collections/msmarco-ltr-passage/ibm_model/') + parser.add_argument('--queries',default='./collections/msmarco-ltr-passage/') + + args = parser.parse_args() + searcher = MsmarcoPassageLtrSearcher(args.model, args.ibm_model, args.index) + searcher.add_fe() + print("load dev") + dev, dev_qrel = dev_data_loader(args.input, args.input_format, args.reranking_top) + print("load queries") + queries = query_loader() + + batch_info = searcher.search(dev, queries) + del dev, queries + + eval_res = eval_mrr(batch_info) + eval_recall(dev_qrel, batch_info) + output(args.output, batch_info) + print('Done!') + + diff --git a/pyserini/ltr/search_msmarco_passage/_search_msmarco_passage.py b/pyserini/ltr/search_msmarco_passage/_search_msmarco_passage.py new file mode 100644 index 000000000..157d6e796 --- /dev/null +++ b/pyserini/ltr/search_msmarco_passage/_search_msmarco_passage.py @@ -0,0 +1,236 @@ +# +# Pyserini: Python interface to the Anserini IR toolkit built on Lucene +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module provides Pyserini's Python ltr search interface on MS MARCO passage. The main entry point is the ``MsmarcoPassageLtrSearcher`` +class. +""" + +import logging +from typing import Dict, List, Optional, Union +import multiprocessing +import time +from tqdm import tqdm +import pickle + +from pyserini.ltr._base import * + + +logger = logging.getLogger(__name__) + +class MsmarcoPassageLtrSearcher: + def __init__(self, model: str, ibm_model:str, index:str): + self.model = model + self.ibm_model = ibm_model + self.fe = FeatureExtractor(index, max(multiprocessing.cpu_count()//2, 1)) + + def add_fe(self): + for qfield, ifield in [('analyzed', 'contents'), + ('text_unlemm', 'text_unlemm'), + ('text_bert_tok', 'text_bert_tok')]: + print(qfield, ifield) + self.fe.add(BM25Stat(SumPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) + self.fe.add(BM25Stat(AvgPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) + self.fe.add(BM25Stat(MedianPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) + self.fe.add(BM25Stat(MaxPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) + self.fe.add(BM25Stat(MinPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) + self.fe.add(BM25Stat(MaxMinRatioPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) + + self.fe.add(LmDirStat(SumPooler(), mu=1000, field=ifield, qfield=qfield)) + self.fe.add(LmDirStat(AvgPooler(), mu=1000, field=ifield, qfield=qfield)) + self.fe.add(LmDirStat(MedianPooler(), mu=1000, field=ifield, qfield=qfield)) + self.fe.add(LmDirStat(MaxPooler(), mu=1000, field=ifield, qfield=qfield)) + self.fe.add(LmDirStat(MinPooler(), mu=1000, field=ifield, qfield=qfield)) + self.fe.add(LmDirStat(MaxMinRatioPooler(), mu=1000, field=ifield, qfield=qfield)) + + self.fe.add(NormalizedTfIdf(field=ifield, qfield=qfield)) + self.fe.add(ProbalitySum(field=ifield, qfield=qfield)) + + self.fe.add(DfrGl2Stat(SumPooler(), field=ifield, qfield=qfield)) + self.fe.add(DfrGl2Stat(AvgPooler(), field=ifield, qfield=qfield)) + self.fe.add(DfrGl2Stat(MedianPooler(), field=ifield, qfield=qfield)) + self.fe.add(DfrGl2Stat(MaxPooler(), field=ifield, qfield=qfield)) + self.fe.add(DfrGl2Stat(MinPooler(), field=ifield, qfield=qfield)) + self.fe.add(DfrGl2Stat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) + + self.fe.add(DfrInExpB2Stat(SumPooler(), field=ifield, qfield=qfield)) + self.fe.add(DfrInExpB2Stat(AvgPooler(), field=ifield, qfield=qfield)) + self.fe.add(DfrInExpB2Stat(MedianPooler(), field=ifield, qfield=qfield)) + self.fe.add(DfrInExpB2Stat(MaxPooler(), field=ifield, qfield=qfield)) + self.fe.add(DfrInExpB2Stat(MinPooler(), field=ifield, qfield=qfield)) + self.fe.add(DfrInExpB2Stat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) + + self.fe.add(DphStat(SumPooler(), field=ifield, qfield=qfield)) + self.fe.add(DphStat(AvgPooler(), field=ifield, qfield=qfield)) + self.fe.add(DphStat(MedianPooler(), field=ifield, qfield=qfield)) + self.fe.add(DphStat(MaxPooler(), field=ifield, qfield=qfield)) + self.fe.add(DphStat(MinPooler(), field=ifield, qfield=qfield)) + self.fe.add(DphStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) + + self.fe.add(Proximity(field=ifield, qfield=qfield)) + self.fe.add(TpScore(field=ifield, qfield=qfield)) + self.fe.add(TpDist(field=ifield, qfield=qfield)) + + self.fe.add(DocSize(field=ifield)) + + self.fe.add(QueryLength(qfield=qfield)) + self.fe.add(QueryCoverageRatio(qfield=qfield)) + self.fe.add(UniqueTermCount(qfield=qfield)) + self.fe.add(MatchingTermCount(field=ifield, qfield=qfield)) + self.fe.add(SCS(field=ifield, qfield=qfield)) + + self.fe.add(TfStat(AvgPooler(), field=ifield, qfield=qfield)) + self.fe.add(TfStat(MedianPooler(), field=ifield, qfield=qfield)) + self.fe.add(TfStat(SumPooler(), field=ifield, qfield=qfield)) + self.fe.add(TfStat(MinPooler(), field=ifield, qfield=qfield)) + self.fe.add(TfStat(MaxPooler(), field=ifield, qfield=qfield)) + self.fe.add(TfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) + + self.fe.add(TfIdfStat(True, AvgPooler(), field=ifield, qfield=qfield)) + self.fe.add(TfIdfStat(True, MedianPooler(), field=ifield, qfield=qfield)) + self.fe.add(TfIdfStat(True, SumPooler(), field=ifield, qfield=qfield)) + self.fe.add(TfIdfStat(True, MinPooler(), field=ifield, qfield=qfield)) + self.fe.add(TfIdfStat(True, MaxPooler(), field=ifield, qfield=qfield)) + self.fe.add(TfIdfStat(True, MaxMinRatioPooler(), field=ifield, qfield=qfield)) + + self.fe.add(NormalizedTfStat(AvgPooler(), field=ifield, qfield=qfield)) + self.fe.add(NormalizedTfStat(MedianPooler(), field=ifield, qfield=qfield)) + self.fe.add(NormalizedTfStat(SumPooler(), field=ifield, qfield=qfield)) + self.fe.add(NormalizedTfStat(MinPooler(), field=ifield, qfield=qfield)) + self.fe.add(NormalizedTfStat(MaxPooler(), field=ifield, qfield=qfield)) + self.fe.add(NormalizedTfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) + + self.fe.add(IdfStat(AvgPooler(), field=ifield, qfield=qfield)) + self.fe.add(IdfStat(MedianPooler(), field=ifield, qfield=qfield)) + self.fe.add(IdfStat(SumPooler(), field=ifield, qfield=qfield)) + self.fe.add(IdfStat(MinPooler(), field=ifield, qfield=qfield)) + self.fe.add(IdfStat(MaxPooler(), field=ifield, qfield=qfield)) + self.fe.add(IdfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) + + self.fe.add(IcTfStat(AvgPooler(), field=ifield, qfield=qfield)) + self.fe.add(IcTfStat(MedianPooler(), field=ifield, qfield=qfield)) + self.fe.add(IcTfStat(SumPooler(), field=ifield, qfield=qfield)) + self.fe.add(IcTfStat(MinPooler(), field=ifield, qfield=qfield)) + self.fe.add(IcTfStat(MaxPooler(), field=ifield, qfield=qfield)) + self.fe.add(IcTfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) + + self.fe.add(UnorderedSequentialPairs(3, field=ifield, qfield=qfield)) + self.fe.add(UnorderedSequentialPairs(8, field=ifield, qfield=qfield)) + self.fe.add(UnorderedSequentialPairs(15, field=ifield, qfield=qfield)) + self.fe.add(OrderedSequentialPairs(3, field=ifield, qfield=qfield)) + self.fe.add(OrderedSequentialPairs(8, field=ifield, qfield=qfield)) + self.fe.add(OrderedSequentialPairs(15, field=ifield, qfield=qfield)) + self.fe.add(UnorderedQueryPairs(3, field=ifield, qfield=qfield)) + self.fe.add(UnorderedQueryPairs(8, field=ifield, qfield=qfield)) + self.fe.add(UnorderedQueryPairs(15, field=ifield, qfield=qfield)) + self.fe.add(OrderedQueryPairs(3, field=ifield, qfield=qfield)) + self.fe.add(OrderedQueryPairs(8, field=ifield, qfield=qfield)) + self.fe.add(OrderedQueryPairs(15, field=ifield, qfield=qfield)) + + start = time.time() + self.fe.add( + IbmModel1(f"{self.ibm_model}/title_unlemm", "text_unlemm", "title_unlemm", + "text_unlemm")) + end = time.time() + print('IBM model Load takes %.2f seconds' % (end - start)) + start = end + self.fe.add(IbmModel1(f"{self.ibm_model}url_unlemm", "text_unlemm", "url_unlemm", + "text_unlemm")) + end = time.time() + print('IBM model Load takes %.2f seconds' % (end - start)) + start = end + self.fe.add( + IbmModel1(f"{self.ibm_model}body", "text_unlemm", "body", "text_unlemm")) + end = time.time() + print('IBM model Load takes %.2f seconds' % (end - start)) + start = end + self.fe.add(IbmModel1(f"{self.ibm_model}text_bert_tok", "text_bert_tok", + "text_bert_tok", "text_bert_tok")) + end = time.time() + print('IBM model Load takes %.2f seconds' % (end - start)) + start = end + + def batch_extract(self, df, queries, fe): + tasks = [] + task_infos = [] + group_lst = [] + + for qid, group in tqdm(df.groupby('qid')): + task = { + "qid": qid, + "docIds": [], + "rels": [], + "query_dict": queries[qid] + } + for t in group.reset_index().itertuples(): + task["docIds"].append(t.pid) + task_infos.append((qid, t.pid, t.rel)) + tasks.append(task) + group_lst.append((qid, len(task['docIds']))) + if len(tasks) == 1000: + features = fe.batch_extract(tasks) + task_infos = pd.DataFrame(task_infos, columns=['qid', 'pid', 'rel']) + group = pd.DataFrame(group_lst, columns=['qid', 'count']) + print(features.shape) + print(task_infos.qid.drop_duplicates().shape) + print(group.mean()) + print(features.head(10)) + print(features.info()) + yield task_infos, features, group + tasks = [] + task_infos = [] + group_lst = [] + # deal with rest + if len(tasks) > 0: + features = fe.batch_extract(tasks) + task_infos = pd.DataFrame(task_infos, columns=['qid', 'pid', 'rel']) + group = pd.DataFrame(group_lst, columns=['qid', 'count']) + print(features.shape) + print(task_infos.qid.drop_duplicates().shape) + print(group.mean()) + print(features.head(10)) + print(features.info()) + yield task_infos, features, group + + return + + def batch_predict(self, models, dev_extracted, feature_name): + task_infos, features, group = dev_extracted + dev_X = features.loc[:, feature_name] + + task_infos['score'] = 0. + for gbm in models: + task_infos['score'] += gbm.predict(dev_X) + + def search(self, dev, queries): + batch_info = [] + start_extract = time.time() + models = pickle.load(open(self.model+'/model.pkl', 'rb')) + metadata = json.load(open(self.model+'/metadata.json', 'r')) + feature_used = metadata['feature_names'] + for dev_extracted in self.batch_extract(dev, queries, self.fe): + end_extract = time.time() + print(f'extract 1000 queries take {end_extract - start_extract}s') + task_infos, features, group = dev_extracted + start_predict = time.time() + self.batch_predict(models, dev_extracted, feature_used) + end_predict = time.time() + print(f'predict 1000 queries take {end_predict - start_predict}s') + batch_info.append(task_infos) + start_extract = time.time() + batch_info = pd.concat(batch_info, axis=0, ignore_index=True) + return batch_info + diff --git a/scripts/ltr_msmarco-passage/rerank_with_ltr_model.py b/scripts/ltr_msmarco-passage/rerank_with_ltr_model.py deleted file mode 100644 index 1220f2c00..000000000 --- a/scripts/ltr_msmarco-passage/rerank_with_ltr_model.py +++ /dev/null @@ -1,450 +0,0 @@ -# -# Pyserini: Reproducible IR research with sparse and dense representations -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -# We're going to explicitly use a local installation of Pyserini (as opposed to a pip-installed one). -# Comment these lines out to use a pip-installed one instead. -sys.path.insert(0, './') - -import argparse -import json -import multiprocessing -import os -import pickle -import time - -import numpy as np -import pandas as pd -from tqdm import tqdm -from pyserini.ltr import * - -""" -Running prediction on candidates -""" -def dev_data_loader(file, format, top=100): - if format == 'tsv': - dev = pd.read_csv(file, sep="\t", - names=['qid', 'pid', 'rank'], - dtype={'qid': 'S','pid': 'S', 'rank':'i',}) - elif format == 'trec': - dev = pd.read_csv(file, sep="\s+", - names=['qid', 'q0', 'pid', 'rank', 'score', 'tag'], - usecols=['qid', 'pid', 'rank'], - dtype={'qid': 'S','pid': 'S', 'rank':'i',}) - else: - raise Exception('unknown parameters') - assert dev['qid'].dtype == np.object - assert dev['pid'].dtype == np.object - assert dev['rank'].dtype == np.int32 - dev = dev[dev['rank']<=top] - dev_qrel = pd.read_csv('tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt', sep=" ", - names=["qid", "q0", "pid", "rel"], usecols=['qid', 'pid', 'rel'], - dtype={'qid': 'S','pid': 'S', 'rel':'i'}) - assert dev['qid'].dtype == np.object - assert dev['pid'].dtype == np.object - assert dev['rank'].dtype == np.int32 - dev = dev.merge(dev_qrel, left_on=['qid', 'pid'], right_on=['qid', 'pid'], how='left') - dev['rel'] = dev['rel'].fillna(0).astype(np.int32) - dev = dev.sort_values(['qid', 'pid']).set_index(['qid', 'pid']) - - print(dev.shape) - print(dev.index.get_level_values('qid').drop_duplicates().shape) - print(dev.groupby('qid').count().mean()) - print(dev.head(10)) - print(dev.info()) - - dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel'] - - recall_point = [10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000] - recall_curve = {k: [] for k in recall_point} - for qid, group in tqdm(dev.groupby('qid')): - group = group.reset_index() - assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) - total_rel = dev_rel_num.loc[qid] - query_recall = [0 for k in recall_point] - for t in group.sort_values('rank').itertuples(): - if t.rel > 0: - for i, p in enumerate(recall_point): - if t.rank <= p: - query_recall[i] += 1 - for i, p in enumerate(recall_point): - if total_rel > 0: - recall_curve[p].append(query_recall[i] / total_rel) - else: - recall_curve[p].append(0.) - - for k, v in recall_curve.items(): - avg = np.mean(v) - print(f'recall@{k}:{avg}') - - return dev, dev_qrel - - -def query_loader(): - queries = {} - ''' - with open('collections/msmarco-ltr-passage/queries.train.json') as f: - for line in f: - query = json.loads(line) - qid = query.pop('id') - query['analyzed'] = query['analyzed'].split(" ") - query['text'] = query['text_unlemm'].split(" ") - query['text_unlemm'] = query['text_unlemm'].split(" ") - query['text_bert_tok'] = query['text_bert_tok'].split(" ") - queries[qid] = query - ''' - with open(f'{args.queries}/queries.dev.small.json') as f: - for line in f: - query = json.loads(line) - qid = query.pop('id') - query['analyzed'] = query['analyzed'].split(" ") - query['text'] = query['text_unlemm'].split(" ") - query['text_unlemm'] = query['text_unlemm'].split(" ") - query['text_bert_tok'] = query['text_bert_tok'].split(" ") - queries[qid] = query - ''' - with open('collections/msmarco-ltr-passage/queries.eval.small.json') as f: - for line in f: - query = json.loads(line) - qid = query.pop('id') - query['analyzed'] = query['analyzed'].split(" ") - query['text'] = query['text_unlemm'].split(" ") - query['text_unlemm'] = query['text_unlemm'].split(" ") - query['text_bert_tok'] = query['text_bert_tok'].split(" ") - queries[qid] = query - ''' - return queries - - -def batch_extract(df, queries, fe): - tasks = [] - task_infos = [] - group_lst = [] - - for qid, group in tqdm(df.groupby('qid')): - task = { - "qid": qid, - "docIds": [], - "rels": [], - "query_dict": queries[qid] - } - for t in group.reset_index().itertuples(): - task["docIds"].append(t.pid) - task_infos.append((qid, t.pid, t.rel)) - tasks.append(task) - group_lst.append((qid, len(task['docIds']))) - if len(tasks) == 1000: - features = fe.batch_extract(tasks) - task_infos = pd.DataFrame(task_infos, columns=['qid', 'pid', 'rel']) - group = pd.DataFrame(group_lst, columns=['qid', 'count']) - print(features.shape) - print(task_infos.qid.drop_duplicates().shape) - print(group.mean()) - print(features.head(10)) - print(features.info()) - yield task_infos, features, group - tasks = [] - task_infos = [] - group_lst = [] - # deal with rest - if len(tasks) > 0: - features = fe.batch_extract(tasks) - task_infos = pd.DataFrame(task_infos, columns=['qid', 'pid', 'rel']) - group = pd.DataFrame(group_lst, columns=['qid', 'count']) - print(features.shape) - print(task_infos.qid.drop_duplicates().shape) - print(group.mean()) - print(features.head(10)) - print(features.info()) - yield task_infos, features, group - - return - -def batch_predict(models, dev_extracted, feature_name): - task_infos, features, group = dev_extracted - dev_X = features.loc[:, feature_name] - - task_infos['score'] = 0. - for gbm in models: - task_infos['score'] += gbm.predict(dev_X) - - -def eval_mrr(dev_data): - score_tie_counter = 0 - score_tie_query = set() - MRR = [] - for qid, group in tqdm(dev_data.groupby('qid')): - group = group.reset_index() - rank = 0 - prev_score = None - assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) - # stable sort is also used in LightGBM - - for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): - if prev_score is not None and abs(t.score - prev_score) < 1e-8: - score_tie_counter += 1 - score_tie_query.add(qid) - prev_score = t.score - rank += 1 - if t.rel > 0: - MRR.append(1.0 / rank) - break - elif rank == 10 or rank == len(group): - MRR.append(0.) - break - - score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' - print(score_tie) - mrr_10 = np.mean(MRR).item() - print(f'MRR@10:{mrr_10} with {len(MRR)} queries') - return {'score_tie': score_tie, 'mrr_10': mrr_10} - - -def eval_recall(dev_qrel, dev_data): - dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel'] - - score_tie_counter = 0 - score_tie_query = set() - - recall_point = [10,20,50,100,200,250,300,333,400,500,1000] - recall_curve = {k: [] for k in recall_point} - for qid, group in tqdm(dev_data.groupby('qid')): - group = group.reset_index() - rank = 0 - prev_score = None - assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) - # stable sort is also used in LightGBM - total_rel = dev_rel_num.loc[qid] - query_recall = [0 for k in recall_point] - for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): - if prev_score is not None and abs(t.score - prev_score) < 1e-8: - score_tie_counter += 1 - score_tie_query.add(qid) - prev_score = t.score - rank += 1 - if t.rel > 0: - for i, p in enumerate(recall_point): - if rank <= p: - query_recall[i] += 1 - for i, p in enumerate(recall_point): - if total_rel > 0: - recall_curve[p].append(query_recall[i] / total_rel) - else: - recall_curve[p].append(0.) - - score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' - print(score_tie) - res = {'score_tie': score_tie} - - for k, v in recall_curve.items(): - avg = np.mean(v) - print(f'recall@{k}:{avg}') - res[f'recall@{k}'] = avg - - return res - - -def output(file, dev_data): - score_tie_counter = 0 - score_tie_query = set() - output_file = open(file,'w') - - for qid, group in tqdm(dev_data.groupby('qid')): - group = group.reset_index() - rank = 0 - prev_score = None - assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) - # stable sort is also used in LightGBM - - for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): - if prev_score is not None and abs(t.score - prev_score) < 1e-8: - score_tie_counter += 1 - score_tie_query.add(qid) - prev_score = t.score - rank += 1 - output_file.write(f"{qid}\t{t.pid}\t{rank}\n") - - score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' - print(score_tie) - - -if __name__ == '__main__': - os.environ["ANSERINI_CLASSPATH"] = "./pyserini/resources/jars" - parser = argparse.ArgumentParser(description='Learning to rank reranking') - parser.add_argument('--input', required=True) - parser.add_argument('--reranking-top', type=int, default=1000) - parser.add_argument('--input-format', required=True) - parser.add_argument('--model', required=True) - parser.add_argument('--index', required=True) - parser.add_argument('--output', required=True) - parser.add_argument('--ibm-model',default='./collections/msmarco-ltr-passage/ibm_model/') - parser.add_argument('--queries',default='./collections/msmarco-ltr-passage/') - - args = parser.parse_args() - print("load dev") - dev, dev_qrel = dev_data_loader(args.input, args.input_format, args.reranking_top) - print("load queries") - queries = query_loader() - print("add feature") - fe = FeatureExtractor(args.index, max(multiprocessing.cpu_count()//2, 1)) - for qfield, ifield in [('analyzed', 'contents'), - ('text_unlemm', 'text_unlemm'), - ('text_bert_tok', 'text_bert_tok')]: - print(qfield, ifield) - fe.add(BM25Stat(SumPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) - fe.add(BM25Stat(AvgPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) - fe.add(BM25Stat(MedianPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) - fe.add(BM25Stat(MaxPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) - fe.add(BM25Stat(MinPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) - fe.add(BM25Stat(MaxMinRatioPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) - - fe.add(LmDirStat(SumPooler(), mu=1000, field=ifield, qfield=qfield)) - fe.add(LmDirStat(AvgPooler(), mu=1000, field=ifield, qfield=qfield)) - fe.add(LmDirStat(MedianPooler(), mu=1000, field=ifield, qfield=qfield)) - fe.add(LmDirStat(MaxPooler(), mu=1000, field=ifield, qfield=qfield)) - fe.add(LmDirStat(MinPooler(), mu=1000, field=ifield, qfield=qfield)) - fe.add(LmDirStat(MaxMinRatioPooler(), mu=1000, field=ifield, qfield=qfield)) - - fe.add(NormalizedTfIdf(field=ifield, qfield=qfield)) - fe.add(ProbalitySum(field=ifield, qfield=qfield)) - - fe.add(DfrGl2Stat(SumPooler(), field=ifield, qfield=qfield)) - fe.add(DfrGl2Stat(AvgPooler(), field=ifield, qfield=qfield)) - fe.add(DfrGl2Stat(MedianPooler(), field=ifield, qfield=qfield)) - fe.add(DfrGl2Stat(MaxPooler(), field=ifield, qfield=qfield)) - fe.add(DfrGl2Stat(MinPooler(), field=ifield, qfield=qfield)) - fe.add(DfrGl2Stat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) - - fe.add(DfrInExpB2Stat(SumPooler(), field=ifield, qfield=qfield)) - fe.add(DfrInExpB2Stat(AvgPooler(), field=ifield, qfield=qfield)) - fe.add(DfrInExpB2Stat(MedianPooler(), field=ifield, qfield=qfield)) - fe.add(DfrInExpB2Stat(MaxPooler(), field=ifield, qfield=qfield)) - fe.add(DfrInExpB2Stat(MinPooler(), field=ifield, qfield=qfield)) - fe.add(DfrInExpB2Stat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) - - fe.add(DphStat(SumPooler(), field=ifield, qfield=qfield)) - fe.add(DphStat(AvgPooler(), field=ifield, qfield=qfield)) - fe.add(DphStat(MedianPooler(), field=ifield, qfield=qfield)) - fe.add(DphStat(MaxPooler(), field=ifield, qfield=qfield)) - fe.add(DphStat(MinPooler(), field=ifield, qfield=qfield)) - fe.add(DphStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) - - fe.add(Proximity(field=ifield, qfield=qfield)) - fe.add(TpScore(field=ifield, qfield=qfield)) - fe.add(TpDist(field=ifield, qfield=qfield)) - - fe.add(DocSize(field=ifield)) - - fe.add(QueryLength(qfield=qfield)) - fe.add(QueryCoverageRatio(qfield=qfield)) - fe.add(UniqueTermCount(qfield=qfield)) - fe.add(MatchingTermCount(field=ifield, qfield=qfield)) - fe.add(SCS(field=ifield, qfield=qfield)) - - fe.add(TfStat(AvgPooler(), field=ifield, qfield=qfield)) - fe.add(TfStat(MedianPooler(), field=ifield, qfield=qfield)) - fe.add(TfStat(SumPooler(), field=ifield, qfield=qfield)) - fe.add(TfStat(MinPooler(), field=ifield, qfield=qfield)) - fe.add(TfStat(MaxPooler(), field=ifield, qfield=qfield)) - fe.add(TfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) - - fe.add(TfIdfStat(True, AvgPooler(), field=ifield, qfield=qfield)) - fe.add(TfIdfStat(True, MedianPooler(), field=ifield, qfield=qfield)) - fe.add(TfIdfStat(True, SumPooler(), field=ifield, qfield=qfield)) - fe.add(TfIdfStat(True, MinPooler(), field=ifield, qfield=qfield)) - fe.add(TfIdfStat(True, MaxPooler(), field=ifield, qfield=qfield)) - fe.add(TfIdfStat(True, MaxMinRatioPooler(), field=ifield, qfield=qfield)) - - fe.add(NormalizedTfStat(AvgPooler(), field=ifield, qfield=qfield)) - fe.add(NormalizedTfStat(MedianPooler(), field=ifield, qfield=qfield)) - fe.add(NormalizedTfStat(SumPooler(), field=ifield, qfield=qfield)) - fe.add(NormalizedTfStat(MinPooler(), field=ifield, qfield=qfield)) - fe.add(NormalizedTfStat(MaxPooler(), field=ifield, qfield=qfield)) - fe.add(NormalizedTfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) - - fe.add(IdfStat(AvgPooler(), field=ifield, qfield=qfield)) - fe.add(IdfStat(MedianPooler(), field=ifield, qfield=qfield)) - fe.add(IdfStat(SumPooler(), field=ifield, qfield=qfield)) - fe.add(IdfStat(MinPooler(), field=ifield, qfield=qfield)) - fe.add(IdfStat(MaxPooler(), field=ifield, qfield=qfield)) - fe.add(IdfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) - - fe.add(IcTfStat(AvgPooler(), field=ifield, qfield=qfield)) - fe.add(IcTfStat(MedianPooler(), field=ifield, qfield=qfield)) - fe.add(IcTfStat(SumPooler(), field=ifield, qfield=qfield)) - fe.add(IcTfStat(MinPooler(), field=ifield, qfield=qfield)) - fe.add(IcTfStat(MaxPooler(), field=ifield, qfield=qfield)) - fe.add(IcTfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) - - fe.add(UnorderedSequentialPairs(3, field=ifield, qfield=qfield)) - fe.add(UnorderedSequentialPairs(8, field=ifield, qfield=qfield)) - fe.add(UnorderedSequentialPairs(15, field=ifield, qfield=qfield)) - fe.add(OrderedSequentialPairs(3, field=ifield, qfield=qfield)) - fe.add(OrderedSequentialPairs(8, field=ifield, qfield=qfield)) - fe.add(OrderedSequentialPairs(15, field=ifield, qfield=qfield)) - fe.add(UnorderedQueryPairs(3, field=ifield, qfield=qfield)) - fe.add(UnorderedQueryPairs(8, field=ifield, qfield=qfield)) - fe.add(UnorderedQueryPairs(15, field=ifield, qfield=qfield)) - fe.add(OrderedQueryPairs(3, field=ifield, qfield=qfield)) - fe.add(OrderedQueryPairs(8, field=ifield, qfield=qfield)) - fe.add(OrderedQueryPairs(15, field=ifield, qfield=qfield)) - - start = time.time() - fe.add( - IbmModel1(f"{args.ibm_model}/title_unlemm", "text_unlemm", "title_unlemm", - "text_unlemm")) - end = time.time() - print('IBM model Load takes %.2f seconds' % (end - start)) - start = end - fe.add(IbmModel1(f"{args.ibm_model}url_unlemm", "text_unlemm", "url_unlemm", - "text_unlemm")) - end = time.time() - print('IBM model Load takes %.2f seconds' % (end - start)) - start = end - fe.add( - IbmModel1(f"{args.ibm_model}body", "text_unlemm", "body", "text_unlemm")) - end = time.time() - print('IBM model Load takes %.2f seconds' % (end - start)) - start = end - fe.add(IbmModel1(f"{args.ibm_model}text_bert_tok", "text_bert_tok", - "text_bert_tok", "text_bert_tok")) - end = time.time() - print('IBM model Load takes %.2f seconds' % (end - start)) - start = end - - models = pickle.load(open(args.model+'/model.pkl', 'rb')) - metadata = json.load(open(args.model+'/metadata.json', 'r')) - feature_used = metadata['feature_names'] - - batch_info = [] - start_extract = time.time() - for dev_extracted in batch_extract(dev, queries, fe): - end_extract = time.time() - print(f'extract 1000 queries take {end_extract - start_extract}s') - task_infos, features, group = dev_extracted - start_predict = time.time() - batch_predict(models, dev_extracted, feature_used) - end_predict = time.time() - print(f'predict 1000 queries take {end_predict - start_predict}s') - batch_info.append(task_infos) - start_extract = time.time() - batch_info = pd.concat(batch_info, axis=0, ignore_index=True) - del dev, queries, fe - - eval_res = eval_mrr(batch_info) - eval_recall(dev_qrel, batch_info) - output(args.output, batch_info) - print('Done!') \ No newline at end of file diff --git a/scripts/ltr_msmarco-passage/train_ltr_model.py b/scripts/ltr_msmarco-passage/train_ltr_model.py index e35bf80c7..cce499c35 100644 --- a/scripts/ltr_msmarco-passage/train_ltr_model.py +++ b/scripts/ltr_msmarco-passage/train_ltr_model.py @@ -480,9 +480,11 @@ def save_exp(dirname, os.environ["ANSERINI_CLASSPATH"] = "pyserini/resources/jars" parser = argparse.ArgumentParser(description='Learning to rank training') parser.add_argument('--index', required=True) + parser.add_argument('--neg-sample', default=10) + parser.add_argument('--opt', default='mrr_at_10') args = parser.parse_args() total_start_time = time.time() - sampled_train = train_data_loader(task='triple', neg_sample=10) + sampled_train = train_data_loader(task='triple', neg_sample = args.neg_sample) dev, dev_qrel = dev_data_loader(task='anserini') queries = query_loader() @@ -615,7 +617,7 @@ def save_exp(dirname, print("dev extracted") feature_name = fe.feature_names() del sampled_train, dev, queries, fe - eval_fn = gen_dev_group_rel_num(dev_qrel, dev_extracted) + recall_at_20 = gen_dev_group_rel_num(dev_qrel, dev_extracted) print("start train") train_res = train(train_extracted, dev_extracted, feature_name, mrr_at_10) print("end train")