From 1a04e73f0dd633a7eee6792c946cc895546bb9c0 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Sun, 26 Jan 2025 13:28:35 +0800 Subject: [PATCH] Refactor for total_tokens. --- rag/llm/chat_model.py | 72 ++++++++++++++++++++------------------ rag/llm/embedding_model.py | 44 ++++++++++++++--------- rag/llm/rerank_model.py | 15 ++++++-- 3 files changed, 79 insertions(+), 52 deletions(-) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 48f234f8418..77f2714b9c2 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -53,7 +53,7 @@ def chat(self, system, history, gen_conf): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN - return ans, response.usage.total_tokens + return ans, self.total_token_count(response) except openai.APIError as e: return "**ERROR**: " + str(e), 0 @@ -75,15 +75,11 @@ def chat_streamly(self, system, history, gen_conf): resp.choices[0].delta.content = "" ans += resp.choices[0].delta.content - if not hasattr(resp, "usage") or not resp.usage: - total_tokens = ( - total_tokens - + num_tokens_from_string(resp.choices[0].delta.content) - ) - elif isinstance(resp.usage, dict): - total_tokens = resp.usage.get("total_tokens", total_tokens) + tol = self.total_token_count(resp) + if not tol: + total_tokens += num_tokens_from_string(resp.choices[0].delta.content) else: - total_tokens = resp.usage.total_tokens + total_tokens = tol if resp.choices[0].finish_reason == "length": if is_chinese(ans): @@ -97,6 +93,17 @@ def chat_streamly(self, system, history, gen_conf): yield total_tokens + def total_token_count(self, resp): + try: + return resp.usage.total_tokens + except Exception: + pass + try: + return resp["usage"]["total_tokens"] + except Exception: + pass + return 0 + class GptTurbo(Base): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): @@ -182,7 +189,7 @@ def chat(self, system, history, gen_conf): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN - return ans, response.usage.total_tokens + return ans, self.total_token_count(response) except openai.APIError as e: return "**ERROR**: " + str(e), 0 @@ -212,14 +219,11 @@ def chat_streamly(self, system, history, gen_conf): if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" ans += resp.choices[0].delta.content - total_tokens = ( - ( - total_tokens - + num_tokens_from_string(resp.choices[0].delta.content) - ) - if not hasattr(resp, "usage") - else resp.usage["total_tokens"] - ) + tol = self.total_token_count(resp) + if not tol: + total_tokens += num_tokens_from_string(resp.choices[0].delta.content) + else: + total_tokens = tol if resp.choices[0].finish_reason == "length": if is_chinese([ans]): ans += LENGTH_NOTIFICATION_CN @@ -256,7 +260,7 @@ def chat(self, system, history, gen_conf): tk_count = 0 if response.status_code == HTTPStatus.OK: ans += response.output.choices[0]['message']['content'] - tk_count += response.usage.total_tokens + tk_count += self.total_token_count(response) if response.output.choices[0].get("finish_reason", "") == "length": if is_chinese([ans]): ans += LENGTH_NOTIFICATION_CN @@ -292,7 +296,7 @@ def _chat_streamly(self, system, history, gen_conf, incremental_output=False): for resp in response: if resp.status_code == HTTPStatus.OK: ans = resp.output.choices[0]['message']['content'] - tk_count = resp.usage.total_tokens + tk_count = self.total_token_count(resp) if resp.output.choices[0].get("finish_reason", "") == "length": if is_chinese(ans): ans += LENGTH_NOTIFICATION_CN @@ -334,7 +338,7 @@ def chat(self, system, history, gen_conf): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN - return ans, response.usage.total_tokens + return ans, self.total_token_count(response) except Exception as e: return "**ERROR**: " + str(e), 0 @@ -364,9 +368,9 @@ def chat_streamly(self, system, history, gen_conf): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN - tk_count = resp.usage.total_tokens + tk_count = self.total_token_count(resp) if resp.choices[0].finish_reason == "stop": - tk_count = resp.usage.total_tokens + tk_count = self.total_token_count(resp) yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e) @@ -569,7 +573,7 @@ def chat(self, system, history, gen_conf): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN - return ans, response["usage"]["total_tokens"] + return ans, self.total_token_count(response) except Exception as e: return "**ERROR**: " + str(e), 0 @@ -603,11 +607,11 @@ def chat_streamly(self, system, history, gen_conf): if "choices" in resp and "delta" in resp["choices"][0]: text = resp["choices"][0]["delta"]["content"] ans += text - total_tokens = ( - total_tokens + num_tokens_from_string(text) - if "usage" not in resp - else resp["usage"]["total_tokens"] - ) + tol = self.total_token_count(resp) + if not tol: + total_tokens += num_tokens_from_string(text) + else: + total_tokens = tol yield ans except Exception as e: @@ -640,7 +644,7 @@ def chat(self, system, history, gen_conf): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN - return ans, response.usage.total_tokens + return ans, self.total_token_count(response) except openai.APIError as e: return "**ERROR**: " + str(e), 0 @@ -838,7 +842,7 @@ def chat_streamly(self, system, history, gen_conf): yield 0 -class GroqChat: +class GroqChat(Base): def __init__(self, key, model_name, base_url=''): from groq import Groq self.client = Groq(api_key=key) @@ -863,7 +867,7 @@ def chat(self, system, history, gen_conf): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN - return ans, response.usage.total_tokens + return ans, self.total_token_count(response) except Exception as e: return ans + "\n**ERROR**: " + str(e), 0 @@ -1255,7 +1259,7 @@ def chat(self, system, history, gen_conf): **gen_conf ).body ans = response['result'] - return ans, response["usage"]["total_tokens"] + return ans, self.total_token_count(response) except Exception as e: return ans + "\n**ERROR**: " + str(e), 0 @@ -1283,7 +1287,7 @@ def chat_streamly(self, system, history, gen_conf): for resp in response: resp = resp.body ans += resp['result'] - total_tokens = resp["usage"]["total_tokens"] + total_tokens = self.total_token_count(resp) yield ans diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 9b7408ed494..893bf65effc 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -44,11 +44,23 @@ def encode(self, texts: list): def encode_queries(self, text: str): raise NotImplementedError("Please implement encode method!") + def total_token_count(self, resp): + try: + return resp.usage.total_tokens + except Exception: + pass + try: + return resp["usage"]["total_tokens"] + except Exception: + pass + return 0 + class DefaultEmbedding(Base): _model = None _model_name = "" _model_lock = threading.Lock() + def __init__(self, key, model_name, **kwargs): """ If you have trouble downloading HuggingFace models, -_^ this might help!! @@ -115,13 +127,13 @@ def encode(self, texts: list): res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name) ress.extend([d.embedding for d in res.data]) - total_tokens += res.usage.total_tokens + total_tokens += self.total_token_count(res) return np.array(ress), total_tokens def encode_queries(self, text): res = self.client.embeddings.create(input=[truncate(text, 8191)], model=self.model_name) - return np.array(res.data[0].embedding), res.usage.total_tokens + return np.array(res.data[0].embedding), self.total_token_count(res) class LocalAIEmbed(Base): @@ -188,7 +200,7 @@ def encode(self, texts: list): for e in resp["output"]["embeddings"]: embds[e["text_index"]] = e["embedding"] res.extend(embds) - token_count += resp["usage"]["total_tokens"] + token_count += self.total_token_count(resp) return np.array(res), token_count except Exception as e: raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name) @@ -203,7 +215,7 @@ def encode_queries(self, text): text_type="query" ) return np.array(resp["output"]["embeddings"][0] - ["embedding"]), resp["usage"]["total_tokens"] + ["embedding"]), self.total_token_count(resp) except Exception: raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name) return np.array([]), 0 @@ -229,13 +241,13 @@ def encode(self, texts: list): res = self.client.embeddings.create(input=txt, model=self.model_name) arr.append(res.data[0].embedding) - tks_num += res.usage.total_tokens + tks_num += self.total_token_count(res) return np.array(arr), tks_num def encode_queries(self, text): res = self.client.embeddings.create(input=text, model=self.model_name) - return np.array(res.data[0].embedding), res.usage.total_tokens + return np.array(res.data[0].embedding), self.total_token_count(res) class OllamaEmbed(Base): @@ -318,13 +330,13 @@ def encode(self, texts: list): for i in range(0, len(texts), batch_size): res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name) ress.extend([d.embedding for d in res.data]) - total_tokens += res.usage.total_tokens + total_tokens += self.total_token_count(res) return np.array(ress), total_tokens def encode_queries(self, text): res = self.client.embeddings.create(input=[text], model=self.model_name) - return np.array(res.data[0].embedding), res.usage.total_tokens + return np.array(res.data[0].embedding), self.total_token_count(res) class YoudaoEmbed(Base): @@ -383,7 +395,7 @@ def encode(self, texts: list): } res = requests.post(self.base_url, headers=self.headers, json=data).json() ress.extend([d["embedding"] for d in res["data"]]) - token_count += res["usage"]["total_tokens"] + token_count += self.total_token_count(res) return np.array(ress), token_count def encode_queries(self, text): @@ -447,13 +459,13 @@ def encode(self, texts: list): res = self.client.embeddings(input=texts[i:i + batch_size], model=self.model_name) ress.extend([d.embedding for d in res.data]) - token_count += res.usage.total_tokens + token_count += self.total_token_count(res) return np.array(ress), token_count def encode_queries(self, text): res = self.client.embeddings(input=[truncate(text, 8196)], model=self.model_name) - return np.array(res.data[0].embedding), res.usage.total_tokens + return np.array(res.data[0].embedding), self.total_token_count(res) class BedrockEmbed(Base): @@ -565,7 +577,7 @@ def encode(self, texts: list): } res = requests.post(self.base_url, headers=self.headers, json=payload).json() ress.extend([d["embedding"] for d in res["data"]]) - token_count += res["usage"]["total_tokens"] + token_count += self.total_token_count(res) return np.array(ress), token_count def encode_queries(self, text): @@ -677,7 +689,7 @@ def encode(self, texts: list): if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch): raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}") ress.extend([d["embedding"] for d in res["data"]]) - token_count += res["usage"]["total_tokens"] + token_count += self.total_token_count(res) return np.array(ress), token_count def encode_queries(self, text): @@ -689,7 +701,7 @@ def encode_queries(self, text): res = requests.post(self.base_url, json=payload, headers=self.headers).json() if "data" not in res or not isinstance(res["data"], list) or len(res["data"])!= 1: raise ValueError(f"SILICONFLOWEmbed.encode_queries got invalid response from {self.base_url}") - return np.array(res["data"][0]["embedding"]), res["usage"]["total_tokens"] + return np.array(res["data"][0]["embedding"]), self.total_token_count(res) class ReplicateEmbed(Base): @@ -727,14 +739,14 @@ def encode(self, texts: list, batch_size=16): res = self.client.do(model=self.model_name, texts=texts).body return ( np.array([r["embedding"] for r in res["data"]]), - res["usage"]["total_tokens"], + self.total_token_count(res), ) def encode_queries(self, text): res = self.client.do(model=self.model_name, texts=[text]).body return ( np.array([r["embedding"] for r in res["data"]]), - res["usage"]["total_tokens"], + self.total_token_count(res), ) diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 406faf2f9b2..c57e7f84369 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -42,6 +42,17 @@ def __init__(self, key, model_name): def similarity(self, query: str, texts: list): raise NotImplementedError("Please implement encode method!") + def total_token_count(self, resp): + try: + return resp.usage.total_tokens + except Exception: + pass + try: + return resp["usage"]["total_tokens"] + except Exception: + pass + return 0 + class DefaultRerank(Base): _model = None @@ -115,7 +126,7 @@ def similarity(self, query: str, texts: list): rank = np.zeros(len(texts), dtype=float) for d in res["results"]: rank[d["index"]] = d["relevance_score"] - return rank, res["usage"]["total_tokens"] + return rank, self.total_token_count(res) class YoudaoRerank(DefaultRerank): @@ -417,7 +428,7 @@ def similarity(self, query: str, texts: list): rank = np.zeros(len(texts), dtype=float) for d in res["results"]: rank[d["index"]] = d["relevance_score"] - return rank, res["usage"]["total_tokens"] + return rank, self.total_token_count(res) class VoyageRerank(Base):