Skip to content

Commit

Permalink
Use parameter for reranker (#177)
Browse files Browse the repository at this point in the history
* Use parameter for reranker

Signed-off-by: Liangyx2 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Liangyx2 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sun, Xuehao <[email protected]>
Co-authored-by: chen, suyue <[email protected]>
  • Loading branch information
4 people authored Jun 19, 2024
1 parent 9e91843 commit dfdd08c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
1 change: 1 addition & 0 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class EmbedDoc1024(BaseDoc):
class SearchedDoc(BaseDoc):
retrieved_docs: DocList[TextDoc]
initial_query: str
top_n: int = 1

class Config:
json_encoders = {np.ndarray: lambda x: x.tolist()}
Expand Down
9 changes: 9 additions & 0 deletions comps/reranks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,12 @@ curl http://localhost:8000/v1/reranking \
-d '{"initial_query":"What is Deep Learning?", "retrieved_docs": [{"text":"Deep Learning is not..."}, {"text":"Deep learning is..."}]}' \
-H 'Content-Type: application/json'
```

You can add the parameter `top_n` to specify the return number of the reranker model, default value is 1.

```bash
curl http://localhost:8000/v1/reranking \
-X POST \
-d '{"initial_query":"What is Deep Learning?", "retrieved_docs": [{"text":"Deep Learning is not..."}, {"text":"Deep learning is..."}], "top_n":2}' \
-H 'Content-Type: application/json'
```
11 changes: 7 additions & 4 deletions comps/reranks/langchain/reranking_tei_xeon.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import heapq
import json
import os
import re
Expand Down Expand Up @@ -40,9 +41,11 @@ def reranking(input: SearchedDoc) -> LLMParamsDoc:
headers = {"Content-Type": "application/json"}
response = requests.post(url, data=json.dumps(data), headers=headers)
response_data = response.json()
best_response = max(response_data, key=lambda response: response["score"])
doc = input.retrieved_docs[best_response["index"]]
if doc.text and len(re.findall("[\u4E00-\u9FFF]", doc.text)) / len(doc.text) >= 0.3:
best_response_list = heapq.nlargest(input.top_n, response_data, key=lambda x: x["score"])
context_str = ""
for best_response in best_response_list:
context_str = context_str + " " + input.retrieved_docs[best_response["index"]].text
if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3:
# chinese context
template = "仅基于以下背景回答问题:\n{context}\n问题: {question}"
else:
Expand All @@ -51,7 +54,7 @@ def reranking(input: SearchedDoc) -> LLMParamsDoc:
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
final_prompt = prompt.format(context=doc.text, question=input.initial_query)
final_prompt = prompt.format(context=context_str, question=input.initial_query)
statistics_dict["opea_service@reranking_tgi_gaudi"].append_latency(time.time() - start, None)
return LLMParamsDoc(query=final_prompt.strip())

Expand Down

0 comments on commit dfdd08c

Please sign in to comment.