Skip to content

Commit

Permalink
fix term weight issue (#3306)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
  • Loading branch information
KevinHuSh authored Nov 8, 2024
1 parent 74d1eeb commit 004487c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
28 changes: 18 additions & 10 deletions rag/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@

class Benchmark:
def __init__(self, kb_id):
e, kb = KnowledgebaseService.get_by_id(kb_id)
self.similarity_threshold = kb.similarity_threshold
self.vector_similarity_weight = kb.vector_similarity_weight
self.embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
e, self.kb = KnowledgebaseService.get_by_id(kb_id)
self.similarity_threshold = self.kb.similarity_threshold
self.vector_similarity_weight = self.kb.vector_similarity_weight
self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)

def _get_benchmarks(self, query, dataset_idxnm, count=16):

req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
return sres
Expand All @@ -48,11 +49,15 @@ def _get_retrieval(self, qrels, dataset_idxnm):
run = defaultdict(dict)
query_list = list(qrels.keys())
for query in query_list:
sres = self._get_benchmarks(query, dataset_idxnm)
sim, _, _ = retrievaler.rerank(sres, query, 1 - self.vector_similarity_weight,
self.vector_similarity_weight)
for index, id in enumerate(sres.ids):
run[query][id] = sim[index]

ranks = retrievaler.retrieval(query, self.embd_mdl, dataset_idxnm.replace("ragflow_", ""),
[self.kb.id], 0, 30,
0.0, self.vector_similarity_weight)
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
run[query][c["chunk_id"]] = c["similarity"]

return run

def embedding(self, docs, batch_size=16):
Expand Down Expand Up @@ -99,7 +104,8 @@ def slow_actions(es_docs, idx_nm):
query = data.iloc[i]['query']
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
d = {
"id": get_uuid()
"id": get_uuid(),
"kb_id": self.kb.id
}
tokenize(d, text, "english")
docs.append(d)
Expand Down Expand Up @@ -208,6 +214,8 @@ def save_results(self, qrels, run, texts, dataset, file_path):
scores = sorted(scores, key=lambda kk: kk[1])
for score in scores[:10]:
f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')

def __call__(self, dataset, file_path, miracl_corpus=''):
Expand Down
4 changes: 2 additions & 2 deletions rag/nlp/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def getFields(self, sres, flds):
continue
if not isinstance(v, type("")):
m[n] = str(m[n])
if n.find("tks") > 0:
m[n] = rmSpace(m[n])
#if n.find("tks") > 0:
# m[n] = rmSpace(m[n])

if m:
res[d["id"]] = m
Expand Down

0 comments on commit 004487c

Please sign in to comment.