Skip to content

Commit

Permalink
add tct-colbert-v2 doc zero shot exp (#682)
Browse files Browse the repository at this point in the history
* add tct-colbert-v2 doc zero shot exp

Co-authored-by: Lin Jack <[email protected]>
Co-authored-by: Lin Jack <[email protected]>
Co-authored-by: Jack Lin <[email protected]>
Co-authored-by: Jheng-Hong Yang <[email protected]>
  • Loading branch information
5 people authored Jul 13, 2021
1 parent 853d90f commit 5cc929d
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 0 deletions.
110 changes: 110 additions & 0 deletions docs/experiments-tct_colbert-v2.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,116 @@ recall_1000 all 0.9759

Follow the same instructions above to perform on-the-fly query encoding.


## MS MARCO Document Ranking with TCT-ColBERT-V2 (zero-shot)

Document retrieval with TCT-ColBERT, brute-force index:

Step0: prepare docs.json: split docs into segments of passages
Each line contains a json dict as follows:
{"id": "[doc_id]#[seg_id]", "contents": "[url]\n[title]\n[seg_text]"}


Step1: split documents for parallel encoding
```bash
$ split -a 2 -d -n l/50 docs.json collection.part
```

Step2-1: prepare encoder (on CC), you can download encoder using [git-lfs](https://git-lfs.github.com/)

Example (after you install git-lfs):
```bash
git clone https://huggingface.co/castorini/tct_colbert-v2-hnp-msmarco
```

Step2-2: run encoding
```bash
export TASK=msmarco
export ENCODER=tct_colbert-v2-msmarco-hnp
export WORKING_DIR=~/scratch

for i in $(seq -f "%02g" 0 49)
do
srun --gres=gpu:v100:1 --mem=16G --cpus-per-task=2 --time=2:00:00 \
python scripts/tct_colbert/encode_corpus_msmarco_doc.py \
--corpus ${WORKING_DIR}/${TASK}/collection.part${i} \
--encoder ${WORKING_DIR}/checkpoint/${ENCODER} \
--index indexes/${TASK}-${ENCODER}-${i} \
--index indexes/${TASK}-${ENCODER}-${i} \
--batch 16 \
--device cuda:0 &
done
```

Step3: merge / filter index, use --segment-num -1 for maxp (1 for firstp), or anyother interger you like
```bash
$ python scripts/tct_colbert/merge_indexes.py \
--prefix <path_to_index> \
--shard-num 50
--segment-num -1
```

Step4: search (with on-the-fly query encoding)
```bash
$ python -m pyserini.dsearch --topics msmarco-doc-dev \
--index <path_to_index>/msmarco-tct_colbert-v2-hnp-msmarco-full-maxp \
--encoder castorini/tct_colbert-v2-hnp-msmarco \
--output runs/run.msmarco-doc.passage.tct_colbert-v2-hnp-maxp.txt \
--hits 1000 \
--max-passage \
--max-passage-hits 100 \
--output-format msmarco \
--batch-size 144 \
--threads 36

$ python -m pyserini.dsearch --topics dl19-doc \
--index <path_to_index>/msmarco-tct_colbert-v2-hnp-msmarco-full-maxp \
--encoder castorini/tct_colbert-v2-hnp-msmarco \
--output runs/run.dl19-doc.passage.tct_colbert-v2-hnp-maxp.txt \
--hits 1000 \
--max-passage \
--max-passage-hits 100 \
--output-format msmarco \
--batch-size 144 \
--threads 36
```

Step5: eval

For MSMARCO-Doc-dev
```bash
$ python -m pyserini.eval.msmarco_doc_eval --judgments msmarco-doc-dev --run runs/run.msmarco-doc.passage.tct_colbert-v2-hnp-maxp.txt

#####################
MRR @100: 0.3508557690776294
QueriesRanked: 5193
#####################

$ python -m pyserini.eval.convert_msmarco_run_to_trec_run --input runs/run.msmarco-doc.passage.tct_colbert-v2-hnp-maxp.txt \
--output runs/run.msmarco-doc.passage.tct_colbert-v2-hnp-maxp.trec
$ python -m pyserini.eval.trec_eval -c -mrecall.100 -mmap -mndcg_cut.10 msmarco-doc-dev

Results:
map all 0.3509
recall_100 all 0.8908
ndcg_cut_10 all 0.4123

```

For TREC-DL19
```bash
$ python -m pyserini.eval.convert_msmarco_run_to_trec_run --input runs/run.dl19-doc.passage.tct_colbert-v2-hnp-maxp.txt \
--output runs/run.dl19-doc.passage.tct_colbert-v2-hnp-maxp.trec
$ python -m pyserini.eval.trec_eval -c -mrecall.100 -mmap -mndcg_cut.10 dl19-doc

Results:
map all 0.2683
recall_100 all 0.3854
ndcg_cut_10 all 0.6592
```



## Reproduction Log[*](reproducibility.md)

+ Results reproduced by [@lintool](https://github.com/lintool) on 2021-07-01 (commit [`b1576a2`](https://github.com/castorini/pyserini/commit/b1576a2c3e899349be12e897f92f3ad75ec82d6f))
109 changes: 109 additions & 0 deletions scripts/tct_colbert/encode_corpus_msmarco_doc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import json
import os
import sys
import numpy as np
import faiss
import torch
from tqdm import tqdm

from transformers import BertTokenizer, BertModel

# We're going to explicitly use a local installation of Pyserini (as opposed to a pip-installed one).
# Comment these lines out to use a pip-installed one instead.
sys.path.insert(0, './')
sys.path.insert(0, '../pyserini/')

def mean_pooling(last_hidden_state, attention_mask):
token_embeddings = last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask


class TctColBertDocumentEncoder(torch.nn.Module):
def __init__(self, model_name, tokenizer_name=None, device='cuda:0'):
super().__init__()
self.device = device
self.model = BertModel.from_pretrained(model_name)
self.model.to(self.device)
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name or model_name)

def encode(self, texts, titles=None):
texts = ['[CLS] [D] ' + text for text in texts]
max_length = 512 # hardcode for now
inputs = self.tokenizer(
texts,
max_length=max_length,
padding="longest",
truncation=True,
add_special_tokens=False,
return_tensors='pt'
)
inputs.to(self.device)
outputs = self.model(**inputs)
embeddings = mean_pooling(outputs["last_hidden_state"][:, 4:, :], inputs['attention_mask'][:, 4:])
return embeddings.detach().cpu().numpy()



if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--encoder', type=str, help='encoder name or path', required=True)
parser.add_argument('--dimension', type=int, help='dimension of passage embeddings', required=False, default=768)
parser.add_argument('--corpus', type=str,
help='collection file to be encoded (format: jsonl)', required=True)
parser.add_argument('--index', type=str, help='directory to store brute force index of corpus', required=True)
parser.add_argument('--batch', type=int, help='batch size', default=8)
parser.add_argument('--device', type=str, help='device cpu or cuda [cuda:0, cuda:1...]', default='cuda:0')
args = parser.parse_args()

# tokenizer = AutoTokenizer.from_pretrained(args.encoder)
# model = AutoModel.from_pretrained(args.encoder)
model = TctColBertDocumentEncoder(model_name=args.encoder, device=args.device)

index = faiss.IndexFlatIP(args.dimension)

if not os.path.exists(args.index):
os.mkdir(args.index)

texts = []
with open(os.path.join(args.index, 'docid'), 'w') as id_file:
file = os.path.join(args.corpus)
print(f'Loading {file}')
with open(file, 'r') as corpus:
for idx, line in enumerate(tqdm(corpus.readlines())):
info = json.loads(line)
docid = info['id']
text = info['contents']
id_file.write(f'{docid}\n')
# docs can have many \n ...
fields = text.split('\n')
title, text = fields[1], fields[2:]
if len(text) > 1:
text = ' '.join(text)
text = f"{title} {text}"
texts.append(text.lower())

for idx in tqdm(range(0, len(texts), args.batch)):
text_batch = texts[idx: idx+args.batch]
embeddings = model.encode(text_batch)
index.add(np.array(embeddings))
faiss.write_index(index, os.path.join(args.index, 'index'))
70 changes: 70 additions & 0 deletions scripts/tct_colbert/merge_indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse

import faiss
import os


parser = argparse.ArgumentParser()
parser.add_argument('--dimension', type=int, help='dimension of passage embeddings', required=False, default=768)
parser.add_argument('--prefix', type=str, help='directory to store brute force index of corpus', required=True)
parser.add_argument('--segment-num', type=int, help='number of passage segments, use -1 for MaxP', default=1)
parser.add_argument('--shard-num', type=int, help='number of shards', default=1)
args = parser.parse_args()

new_index = faiss.IndexFlatIP(args.dimension)
docid_list = []
for i in range(args.shard_num):
index = os.path.join(args.prefix + f"{i:02d}", 'index')
docid = os.path.join(args.prefix + f"{i:02d}", 'docid')
print(f"reading ... {index}")
line_idx = []
with open(docid, 'r') as f:
for idx, line in enumerate(f):
doc_id, psg_id = line.strip().split("#")
if args.segment_num == -1:
line_idx.append(idx)
docid_list.append(doc_id + "#" + psg_id)

elif int(psg_id) < args.segment_num:
line_idx.append(idx)
docid_list.append(doc_id + "#" + psg_id)

index = faiss.read_index(index)
vectors = index.reconstruct_n(0, index.ntotal)
new_index.add(vectors[line_idx]) # filter segments

if args.segment_num == -1:
postfix = 'maxp'
elif args.segment_num == 1:
postfix = 'firstp'
else:
postfix = f'seg{args.segment_num}'

if not os.path.exists(args.prefix + f'full-{postfix}'):
os.mkdir(args.prefix + f'full-{postfix}')


print(f"number of docs: {len(docid_list)}")
print(f"number of vecs: {new_index.ntotal}")
assert len(docid_list) == new_index.ntotal
faiss.write_index(new_index, os.path.join(args.prefix + f'full-{postfix}', 'index'))

with open(os.path.join(args.prefix + f'full-{postfix}', 'docid'), 'w') as wfd:
for docid in docid_list:
wfd.write(docid + '\n')

0 comments on commit 5cc929d

Please sign in to comment.