Skip to content

Commit

Permalink
feat: support xinference rerank model (#1466)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

support xinference rerank model
#1455 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
hwzhuhao authored Jul 11, 2024
1 parent 9c023b6 commit 009e18f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 3 deletions.
6 changes: 3 additions & 3 deletions api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def init_superuser():
"name": "Ollama",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
"status": "1",
}, {
"name": "Moonshot",
"logo": "",
Expand All @@ -123,8 +123,8 @@ def init_superuser():
}, {
"name": "Xinference",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION,TEXT RE-RANK",
"status": "1",
},{
"name": "Youdao",
"logo": "",
Expand Down
1 change: 1 addition & 0 deletions rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@
"BAAI": DefaultRerank,
"Jina": JinaRerank,
"Youdao": YoudaoRerank,
"Xinference": XInferenceRerank
}
18 changes: 18 additions & 0 deletions rag/llm/rerank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,22 @@ def similarity(self, query: str, texts: list):
else: res.extend(scores)
return np.array(res), token_count

class XInferenceRerank(Base):
def __init__(self,model_name="",base_url=""):
self.model_name=model_name
self.base_url=base_url
self.headers = {
"Content-Type": "application/json",
"accept": "application/json"
}

def similarity(self, query: str, texts: list):
data = {
"model":self.model_name,
"query":query,
"return_documents": "true",
"return_len": "true",
"documents":texts
}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
return np.array([d["relevance_score"] for d in res["results"]]),res["tokens"]["input_tokens"]+res["tokens"]["output_tokens"]
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ const OllamaModal = ({
<Select placeholder={t('modelTypeMessage')}>
<Option value="chat">chat</Option>
<Option value="embedding">embedding</Option>
<Option value="rerank">rerank</Option>
</Select>
</Form.Item>
<Form.Item<FieldType>
Expand Down

0 comments on commit 009e18f

Please sign in to comment.