Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor for total_tokens. #4652

Merged
merged 1 commit into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 38 additions & 34 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def chat(self, system, history, gen_conf):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, response.usage.total_tokens
return ans, self.total_token_count(response)
except openai.APIError as e:
return "**ERROR**: " + str(e), 0

Expand All @@ -75,15 +75,11 @@ def chat_streamly(self, system, history, gen_conf):
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content

if not hasattr(resp, "usage") or not resp.usage:
total_tokens = (
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
tol = self.total_token_count(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = resp.usage.total_tokens
total_tokens = tol

if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
Expand All @@ -97,6 +93,17 @@ def chat_streamly(self, system, history, gen_conf):

yield total_tokens

def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0


class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
Expand Down Expand Up @@ -182,7 +189,7 @@ def chat(self, system, history, gen_conf):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, response.usage.total_tokens
return ans, self.total_token_count(response)
except openai.APIError as e:
return "**ERROR**: " + str(e), 0

Expand Down Expand Up @@ -212,14 +219,11 @@ def chat_streamly(self, system, history, gen_conf):
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
total_tokens = (
(
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
if not hasattr(resp, "usage")
else resp.usage["total_tokens"]
)
tol = self.total_token_count(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
if resp.choices[0].finish_reason == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
Expand Down Expand Up @@ -256,7 +260,7 @@ def chat(self, system, history, gen_conf):
tk_count = 0
if response.status_code == HTTPStatus.OK:
ans += response.output.choices[0]['message']['content']
tk_count += response.usage.total_tokens
tk_count += self.total_token_count(response)
if response.output.choices[0].get("finish_reason", "") == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
Expand Down Expand Up @@ -292,7 +296,7 @@ def _chat_streamly(self, system, history, gen_conf, incremental_output=False):
for resp in response:
if resp.status_code == HTTPStatus.OK:
ans = resp.output.choices[0]['message']['content']
tk_count = resp.usage.total_tokens
tk_count = self.total_token_count(resp)
if resp.output.choices[0].get("finish_reason", "") == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
Expand Down Expand Up @@ -334,7 +338,7 @@ def chat(self, system, history, gen_conf):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, response.usage.total_tokens
return ans, self.total_token_count(response)
except Exception as e:
return "**ERROR**: " + str(e), 0

Expand Down Expand Up @@ -364,9 +368,9 @@ def chat_streamly(self, system, history, gen_conf):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
tk_count = resp.usage.total_tokens
tk_count = self.total_token_count(resp)
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
tk_count = self.total_token_count(resp)
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
Expand Down Expand Up @@ -569,7 +573,7 @@ def chat(self, system, history, gen_conf):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, response["usage"]["total_tokens"]
return ans, self.total_token_count(response)
except Exception as e:
return "**ERROR**: " + str(e), 0

Expand Down Expand Up @@ -603,11 +607,11 @@ def chat_streamly(self, system, history, gen_conf):
if "choices" in resp and "delta" in resp["choices"][0]:
text = resp["choices"][0]["delta"]["content"]
ans += text
total_tokens = (
total_tokens + num_tokens_from_string(text)
if "usage" not in resp
else resp["usage"]["total_tokens"]
)
tol = self.total_token_count(resp)
if not tol:
total_tokens += num_tokens_from_string(text)
else:
total_tokens = tol
yield ans

except Exception as e:
Expand Down Expand Up @@ -640,7 +644,7 @@ def chat(self, system, history, gen_conf):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, response.usage.total_tokens
return ans, self.total_token_count(response)
except openai.APIError as e:
return "**ERROR**: " + str(e), 0

Expand Down Expand Up @@ -838,7 +842,7 @@ def chat_streamly(self, system, history, gen_conf):
yield 0


class GroqChat:
class GroqChat(Base):
def __init__(self, key, model_name, base_url=''):
from groq import Groq
self.client = Groq(api_key=key)
Expand All @@ -863,7 +867,7 @@ def chat(self, system, history, gen_conf):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, response.usage.total_tokens
return ans, self.total_token_count(response)
except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0

Expand Down Expand Up @@ -1255,7 +1259,7 @@ def chat(self, system, history, gen_conf):
**gen_conf
).body
ans = response['result']
return ans, response["usage"]["total_tokens"]
return ans, self.total_token_count(response)

except Exception as e:
return ans + "\n**ERROR**: " + str(e), 0
Expand Down Expand Up @@ -1283,7 +1287,7 @@ def chat_streamly(self, system, history, gen_conf):
for resp in response:
resp = resp.body
ans += resp['result']
total_tokens = resp["usage"]["total_tokens"]
total_tokens = self.total_token_count(resp)

yield ans

Expand Down
44 changes: 28 additions & 16 deletions rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,23 @@ def encode(self, texts: list):
def encode_queries(self, text: str):
raise NotImplementedError("Please implement encode method!")

def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0


class DefaultEmbedding(Base):
_model = None
_model_name = ""
_model_lock = threading.Lock()

def __init__(self, key, model_name, **kwargs):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
Expand Down Expand Up @@ -115,13 +127,13 @@ def encode(self, texts: list):
res = self.client.embeddings.create(input=texts[i:i + batch_size],
model=self.model_name)
ress.extend([d.embedding for d in res.data])
total_tokens += res.usage.total_tokens
total_tokens += self.total_token_count(res)
return np.array(ress), total_tokens

def encode_queries(self, text):
res = self.client.embeddings.create(input=[truncate(text, 8191)],
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
return np.array(res.data[0].embedding), self.total_token_count(res)


class LocalAIEmbed(Base):
Expand Down Expand Up @@ -188,7 +200,7 @@ def encode(self, texts: list):
for e in resp["output"]["embeddings"]:
embds[e["text_index"]] = e["embedding"]
res.extend(embds)
token_count += resp["usage"]["total_tokens"]
token_count += self.total_token_count(resp)
return np.array(res), token_count
except Exception as e:
raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name)
Expand All @@ -203,7 +215,7 @@ def encode_queries(self, text):
text_type="query"
)
return np.array(resp["output"]["embeddings"][0]
["embedding"]), resp["usage"]["total_tokens"]
["embedding"]), self.total_token_count(resp)
except Exception:
raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name)
return np.array([]), 0
Expand All @@ -229,13 +241,13 @@ def encode(self, texts: list):
res = self.client.embeddings.create(input=txt,
model=self.model_name)
arr.append(res.data[0].embedding)
tks_num += res.usage.total_tokens
tks_num += self.total_token_count(res)
return np.array(arr), tks_num

def encode_queries(self, text):
res = self.client.embeddings.create(input=text,
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
return np.array(res.data[0].embedding), self.total_token_count(res)


class OllamaEmbed(Base):
Expand Down Expand Up @@ -318,13 +330,13 @@ def encode(self, texts: list):
for i in range(0, len(texts), batch_size):
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
ress.extend([d.embedding for d in res.data])
total_tokens += res.usage.total_tokens
total_tokens += self.total_token_count(res)
return np.array(ress), total_tokens

def encode_queries(self, text):
res = self.client.embeddings.create(input=[text],
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
return np.array(res.data[0].embedding), self.total_token_count(res)


class YoudaoEmbed(Base):
Expand Down Expand Up @@ -383,7 +395,7 @@ def encode(self, texts: list):
}
res = requests.post(self.base_url, headers=self.headers, json=data).json()
ress.extend([d["embedding"] for d in res["data"]])
token_count += res["usage"]["total_tokens"]
token_count += self.total_token_count(res)
return np.array(ress), token_count

def encode_queries(self, text):
Expand Down Expand Up @@ -447,13 +459,13 @@ def encode(self, texts: list):
res = self.client.embeddings(input=texts[i:i + batch_size],
model=self.model_name)
ress.extend([d.embedding for d in res.data])
token_count += res.usage.total_tokens
token_count += self.total_token_count(res)
return np.array(ress), token_count

def encode_queries(self, text):
res = self.client.embeddings(input=[truncate(text, 8196)],
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
return np.array(res.data[0].embedding), self.total_token_count(res)


class BedrockEmbed(Base):
Expand Down Expand Up @@ -565,7 +577,7 @@ def encode(self, texts: list):
}
res = requests.post(self.base_url, headers=self.headers, json=payload).json()
ress.extend([d["embedding"] for d in res["data"]])
token_count += res["usage"]["total_tokens"]
token_count += self.total_token_count(res)
return np.array(ress), token_count

def encode_queries(self, text):
Expand Down Expand Up @@ -677,7 +689,7 @@ def encode(self, texts: list):
if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch):
raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}")
ress.extend([d["embedding"] for d in res["data"]])
token_count += res["usage"]["total_tokens"]
token_count += self.total_token_count(res)
return np.array(ress), token_count

def encode_queries(self, text):
Expand All @@ -689,7 +701,7 @@ def encode_queries(self, text):
res = requests.post(self.base_url, json=payload, headers=self.headers).json()
if "data" not in res or not isinstance(res["data"], list) or len(res["data"])!= 1:
raise ValueError(f"SILICONFLOWEmbed.encode_queries got invalid response from {self.base_url}")
return np.array(res["data"][0]["embedding"]), res["usage"]["total_tokens"]
return np.array(res["data"][0]["embedding"]), self.total_token_count(res)


class ReplicateEmbed(Base):
Expand Down Expand Up @@ -727,14 +739,14 @@ def encode(self, texts: list, batch_size=16):
res = self.client.do(model=self.model_name, texts=texts).body
return (
np.array([r["embedding"] for r in res["data"]]),
res["usage"]["total_tokens"],
self.total_token_count(res),
)

def encode_queries(self, text):
res = self.client.do(model=self.model_name, texts=[text]).body
return (
np.array([r["embedding"] for r in res["data"]]),
res["usage"]["total_tokens"],
self.total_token_count(res),
)


Expand Down
15 changes: 13 additions & 2 deletions rag/llm/rerank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def __init__(self, key, model_name):
def similarity(self, query: str, texts: list):
raise NotImplementedError("Please implement encode method!")

def total_token_count(self, resp):
try:
return resp.usage.total_tokens
except Exception:
pass
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
return 0


class DefaultRerank(Base):
_model = None
Expand Down Expand Up @@ -115,7 +126,7 @@ def similarity(self, query: str, texts: list):
rank = np.zeros(len(texts), dtype=float)
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
return rank, res["usage"]["total_tokens"]
return rank, self.total_token_count(res)


class YoudaoRerank(DefaultRerank):
Expand Down Expand Up @@ -417,7 +428,7 @@ def similarity(self, query: str, texts: list):
rank = np.zeros(len(texts), dtype=float)
for d in res["results"]:
rank[d["index"]] = d["relevance_score"]
return rank, res["usage"]["total_tokens"]
return rank, self.total_token_count(res)


class VoyageRerank(Base):
Expand Down