diff --git a/rag/benchmark.py b/rag/benchmark.py index aea4ef99c5..237f2be7e9 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -34,12 +34,13 @@ class Benchmark: def __init__(self, kb_id): - e, kb = KnowledgebaseService.get_by_id(kb_id) - self.similarity_threshold = kb.similarity_threshold - self.vector_similarity_weight = kb.vector_similarity_weight - self.embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language) + e, self.kb = KnowledgebaseService.get_by_id(kb_id) + self.similarity_threshold = self.kb.similarity_threshold + self.vector_similarity_weight = self.kb.vector_similarity_weight + self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language) def _get_benchmarks(self, query, dataset_idxnm, count=16): + req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold} sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl) return sres @@ -48,11 +49,15 @@ def _get_retrieval(self, qrels, dataset_idxnm): run = defaultdict(dict) query_list = list(qrels.keys()) for query in query_list: - sres = self._get_benchmarks(query, dataset_idxnm) - sim, _, _ = retrievaler.rerank(sres, query, 1 - self.vector_similarity_weight, - self.vector_similarity_weight) - for index, id in enumerate(sres.ids): - run[query][id] = sim[index] + + ranks = retrievaler.retrieval(query, self.embd_mdl, dataset_idxnm.replace("ragflow_", ""), + [self.kb.id], 0, 30, + 0.0, self.vector_similarity_weight) + for c in ranks["chunks"]: + if "vector" in c: + del c["vector"] + run[query][c["chunk_id"]] = c["similarity"] + return run def embedding(self, docs, batch_size=16): @@ -99,7 +104,8 @@ def slow_actions(es_docs, idx_nm): query = data.iloc[i]['query'] for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']): d = { - "id": get_uuid() + "id": get_uuid(), + "kb_id": self.kb.id } tokenize(d, text, "english") docs.append(d) @@ -208,6 +214,8 @@ def save_results(self, qrels, run, texts, dataset, file_path): scores = sorted(scores, key=lambda kk: kk[1]) for score in scores[:10]: f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n') + json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2) + json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2) print(os.path.join(file_path, dataset + '_result.md'), 'Saved!') def __call__(self, dataset, file_path, miracl_corpus=''): diff --git a/rag/nlp/search.py b/rag/nlp/search.py index a7740fafca..8bdc96393e 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -211,8 +211,8 @@ def getFields(self, sres, flds): continue if not isinstance(v, type("")): m[n] = str(m[n]) - if n.find("tks") > 0: - m[n] = rmSpace(m[n]) + #if n.find("tks") > 0: + # m[n] = rmSpace(m[n]) if m: res[d["id"]] = m