-
Notifications
You must be signed in to change notification settings - Fork 3
/
evaluate_fp16.py
122 lines (92 loc) · 5.34 KB
/
evaluate_fp16.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
In this example, we show how to evaluate different floating point precision (fp16, fp8) based faiss indexes for evaluation in BEIR.
We currently support SQFaissSearch from faiss indexes. Faiss indexes are stored and retrieved using the CPU.
Some good notes for information on different faiss indexes can be found here:
1. https://github.com/facebookresearch/faiss/wiki/Faiss-indexes#supported-operations
2. https://github.com/facebookresearch/faiss/wiki/Faiss-building-blocks:-clustering,-PCA,-quantization
For more information, please refer here: https://github.com/facebookresearch/faiss/wiki
PS: You can also save/load your corpus embeddings as a faiss index! Instead of exact search, use FlatIPFaissSearch
which implements exhaustive search using a faiss index.
Usage: python evaluate_fp16.py --dataset nfcorpus
"""
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import SQFaissSearch
import logging
import pathlib, os
import random
import argparse
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### /print debug information to stdout
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True)
parser.add_argument("--model_path", type=str, default="msmarco-distilbert-base-tas-b")
args = parser.parse_args()
#### Download scifact.zip dataset and unzip the dataset
dataset = args.dataset
model_path = args.model_path
#### Download nfcorpus.zip dataset and unzip the dataset
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
data_path = util.download_and_unzip(url, out_dir)
#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
# data folder would contain these files:
# (1) nfcorpus/corpus.jsonl (format: jsonlines)
# (2) nfcorpus/queries.jsonl (format: jsonlines)
# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
# Dense Retrieval using Different Faiss Indexes (Flat or ANN) ####
# Provide any Sentence-Transformer or Dense Retriever model.
model = models.SentenceBERT(model_path)
#### Scalar Quantization (Exhaustive Search)
###################################################
#### fp-16 floating point precision evaluation ####
###################################################
faiss_search = SQFaissSearch(model, batch_size=128)
##################################################
#### fp-8 floating point precision evaluation ####
##################################################
# faiss_search = SQFaissSearch(model, batch_size=128, quantizer_type="QT_8bit_uniform")
#### Load faiss index from file or disk ####
# We need two files to be present within the input_dir!
# 1. input_dir/{prefix}.{ext}.faiss => which loads the faiss index.
# 2. input_dir/{prefix}.{ext}.faiss => which loads mapping of ids i.e. (beir-doc-id \t faiss-doc-id).
prefix = "my-index" # (default value)
ext = "sq" # or "pq", "hnsw"
input_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "faiss-index")
if os.path.exists(os.path.join(input_dir, "{}.{}.faiss".format(prefix, ext))):
faiss_search.load(input_dir=input_dir, prefix=prefix, ext=ext)
#### Retrieve dense results (format of results is identical to qrels)
retriever = EvaluateRetrieval(faiss_search, score_function="dot") # or "cos_sim"
results = retriever.retrieve(corpus, queries)
### Save faiss index into file or disk ####
# Unfortunately faiss only supports integer doc-ids, We need save two files in output_dir.
# 1. output_dir/{prefix}.{ext}.faiss => which saves the faiss index.
# 2. output_dir/{prefix}.{ext}.faiss => which saves mapping of ids i.e. (beir-doc-id \t faiss-doc-id).
prefix = "my-index" # (default value)
ext = "sq" # or "pq", "hnsw"
output_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "faiss-index")
os.makedirs(output_dir, exist_ok=True)
if not os.path.exists(os.path.join(output_dir, "{}.{}.faiss".format(prefix, ext))):
faiss_search.save(output_dir=output_dir, prefix=prefix, ext=ext)
#### Evaluate your retrieval using NDCG@k, MAP@K ...
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap")
hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
#### Print top-k documents retrieved ####
top_k = 10
query_id, ranking_scores = random.choice(list(results.items()))
scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
logging.info("Query : %s\n" % queries[query_id])
for rank in range(top_k):
doc_id = scores_sorted[rank][0]
# Format: Rank x: ID [Title] Body
logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))