Skip to content

Commit

Permalink
updated local APIs to return usage info (#585)
Browse files Browse the repository at this point in the history
* updated APIs to return usage info

* tested all endpoints
  • Loading branch information
cpacker authored Dec 14, 2023
1 parent 0ca0149 commit 986567a
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 75 deletions.
65 changes: 43 additions & 22 deletions memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@
from memgpt.local_llm.vllm.api import get_vllm_completion
from memgpt.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
from memgpt.local_llm.constants import DEFAULT_WRAPPER
from memgpt.local_llm.utils import get_available_wrappers
from memgpt.local_llm.utils import get_available_wrappers, count_tokens
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
from memgpt.errors import LocalLLMConnectionError, LocalLLMError
from memgpt.constants import CLI_WARNING_PREFIX

DEBUG = False
# DEBUG = True

has_shown_warning = False


Expand All @@ -38,6 +35,8 @@ def get_chat_completion(
endpoint=None,
endpoint_type=None,
):
from memgpt.utils import printd

assert context_window is not None, "Local LLM calls need the context length to be explicitly set"
assert endpoint is not None, "Local LLM calls need the endpoint (eg http://localendpoint:1234) to be explicitly set"
assert endpoint_type is not None, "Local LLM calls need the endpoint type (eg webui) to be explicitly set"
Expand Down Expand Up @@ -78,28 +77,27 @@ def get_chat_completion(
# First step: turn the message sequence into a prompt that the model expects
try:
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
if DEBUG:
print(prompt)
printd(prompt)
except Exception as e:
raise LocalLLMError(
f"Failed to convert ChatCompletion messages into prompt string with wrapper {str(llm_wrapper)} - error: {str(e)}"
)

try:
if endpoint_type == "webui":
result = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name)
result, usage = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name)
elif endpoint_type == "webui-legacy":
result = get_webui_completion_legacy(endpoint, prompt, context_window, grammar=grammar_name)
result, usage = get_webui_completion_legacy(endpoint, prompt, context_window, grammar=grammar_name)
elif endpoint_type == "lmstudio":
result = get_lmstudio_completion(endpoint, prompt, context_window)
result, usage = get_lmstudio_completion(endpoint, prompt, context_window)
elif endpoint_type == "llamacpp":
result = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
result, usage = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
elif endpoint_type == "koboldcpp":
result = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
result, usage = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name)
elif endpoint_type == "ollama":
result = get_ollama_completion(endpoint, model, prompt, context_window)
result, usage = get_ollama_completion(endpoint, model, prompt, context_window)
elif endpoint_type == "vllm":
result = get_vllm_completion(endpoint, model, prompt, context_window, user)
result, usage = get_vllm_completion(endpoint, model, prompt, context_window, user)
else:
raise LocalLLMError(
f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
Expand All @@ -109,32 +107,55 @@ def get_chat_completion(

if result is None or result == "":
raise LocalLLMError(f"Got back an empty response string from {endpoint}")
if DEBUG:
print(f"Raw LLM output:\n{result}")
printd(f"Raw LLM output:\n{result}")

try:
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
if DEBUG:
print(json.dumps(chat_completion_result, indent=2))
printd(json.dumps(chat_completion_result, indent=2))
except Exception as e:
raise LocalLLMError(f"Failed to parse JSON from local LLM response - error: {str(e)}")

# Fill in potential missing usage information (used for tracking token use)
if not ("prompt_tokens" in usage and "completion_tokens" in usage and "total_tokens" in usage):
raise LocalLLMError(f"usage dict in response was missing fields ({usage})")

if usage["prompt_tokens"] is None:
printd(f"usage dict was missing prompt_tokens, computing on-the-fly...")
usage["prompt_tokens"] = count_tokens(prompt)

# NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing
usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result))
"""
if usage["completion_tokens"] is None:
printd(f"usage dict was missing completion_tokens, computing on-the-fly...")
# chat_completion_result is dict with 'role' and 'content'
# token counter wants a string
usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result))
"""

# NOTE: this is the token count that matters most
if usage["total_tokens"] is None:
printd(f"usage dict was missing total_tokens, computing on-the-fly...")
usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"]

# unpack with response.choices[0].message.content
response = Box(
{
"model": model,
"choices": [
{
"message": chat_completion_result,
"finish_reason": "stop", # TODO vary based on backend response
# TODO vary 'finish_reason' based on backend response
# NOTE if we got this far (parsing worked), then it's probably OK to treat this as a stop
"finish_reason": "stop",
}
],
"usage": {
# TODO fix, actually use real info
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"prompt_tokens": usage["prompt_tokens"],
"completion_tokens": usage["completion_tokens"],
"total_tokens": usage["total_tokens"],
},
}
)
printd(response)
return response
25 changes: 18 additions & 7 deletions memgpt/local_llm/koboldcpp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from ..utils import load_grammar_file, count_tokens

KOBOLDCPP_API_SUFFIX = "/api/v1/generate"
DEBUG = False
# DEBUG = True


def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
"""See https://lite.koboldai.net/koboldcpp_api for API spec"""
from memgpt.utils import printd

prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
Expand All @@ -34,10 +34,9 @@ def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, set
URI = urljoin(endpoint.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["results"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["results"][0]["text"]
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
Expand All @@ -48,4 +47,16 @@ def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, set
# TODO handle gracefully
raise

return result
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
# KoboldCpp doesn't return anything?
# https://lite.koboldai.net/koboldcpp_api#/v1/post_v1_generate
completion_tokens = None
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}

return result, usage
23 changes: 16 additions & 7 deletions memgpt/local_llm/llamacpp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from ..utils import load_grammar_file, count_tokens

LLAMACPP_API_SUFFIX = "/completion"
DEBUG = False
# DEBUG = True


def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE):
"""See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server"""
from memgpt.utils import printd

prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
Expand All @@ -33,10 +33,9 @@ def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, sett
URI = urljoin(endpoint.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["content"]
if DEBUG:
print(f"json API response.text: {result}")
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["content"]
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
Expand All @@ -47,4 +46,14 @@ def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, sett
# TODO handle gracefully
raise

return result
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
completion_tokens = result_full.get("tokens_predicted", None)
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens, # can grab from "tokens_evaluated", but it's usually wrong (set to 0)
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}

return result, usage
53 changes: 38 additions & 15 deletions memgpt/local_llm/lmstudio/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,42 @@

LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions"
LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
DEBUG = False


# TODO move to "completions" by default, not "chat"
def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, api="chat"):
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""
from memgpt.utils import printd

prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")

# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["max_tokens"] = context_window

# Uses the ChatCompletions API style
# Seems to work better, probably because it's applying some extra settings under-the-hood?
if api == "chat":
# Uses the ChatCompletions API style
# Seems to work better, probably because it's applying some extra settings under-the-hood?
URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/"))

# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["max_tokens"] = context_window

# Put the entire completion string inside the first message
message_structure = [{"role": "user", "content": prompt}]
request["messages"] = message_structure

# Uses basic string completions (string in, string out)
# Does not work as well as ChatCompletions for some reason
elif api == "completions":
# Uses basic string completions (string in, string out)
# Does not work as well as ChatCompletions for some reason
URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/"))

# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["max_tokens"] = context_window

# Standard completions format, formatted string goes in prompt
request["prompt"] = prompt

else:
raise ValueError(api)

Expand All @@ -40,13 +52,14 @@ def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, a
try:
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
if api == "chat":
result = result["choices"][0]["message"]["content"]
result = result_full["choices"][0]["message"]["content"]
usage = result_full.get("usage", None)
elif api == "completions":
result = result["choices"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
result = result_full["choices"][0]["text"]
usage = result_full.get("usage", None)
else:
# Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"}
if "context length" in str(response.text).lower():
Expand All @@ -62,4 +75,14 @@ def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, a
# TODO handle gracefully
raise

return result
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
completion_tokens = usage.get("completion_tokens", None) if usage is not None else None
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0)
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}

return result, usage
24 changes: 18 additions & 6 deletions memgpt/local_llm/ollama/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from ...errors import LocalLLMError

OLLAMA_API_SUFFIX = "/api/generate"
DEBUG = False


def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMPLE, grammar=None):
"""See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server"""
from memgpt.utils import printd

prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
Expand Down Expand Up @@ -39,10 +40,10 @@ def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMP
URI = urljoin(endpoint.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["response"]
if DEBUG:
print(f"json API response.text: {result}")
# https://github.com/jmorganca/ollama/blob/main/docs/api.md
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["response"]
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
Expand All @@ -53,4 +54,15 @@ def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMP
# TODO handle gracefully
raise

return result
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
# https://github.com/jmorganca/ollama/blob/main/docs/api.md#response
completion_tokens = result_full.get("eval_count", None)
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens, # can also grab from "prompt_eval_count"
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}

return result, usage
23 changes: 17 additions & 6 deletions memgpt/local_llm/vllm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from ..utils import load_grammar_file, count_tokens

WEBUI_API_SUFFIX = "/v1/completions"
DEBUG = False


def get_vllm_completion(endpoint, model, prompt, context_window, user, settings={}, grammar=None):
"""https://github.com/vllm-project/vllm/blob/main/examples/api_client.py"""
from memgpt.utils import printd

prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")
Expand All @@ -36,10 +37,10 @@ def get_vllm_completion(endpoint, model, prompt, context_window, user, settings=
URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["choices"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["choices"][0]["text"]
usage = result_full.get("usage", None)
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
Expand All @@ -50,4 +51,14 @@ def get_vllm_completion(endpoint, model, prompt, context_window, user, settings=
# TODO handle gracefully
raise

return result
# Pass usage statistics back to main thread
# These are used to compute memory warning messages
completion_tokens = usage.get("completion_tokens", None) if usage is not None else None
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0)
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}

return result, usage
Loading

0 comments on commit 986567a

Please sign in to comment.