From 5cc929d011af6c171f8a5aca192a43796fa22448 Mon Sep 17 00:00:00 2001 From: Matt Yang Date: Tue, 13 Jul 2021 17:33:53 -0400 Subject: [PATCH] add tct-colbert-v2 doc zero shot exp (#682) * add tct-colbert-v2 doc zero shot exp Co-authored-by: Lin Jack Co-authored-by: Lin Jack Co-authored-by: Jack Lin Co-authored-by: Jheng-Hong Yang --- docs/experiments-tct_colbert-v2.md | 110 ++++++++++++++++++ .../tct_colbert/encode_corpus_msmarco_doc.py | 109 +++++++++++++++++ scripts/tct_colbert/merge_indexes.py | 70 +++++++++++ 3 files changed, 289 insertions(+) create mode 100644 scripts/tct_colbert/encode_corpus_msmarco_doc.py create mode 100755 scripts/tct_colbert/merge_indexes.py diff --git a/docs/experiments-tct_colbert-v2.md b/docs/experiments-tct_colbert-v2.md index f4d1fef78..1f5d5dbaa 100644 --- a/docs/experiments-tct_colbert-v2.md +++ b/docs/experiments-tct_colbert-v2.md @@ -186,6 +186,116 @@ recall_1000 all 0.9759 Follow the same instructions above to perform on-the-fly query encoding. + +## MS MARCO Document Ranking with TCT-ColBERT-V2 (zero-shot) + +Document retrieval with TCT-ColBERT, brute-force index: + +Step0: prepare docs.json: split docs into segments of passages +Each line contains a json dict as follows: +{"id": "[doc_id]#[seg_id]", "contents": "[url]\n[title]\n[seg_text]"} + + +Step1: split documents for parallel encoding +```bash +$ split -a 2 -d -n l/50 docs.json collection.part +``` + +Step2-1: prepare encoder (on CC), you can download encoder using [git-lfs](https://git-lfs.github.com/) + +Example (after you install git-lfs): +```bash +git clone https://huggingface.co/castorini/tct_colbert-v2-hnp-msmarco +``` + +Step2-2: run encoding +```bash +export TASK=msmarco +export ENCODER=tct_colbert-v2-msmarco-hnp +export WORKING_DIR=~/scratch + +for i in $(seq -f "%02g" 0 49) +do + srun --gres=gpu:v100:1 --mem=16G --cpus-per-task=2 --time=2:00:00 \ + python scripts/tct_colbert/encode_corpus_msmarco_doc.py \ + --corpus ${WORKING_DIR}/${TASK}/collection.part${i} \ + --encoder ${WORKING_DIR}/checkpoint/${ENCODER} \ + --index indexes/${TASK}-${ENCODER}-${i} \ + --index indexes/${TASK}-${ENCODER}-${i} \ + --batch 16 \ + --device cuda:0 & +done +``` + +Step3: merge / filter index, use --segment-num -1 for maxp (1 for firstp), or anyother interger you like +```bash +$ python scripts/tct_colbert/merge_indexes.py \ + --prefix \ + --shard-num 50 + --segment-num -1 +``` + +Step4: search (with on-the-fly query encoding) +```bash +$ python -m pyserini.dsearch --topics msmarco-doc-dev \ + --index /msmarco-tct_colbert-v2-hnp-msmarco-full-maxp \ + --encoder castorini/tct_colbert-v2-hnp-msmarco \ + --output runs/run.msmarco-doc.passage.tct_colbert-v2-hnp-maxp.txt \ + --hits 1000 \ + --max-passage \ + --max-passage-hits 100 \ + --output-format msmarco \ + --batch-size 144 \ + --threads 36 + +$ python -m pyserini.dsearch --topics dl19-doc \ + --index /msmarco-tct_colbert-v2-hnp-msmarco-full-maxp \ + --encoder castorini/tct_colbert-v2-hnp-msmarco \ + --output runs/run.dl19-doc.passage.tct_colbert-v2-hnp-maxp.txt \ + --hits 1000 \ + --max-passage \ + --max-passage-hits 100 \ + --output-format msmarco \ + --batch-size 144 \ + --threads 36 +``` + +Step5: eval + +For MSMARCO-Doc-dev +```bash +$ python -m pyserini.eval.msmarco_doc_eval --judgments msmarco-doc-dev --run runs/run.msmarco-doc.passage.tct_colbert-v2-hnp-maxp.txt + +##################### +MRR @100: 0.3508557690776294 +QueriesRanked: 5193 +##################### + +$ python -m pyserini.eval.convert_msmarco_run_to_trec_run --input runs/run.msmarco-doc.passage.tct_colbert-v2-hnp-maxp.txt \ + --output runs/run.msmarco-doc.passage.tct_colbert-v2-hnp-maxp.trec +$ python -m pyserini.eval.trec_eval -c -mrecall.100 -mmap -mndcg_cut.10 msmarco-doc-dev + +Results: +map all 0.3509 +recall_100 all 0.8908 +ndcg_cut_10 all 0.4123 + +``` + +For TREC-DL19 +```bash +$ python -m pyserini.eval.convert_msmarco_run_to_trec_run --input runs/run.dl19-doc.passage.tct_colbert-v2-hnp-maxp.txt \ + --output runs/run.dl19-doc.passage.tct_colbert-v2-hnp-maxp.trec +$ python -m pyserini.eval.trec_eval -c -mrecall.100 -mmap -mndcg_cut.10 dl19-doc + +Results: +map all 0.2683 +recall_100 all 0.3854 +ndcg_cut_10 all 0.6592 +``` + + + ## Reproduction Log[*](reproducibility.md) + Results reproduced by [@lintool](https://github.com/lintool) on 2021-07-01 (commit [`b1576a2`](https://github.com/castorini/pyserini/commit/b1576a2c3e899349be12e897f92f3ad75ec82d6f)) diff --git a/scripts/tct_colbert/encode_corpus_msmarco_doc.py b/scripts/tct_colbert/encode_corpus_msmarco_doc.py new file mode 100644 index 000000000..9c9b61ca9 --- /dev/null +++ b/scripts/tct_colbert/encode_corpus_msmarco_doc.py @@ -0,0 +1,109 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import json +import os +import sys +import numpy as np +import faiss +import torch +from tqdm import tqdm + +from transformers import BertTokenizer, BertModel + +# We're going to explicitly use a local installation of Pyserini (as opposed to a pip-installed one). +# Comment these lines out to use a pip-installed one instead. +sys.path.insert(0, './') +sys.path.insert(0, '../pyserini/') + +def mean_pooling(last_hidden_state, attention_mask): + token_embeddings = last_hidden_state + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) + return sum_embeddings / sum_mask + + +class TctColBertDocumentEncoder(torch.nn.Module): + def __init__(self, model_name, tokenizer_name=None, device='cuda:0'): + super().__init__() + self.device = device + self.model = BertModel.from_pretrained(model_name) + self.model.to(self.device) + self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or model_name) + + def encode(self, texts, titles=None): + texts = ['[CLS] [D] ' + text for text in texts] + max_length = 512 # hardcode for now + inputs = self.tokenizer( + texts, + max_length=max_length, + padding="longest", + truncation=True, + add_special_tokens=False, + return_tensors='pt' + ) + inputs.to(self.device) + outputs = self.model(**inputs) + embeddings = mean_pooling(outputs["last_hidden_state"][:, 4:, :], inputs['attention_mask'][:, 4:]) + return embeddings.detach().cpu().numpy() + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--encoder', type=str, help='encoder name or path', required=True) + parser.add_argument('--dimension', type=int, help='dimension of passage embeddings', required=False, default=768) + parser.add_argument('--corpus', type=str, + help='collection file to be encoded (format: jsonl)', required=True) + parser.add_argument('--index', type=str, help='directory to store brute force index of corpus', required=True) + parser.add_argument('--batch', type=int, help='batch size', default=8) + parser.add_argument('--device', type=str, help='device cpu or cuda [cuda:0, cuda:1...]', default='cuda:0') + args = parser.parse_args() + + # tokenizer = AutoTokenizer.from_pretrained(args.encoder) + # model = AutoModel.from_pretrained(args.encoder) + model = TctColBertDocumentEncoder(model_name=args.encoder, device=args.device) + + index = faiss.IndexFlatIP(args.dimension) + + if not os.path.exists(args.index): + os.mkdir(args.index) + + texts = [] + with open(os.path.join(args.index, 'docid'), 'w') as id_file: + file = os.path.join(args.corpus) + print(f'Loading {file}') + with open(file, 'r') as corpus: + for idx, line in enumerate(tqdm(corpus.readlines())): + info = json.loads(line) + docid = info['id'] + text = info['contents'] + id_file.write(f'{docid}\n') + # docs can have many \n ... + fields = text.split('\n') + title, text = fields[1], fields[2:] + if len(text) > 1: + text = ' '.join(text) + text = f"{title} {text}" + texts.append(text.lower()) + + for idx in tqdm(range(0, len(texts), args.batch)): + text_batch = texts[idx: idx+args.batch] + embeddings = model.encode(text_batch) + index.add(np.array(embeddings)) + faiss.write_index(index, os.path.join(args.index, 'index')) diff --git a/scripts/tct_colbert/merge_indexes.py b/scripts/tct_colbert/merge_indexes.py new file mode 100755 index 000000000..05919bcbf --- /dev/null +++ b/scripts/tct_colbert/merge_indexes.py @@ -0,0 +1,70 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +import faiss +import os + + +parser = argparse.ArgumentParser() +parser.add_argument('--dimension', type=int, help='dimension of passage embeddings', required=False, default=768) +parser.add_argument('--prefix', type=str, help='directory to store brute force index of corpus', required=True) +parser.add_argument('--segment-num', type=int, help='number of passage segments, use -1 for MaxP', default=1) +parser.add_argument('--shard-num', type=int, help='number of shards', default=1) +args = parser.parse_args() + +new_index = faiss.IndexFlatIP(args.dimension) +docid_list = [] +for i in range(args.shard_num): + index = os.path.join(args.prefix + f"{i:02d}", 'index') + docid = os.path.join(args.prefix + f"{i:02d}", 'docid') + print(f"reading ... {index}") + line_idx = [] + with open(docid, 'r') as f: + for idx, line in enumerate(f): + doc_id, psg_id = line.strip().split("#") + if args.segment_num == -1: + line_idx.append(idx) + docid_list.append(doc_id + "#" + psg_id) + + elif int(psg_id) < args.segment_num: + line_idx.append(idx) + docid_list.append(doc_id + "#" + psg_id) + + index = faiss.read_index(index) + vectors = index.reconstruct_n(0, index.ntotal) + new_index.add(vectors[line_idx]) # filter segments + +if args.segment_num == -1: + postfix = 'maxp' +elif args.segment_num == 1: + postfix = 'firstp' +else: + postfix = f'seg{args.segment_num}' + +if not os.path.exists(args.prefix + f'full-{postfix}'): + os.mkdir(args.prefix + f'full-{postfix}') + + +print(f"number of docs: {len(docid_list)}") +print(f"number of vecs: {new_index.ntotal}") +assert len(docid_list) == new_index.ntotal +faiss.write_index(new_index, os.path.join(args.prefix + f'full-{postfix}', 'index')) + +with open(os.path.join(args.prefix + f'full-{postfix}', 'docid'), 'w') as wfd: + for docid in docid_list: + wfd.write(docid + '\n')