Skip to content

Commit

Permalink
add support for Anthropic (infiniflow#2148)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#1853  add support for Anthropic

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>
  • Loading branch information
3 people authored Aug 29, 2024
1 parent 405fdff commit 86ae772
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 5 deletions.
50 changes: 50 additions & 0 deletions conf/llm_factories.json
Original file line number Diff line number Diff line change
Expand Up @@ -3240,6 +3240,56 @@
"tags": "SPEECH2TEXT",
"status": "1",
"llm": []
},
{
"name": "Anthropic",
"logo": "",
"tags": "LLM",
"status": "1",
"llm": [
{
"llm_name": "claude-3-5-sonnet-20240620",
"tags": "LLM,CHAT,200k",
"max_tokens": 204800,
"model_type": "chat"
},
{
"llm_name": "claude-3-opus-20240229",
"tags": "LLM,CHAT,200k",
"max_tokens": 204800,
"model_type": "chat"
},
{
"llm_name": "claude-3-sonnet-20240229",
"tags": "LLM,CHAT,200k",
"max_tokens": 204800,
"model_type": "chat"
},
{
"llm_name": "claude-3-haiku-20240307",
"tags": "LLM,CHAT,200k",
"max_tokens": 204800,
"model_type": "chat"
},
{
"llm_name": "claude-2.1",
"tags": "LLM,CHAT,200k",
"max_tokens": 204800,
"model_type": "chat"
},
{
"llm_name": "claude-2.0",
"tags": "LLM,CHAT,100k",
"max_tokens": 102400,
"model_type": "chat"
},
{
"llm_name": "claude-instant-1.2",
"tags": "LLM,CHAT,100k",
"max_tokens": 102400,
"model_type": "chat"
}
]
}
]
}
3 changes: 2 additions & 1 deletion rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@
"Replicate": ReplicateChat,
"Tencent Hunyuan": HunyuanChat,
"XunFei Spark": SparkChat,
"BaiduYiyan": BaiduYiyanChat
"BaiduYiyan": BaiduYiyanChat,
"Anthropic": AnthropicChat
}


Expand Down
72 changes: 68 additions & 4 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ def __init__(
class BaiduYiyanChat(Base):
def __init__(self, key, model_name, base_url=None):
import qianfan

key = json.loads(key)
ak = key.get("yiyan_ak","")
sk = key.get("yiyan_sk","")
Expand All @@ -1149,7 +1149,7 @@ def chat(self, system, history, gen_conf):
if "max_tokens" in gen_conf:
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
ans = ""

try:
response = self.client.do(
model=self.model_name,
Expand All @@ -1159,7 +1159,7 @@ def chat(self, system, history, gen_conf):
).body
ans = response['result']
return ans, response["usage"]["total_tokens"]

except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0

Expand All @@ -1173,7 +1173,7 @@ def chat_streamly(self, system, history, gen_conf):
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
ans = ""
total_tokens = 0

try:
response = self.client.do(
model=self.model_name,
Expand All @@ -1193,3 +1193,67 @@ def chat_streamly(self, system, history, gen_conf):
return ans + "\n**ERROR**: " + str(e), 0

yield total_tokens


class AnthropicChat(Base):
def __init__(self, key, model_name, base_url=None):
import anthropic

self.client = anthropic.Anthropic(api_key=key)
self.model_name = model_name
self.system = ""

def chat(self, system, history, gen_conf):
if system:
self.system = system
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096

try:
response = self.client.messages.create(
model=self.model_name,
messages=history,
system=self.system,
stream=False,
**gen_conf,
).json()
ans = response["content"][0]["text"]
if response["stop_reason"] == "max_tokens":
ans += (
"...\nFor the content length reason, it stopped, continue?"
if is_english([ans])
else "······\n由于长度的原因,回答被截断了,要继续吗?"
)
return (
ans,
response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
)
except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0

def chat_streamly(self, system, history, gen_conf):
if system:
self.system = system
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096

ans = ""
total_tokens = 0
try:
response = self.client.messages.create(
model=self.model_name,
messages=history,
system=self.system,
stream=True,
**gen_conf,
)
for res in response.iter_lines():
res = res.decode("utf-8")
if "content_block_delta" in res and "data" in res:
text = json.loads(res[6:])["delta"]["text"]
ans += text
total_tokens += num_tokens_from_string(text)
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)

yield total_tokens
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
anthropic===0.34.1
arxiv==2.1.3
Aspose.Slides==24.2.0
BCEmbedding==0.1.3
Expand Down
1 change: 1 addition & 0 deletions requirements_arm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ accelerate==0.27.2
aiohttp==3.9.4
aiosignal==1.3.1
annotated-types==0.6.0
anthropic===0.34.1
anyio==4.3.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
Expand Down
1 change: 1 addition & 0 deletions web/src/assets/svg/llm/anthropic.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions web/src/pages/user-setting/setting-model/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export const IconMap = {
BaiduYiyan: 'yiyan',
'Fish Audio': 'fish-audio',
'Tencent Cloud': 'tencent-cloud',
Anthropic: 'anthropic',
};

export const BedrockRegionList = [
Expand Down

0 comments on commit 86ae772

Please sign in to comment.