From 122d7109dcdc83fdf5b19b32a838bbb5a0fb3310 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Mon, 11 Nov 2024 11:54:14 +0800 Subject: [PATCH] fix: Anthropic param error (#3327) ### What problem does this PR solve? #3263 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- rag/llm/chat_model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index ebc48954c0..9060fd4508 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1249,6 +1249,8 @@ def chat(self, system, history, gen_conf): self.system = system if "max_tokens" not in gen_conf: gen_conf["max_tokens"] = 4096 + if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] ans = "" try: @@ -1278,6 +1280,8 @@ def chat_streamly(self, system, history, gen_conf): self.system = system if "max_tokens" not in gen_conf: gen_conf["max_tokens"] = 4096 + if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] ans = "" total_tokens = 0 @@ -1290,11 +1294,11 @@ def chat_streamly(self, system, history, gen_conf): **gen_conf, ) for res in response.iter_lines(): - res = res.decode("utf-8") - if "content_block_delta" in res and "data" in res: - text = json.loads(res[6:])["delta"]["text"] + if res.type == 'content_block_delta': + text = res.delta.text ans += text total_tokens += num_tokens_from_string(text) + yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e)