From 86ae772d17f897b3a4b1ed7d20e1a3d9e99f4c9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E8=85=BE?= <101850389+hangters@users.noreply.github.com> Date: Thu, 29 Aug 2024 13:30:06 +0800 Subject: [PATCH] add support for Anthropic (#2148) ### What problem does this PR solve? #1853 add support for Anthropic ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen Co-authored-by: Kevin Hu --- conf/llm_factories.json | 50 +++++++++++++ rag/llm/__init__.py | 3 +- rag/llm/chat_model.py | 72 +++++++++++++++++-- requirements.txt | 1 + requirements_arm.txt | 1 + web/src/assets/svg/llm/anthropic.svg | 1 + .../user-setting/setting-model/constant.ts | 1 + 7 files changed, 124 insertions(+), 5 deletions(-) create mode 100644 web/src/assets/svg/llm/anthropic.svg diff --git a/conf/llm_factories.json b/conf/llm_factories.json index c2b2cc69fbd..bd0ac30ab52 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -3240,6 +3240,56 @@ "tags": "SPEECH2TEXT", "status": "1", "llm": [] + }, + { + "name": "Anthropic", + "logo": "", + "tags": "LLM", + "status": "1", + "llm": [ + { + "llm_name": "claude-3-5-sonnet-20240620", + "tags": "LLM,CHAT,200k", + "max_tokens": 204800, + "model_type": "chat" + }, + { + "llm_name": "claude-3-opus-20240229", + "tags": "LLM,CHAT,200k", + "max_tokens": 204800, + "model_type": "chat" + }, + { + "llm_name": "claude-3-sonnet-20240229", + "tags": "LLM,CHAT,200k", + "max_tokens": 204800, + "model_type": "chat" + }, + { + "llm_name": "claude-3-haiku-20240307", + "tags": "LLM,CHAT,200k", + "max_tokens": 204800, + "model_type": "chat" + }, + { + "llm_name": "claude-2.1", + "tags": "LLM,CHAT,200k", + "max_tokens": 204800, + "model_type": "chat" + }, + { + "llm_name": "claude-2.0", + "tags": "LLM,CHAT,100k", + "max_tokens": 102400, + "model_type": "chat" + }, + { + "llm_name": "claude-instant-1.2", + "tags": "LLM,CHAT,100k", + "max_tokens": 102400, + "model_type": "chat" + } + ] } ] } diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index adcb53f1ca9..ef37d6446fa 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -104,7 +104,8 @@ "Replicate": ReplicateChat, "Tencent Hunyuan": HunyuanChat, "XunFei Spark": SparkChat, - "BaiduYiyan": BaiduYiyanChat + "BaiduYiyan": BaiduYiyanChat, + "Anthropic": AnthropicChat } diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 64c39912ffa..3af0e0257a4 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1132,7 +1132,7 @@ def __init__( class BaiduYiyanChat(Base): def __init__(self, key, model_name, base_url=None): import qianfan - + key = json.loads(key) ak = key.get("yiyan_ak","") sk = key.get("yiyan_sk","") @@ -1149,7 +1149,7 @@ def chat(self, system, history, gen_conf): if "max_tokens" in gen_conf: gen_conf["max_output_tokens"] = gen_conf["max_tokens"] ans = "" - + try: response = self.client.do( model=self.model_name, @@ -1159,7 +1159,7 @@ def chat(self, system, history, gen_conf): ).body ans = response['result'] return ans, response["usage"]["total_tokens"] - + except Exception as e: return ans + "\n**ERROR**: " + str(e), 0 @@ -1173,7 +1173,7 @@ def chat_streamly(self, system, history, gen_conf): gen_conf["max_output_tokens"] = gen_conf["max_tokens"] ans = "" total_tokens = 0 - + try: response = self.client.do( model=self.model_name, @@ -1193,3 +1193,67 @@ def chat_streamly(self, system, history, gen_conf): return ans + "\n**ERROR**: " + str(e), 0 yield total_tokens + + +class AnthropicChat(Base): + def __init__(self, key, model_name, base_url=None): + import anthropic + + self.client = anthropic.Anthropic(api_key=key) + self.model_name = model_name + self.system = "" + + def chat(self, system, history, gen_conf): + if system: + self.system = system + if "max_tokens" not in gen_conf: + gen_conf["max_tokens"] = 4096 + + try: + response = self.client.messages.create( + model=self.model_name, + messages=history, + system=self.system, + stream=False, + **gen_conf, + ).json() + ans = response["content"][0]["text"] + if response["stop_reason"] == "max_tokens": + ans += ( + "...\nFor the content length reason, it stopped, continue?" + if is_english([ans]) + else "······\n由于长度的原因,回答被截断了,要继续吗?" + ) + return ( + ans, + response["usage"]["input_tokens"] + response["usage"]["output_tokens"], + ) + except Exception as e: + return ans + "\n**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if system: + self.system = system + if "max_tokens" not in gen_conf: + gen_conf["max_tokens"] = 4096 + + ans = "" + total_tokens = 0 + try: + response = self.client.messages.create( + model=self.model_name, + messages=history, + system=self.system, + stream=True, + **gen_conf, + ) + for res in response.iter_lines(): + res = res.decode("utf-8") + if "content_block_delta" in res and "data" in res: + text = json.loads(res[6:])["delta"]["text"] + ans += text + total_tokens += num_tokens_from_string(text) + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens diff --git a/requirements.txt b/requirements.txt index 094e20151fe..0bcd697af4b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +anthropic===0.34.1 arxiv==2.1.3 Aspose.Slides==24.2.0 BCEmbedding==0.1.3 diff --git a/requirements_arm.txt b/requirements_arm.txt index d064ee40ace..1207d6d8d21 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -2,6 +2,7 @@ accelerate==0.27.2 aiohttp==3.9.4 aiosignal==1.3.1 annotated-types==0.6.0 +anthropic===0.34.1 anyio==4.3.0 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 diff --git a/web/src/assets/svg/llm/anthropic.svg b/web/src/assets/svg/llm/anthropic.svg new file mode 100644 index 00000000000..249c9503cb8 --- /dev/null +++ b/web/src/assets/svg/llm/anthropic.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/src/pages/user-setting/setting-model/constant.ts b/web/src/pages/user-setting/setting-model/constant.ts index dae68f74672..33cf76d86a9 100644 --- a/web/src/pages/user-setting/setting-model/constant.ts +++ b/web/src/pages/user-setting/setting-model/constant.ts @@ -37,6 +37,7 @@ export const IconMap = { BaiduYiyan: 'yiyan', 'Fish Audio': 'fish-audio', 'Tencent Cloud': 'tencent-cloud', + Anthropic: 'anthropic', }; export const BedrockRegionList = [