Skip to content

Commit

Permalink
Add colbert version of IBM reranker (#918)
Browse files Browse the repository at this point in the history
* add colbert version of ibm reranker

* fix a bug on selftranslation ratio adjustment

Co-authored-by: Yuqi Liu <[email protected]>
Co-authored-by: stephanie <[email protected]>
  • Loading branch information
3 people authored Jan 3, 2022
1 parent ab72a80 commit 937ec63
Showing 1 changed file with 26 additions and 22 deletions.
48 changes: 26 additions & 22 deletions scripts/rank_ibm.py → scripts/reranker_ibm_colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def get_ibm_score(arguments):
target_lookup = arguments['target_lookup']
tran = arguments['tran']
collect_probs = arguments['collect_probs']
max_sim = arguments['max_sim']

if searcher.documentRaw(test_doc) ==None:
print(f'{test_doc} is not found in searcher')
Expand All @@ -121,22 +122,24 @@ def get_ibm_score(arguments):
target_map = {}
total_tran_prob = 0
collect_prob = collect_probs[querytoken]
max_sim_score = 0
if querytoken in target_lookup.keys():
query_word_id = target_lookup[querytoken]
if query_word_id in tran.keys():
target_map = tran[query_word_id]
for doctoken in doc_token_lst:
tran_prob = 0
doc_word_id = 0
if querytoken==doctoken:
tran_prob = SELF_TRAN
doc_word_id = 0
if doctoken in source_lookup.keys():
doc_word_id = source_lookup[doctoken]
if doc_word_id in target_map.keys():
tran_prob = max(target_map[doc_word_id],tran_prob)
total_tran_prob += (tran_prob/doc_size)

query_word_prob=math.log((1 - LAMBDA_VALUE) * total_tran_prob + LAMBDA_VALUE * collect_prob)
tran_prob = max(target_map[doc_word_id],tran_prob)
max_sim_score = max(tran_prob, max_sim_score)
total_tran_prob += (tran_prob/doc_size)
if (max_sim):
query_word_prob=math.log((1 - LAMBDA_VALUE) * max_sim_score + LAMBDA_VALUE * collect_prob)
else:
query_word_prob=math.log((1 - LAMBDA_VALUE) * total_tran_prob + LAMBDA_VALUE * collect_prob)

total_query_prob += query_word_prob
return total_query_prob /query_size
Expand Down Expand Up @@ -164,28 +167,28 @@ def intbits_to_float(b: bytes):
def rescale(source_lookup: Dict[str,int],target_lookup: Dict[str,int],tran_lookup: Dict[str,Dict[str,float]],\
target_voc: Dict[int,str],source_voc: Dict[int,str]):
for target_id in tran_lookup:
target_probs = tran_lookup[target_id]
if target_id > 0:
adjust_mult = (1 - SELF_TRAN)
else:
adjust_mult = 1
#adjust the prob with adjust_mult and add SELF_TRAN prob to self-translation pair
for source_id in target_probs.keys():
tran_prob = target_probs[source_id]
for source_id in tran_lookup[target_id].keys():
tran_prob = tran_lookup[target_id][source_id]
if source_id >0:
source_word = source_voc[source_id]
target_word = target_voc[target_id]
tran_prob *= adjust_mult
if (source_word== target_word):
tran_prob += SELF_TRAN
target_probs[source_id]= tran_prob
tran_lookup[target_id][source_id]= tran_prob
# in case if self-translation pair was not included in TransTable
if target_id not in target_probs.keys():
target_probs[target_id]= SELF_TRAN
if target_id not in tran_lookup[target_id].keys():
target_word = target_voc[target_id]
source_id = source_lookup[target_word]
tran_lookup[target_id][source_id]= SELF_TRAN
return source_lookup,target_lookup,tran_lookup



def load_tranprobs_table(dir_path: str):
source_path = dir_path +"/source.vcb"
source_lookup = {}
Expand Down Expand Up @@ -229,7 +232,7 @@ def load_tranprobs_table(dir_path: str):


def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path: str,output_path:str, \
score_path:str,field_name:str, tag: str,alpha:int,num_threads:int):
score_path:str,field_name:str, tag: str,alpha:int,num_threads:int, max_sim:bool):

pool = ThreadPool(num_threads)
searcher = JSimpleSearcher(JString(lucene_index_path))
Expand Down Expand Up @@ -257,7 +260,7 @@ def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path:
collect_probs[querytoken] = max(reader.totalTermFreq(JTerm(field_name, querytoken))/total_term_freq, MIN_COLLECT_PROB)
arguments = [{"query_text_lst":query_text_lst,"test_doc":test_doc, "searcher":searcher,\
"field_name":field_name,"source_lookup":source_lookup,"target_lookup":target_lookup,\
"tran":tran,"collect_probs":collect_probs} for test_doc in test_docs]
"tran":tran,"collect_probs":collect_probs, "max_sim":max_sim} for test_doc in test_docs]
rank_scores = pool.map(get_ibm_score, arguments)

ibm_scores = normalize([p for p in rank_scores])
Expand All @@ -270,7 +273,6 @@ def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path:
rank = index + 1
f.write(f'{topic} Q0 {doc_id} {rank} {score} {tag}\n')


f.close()
map_score,ndcg_score = evaluate(qrels, output_path)
with open(score_path, 'w') as outfile:
Expand All @@ -286,25 +288,27 @@ def rank(qrels: str, base: str,tran_path:str, query_path:str, lucene_index_path:
metavar="path_to_qrels", help='path to new_qrels file')
parser.add_argument('-base', type=str, default="../ibm/run.msmarco-passage.bm25tuned.trec",
metavar="path_to_base_run", help='path to base run')
parser.add_argument('-tran_path', type=str, default="../ibm/ibm_model/text_bert_tok",
parser.add_argument('-tran_path', type=str, default="../ibm/ibm_model/text_bert_tok_raw",
metavar="directory_path", help='directory path to source.vcb target.vcb and Transtable bin file')
parser.add_argument('-query_path', type=str, default="../ibm/queries.dev.small.json",
metavar="path_to_query", help='path to dev queries file')
parser.add_argument('-index', type=str, default="../ibm/index-msmarco-passage-ltr-20210519-e25e33f",
metavar="path_to_lucene_index", help='path to lucene index folder')
parser.add_argument('-output', type=str, default="../ibm/runs/result-text-bert-tuned0.1.txt",
parser.add_argument('-output', type=str, default="../ibm/runs/result-colbert-test-alpha0.3.txt",
metavar="path_to_reranked_run", help='the path to store reranked run file')
parser.add_argument('-score_path', type=str, default="../ibm/result-ibm-0.1.json",
parser.add_argument('-score_path', type=str, default="../ibm/runs/result-colbert-test-alpha0.3.json",
metavar="path_to_base_run", help='the path to map and ndcg scores')
parser.add_argument('-field_name', type=str, default="text_bert_tok",
metavar="type of field", help='type of field used for training')
parser.add_argument('-alpha', type=float, default="0.1",
parser.add_argument('-alpha', type=float, default="0.3",
metavar="type of field", help='interpolation weight')
parser.add_argument('-num_threads', type=int, default="12",
metavar="num_of_threads", help='number of threads to use')
parser.add_argument('-max_sim', type=bool, default=True,
metavar="bool for max sim operator", help='whether we use max sim operator or avg instead')
args = parser.parse_args()

print('Using base run:', args.base)

rank(args.qrels, args.base, args.tran_path, args.query_path, args.index, args.output, \
args.score_path,args.field_name, args.tag,args.alpha,args.num_threads)
args.score_path,args.field_name, args.tag,args.alpha,args.num_threads, args.max_sim)

0 comments on commit 937ec63

Please sign in to comment.