From 1042197942df022dca2f857c8a701843b5df0e29 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Wed, 13 Dec 2023 21:11:20 -0800 Subject: [PATCH] updated local APIs to return usage info (#585) * updated APIs to return usage info * tested all endpoints --- memgpt/local_llm/chat_completion_proxy.py | 65 +++++++++++++++-------- memgpt/local_llm/koboldcpp/api.py | 25 ++++++--- memgpt/local_llm/llamacpp/api.py | 23 +++++--- memgpt/local_llm/lmstudio/api.py | 53 ++++++++++++------ memgpt/local_llm/ollama/api.py | 24 ++++++--- memgpt/local_llm/vllm/api.py | 23 +++++--- memgpt/local_llm/webui/api.py | 23 +++++--- memgpt/local_llm/webui/legacy_api.py | 21 +++++--- 8 files changed, 182 insertions(+), 75 deletions(-) diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index b4fd794fe1..dae3d9ca07 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -15,14 +15,11 @@ from memgpt.local_llm.vllm.api import get_vllm_completion from memgpt.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper from memgpt.local_llm.constants import DEFAULT_WRAPPER -from memgpt.local_llm.utils import get_available_wrappers +from memgpt.local_llm.utils import get_available_wrappers, count_tokens from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE from memgpt.errors import LocalLLMConnectionError, LocalLLMError from memgpt.constants import CLI_WARNING_PREFIX -DEBUG = False -# DEBUG = True - has_shown_warning = False @@ -38,6 +35,8 @@ def get_chat_completion( endpoint=None, endpoint_type=None, ): + from memgpt.utils import printd + assert context_window is not None, "Local LLM calls need the context length to be explicitly set" assert endpoint is not None, "Local LLM calls need the endpoint (eg http://localendpoint:1234) to be explicitly set" assert endpoint_type is not None, "Local LLM calls need the endpoint type (eg webui) to be explicitly set" @@ -78,8 +77,7 @@ def get_chat_completion( # First step: turn the message sequence into a prompt that the model expects try: prompt = llm_wrapper.chat_completion_to_prompt(messages, functions) - if DEBUG: - print(prompt) + printd(prompt) except Exception as e: raise LocalLLMError( f"Failed to convert ChatCompletion messages into prompt string with wrapper {str(llm_wrapper)} - error: {str(e)}" @@ -87,19 +85,19 @@ def get_chat_completion( try: if endpoint_type == "webui": - result = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name) + result, usage = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name) elif endpoint_type == "webui-legacy": - result = get_webui_completion_legacy(endpoint, prompt, context_window, grammar=grammar_name) + result, usage = get_webui_completion_legacy(endpoint, prompt, context_window, grammar=grammar_name) elif endpoint_type == "lmstudio": - result = get_lmstudio_completion(endpoint, prompt, context_window) + result, usage = get_lmstudio_completion(endpoint, prompt, context_window) elif endpoint_type == "llamacpp": - result = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name) + result, usage = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name) elif endpoint_type == "koboldcpp": - result = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name) + result, usage = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name) elif endpoint_type == "ollama": - result = get_ollama_completion(endpoint, model, prompt, context_window) + result, usage = get_ollama_completion(endpoint, model, prompt, context_window) elif endpoint_type == "vllm": - result = get_vllm_completion(endpoint, model, prompt, context_window, user) + result, usage = get_vllm_completion(endpoint, model, prompt, context_window, user) else: raise LocalLLMError( f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)" @@ -109,16 +107,37 @@ def get_chat_completion( if result is None or result == "": raise LocalLLMError(f"Got back an empty response string from {endpoint}") - if DEBUG: - print(f"Raw LLM output:\n{result}") + printd(f"Raw LLM output:\n{result}") try: chat_completion_result = llm_wrapper.output_to_chat_completion_response(result) - if DEBUG: - print(json.dumps(chat_completion_result, indent=2)) + printd(json.dumps(chat_completion_result, indent=2)) except Exception as e: raise LocalLLMError(f"Failed to parse JSON from local LLM response - error: {str(e)}") + # Fill in potential missing usage information (used for tracking token use) + if not ("prompt_tokens" in usage and "completion_tokens" in usage and "total_tokens" in usage): + raise LocalLLMError(f"usage dict in response was missing fields ({usage})") + + if usage["prompt_tokens"] is None: + printd(f"usage dict was missing prompt_tokens, computing on-the-fly...") + usage["prompt_tokens"] = count_tokens(prompt) + + # NOTE: we should compute on-the-fly anyways since we might have to correct for errors during JSON parsing + usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result)) + """ + if usage["completion_tokens"] is None: + printd(f"usage dict was missing completion_tokens, computing on-the-fly...") + # chat_completion_result is dict with 'role' and 'content' + # token counter wants a string + usage["completion_tokens"] = count_tokens(json.dumps(chat_completion_result)) + """ + + # NOTE: this is the token count that matters most + if usage["total_tokens"] is None: + printd(f"usage dict was missing total_tokens, computing on-the-fly...") + usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"] + # unpack with response.choices[0].message.content response = Box( { @@ -126,15 +145,17 @@ def get_chat_completion( "choices": [ { "message": chat_completion_result, - "finish_reason": "stop", # TODO vary based on backend response + # TODO vary 'finish_reason' based on backend response + # NOTE if we got this far (parsing worked), then it's probably OK to treat this as a stop + "finish_reason": "stop", } ], "usage": { - # TODO fix, actually use real info - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0, + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], }, } ) + printd(response) return response diff --git a/memgpt/local_llm/koboldcpp/api.py b/memgpt/local_llm/koboldcpp/api.py index 0d00001745..ecf259c0b6 100644 --- a/memgpt/local_llm/koboldcpp/api.py +++ b/memgpt/local_llm/koboldcpp/api.py @@ -6,12 +6,12 @@ from ..utils import load_grammar_file, count_tokens KOBOLDCPP_API_SUFFIX = "/api/v1/generate" -DEBUG = False -# DEBUG = True def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE): """See https://lite.koboldai.net/koboldcpp_api for API spec""" + from memgpt.utils import printd + prompt_tokens = count_tokens(prompt) if prompt_tokens > context_window: raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") @@ -34,10 +34,9 @@ def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, set URI = urljoin(endpoint.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/")) response = requests.post(URI, json=request) if response.status_code == 200: - result = response.json() - result = result["results"][0]["text"] - if DEBUG: - print(f"json API response.text: {result}") + result_full = response.json() + printd(f"JSON API response:\n{result_full}") + result = result_full["results"][0]["text"] else: raise Exception( f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." @@ -48,4 +47,16 @@ def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, set # TODO handle gracefully raise - return result + # Pass usage statistics back to main thread + # These are used to compute memory warning messages + # KoboldCpp doesn't return anything? + # https://lite.koboldai.net/koboldcpp_api#/v1/post_v1_generate + completion_tokens = None + total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None + usage = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + return result, usage diff --git a/memgpt/local_llm/llamacpp/api.py b/memgpt/local_llm/llamacpp/api.py index 4b3c693b4d..649ec67c04 100644 --- a/memgpt/local_llm/llamacpp/api.py +++ b/memgpt/local_llm/llamacpp/api.py @@ -6,12 +6,12 @@ from ..utils import load_grammar_file, count_tokens LLAMACPP_API_SUFFIX = "/completion" -DEBUG = False -# DEBUG = True def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE): """See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server""" + from memgpt.utils import printd + prompt_tokens = count_tokens(prompt) if prompt_tokens > context_window: raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") @@ -33,10 +33,9 @@ def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, sett URI = urljoin(endpoint.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/")) response = requests.post(URI, json=request) if response.status_code == 200: - result = response.json() - result = result["content"] - if DEBUG: - print(f"json API response.text: {result}") + result_full = response.json() + printd(f"JSON API response:\n{result_full}") + result = result_full["content"] else: raise Exception( f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." @@ -47,4 +46,14 @@ def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, sett # TODO handle gracefully raise - return result + # Pass usage statistics back to main thread + # These are used to compute memory warning messages + completion_tokens = result_full.get("tokens_predicted", None) + total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None + usage = { + "prompt_tokens": prompt_tokens, # can grab from "tokens_evaluated", but it's usually wrong (set to 0) + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + return result, usage diff --git a/memgpt/local_llm/lmstudio/api.py b/memgpt/local_llm/lmstudio/api.py index e643b17a8e..183e440e9b 100644 --- a/memgpt/local_llm/lmstudio/api.py +++ b/memgpt/local_llm/lmstudio/api.py @@ -7,30 +7,42 @@ LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions" LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions" -DEBUG = False +# TODO move to "completions" by default, not "chat" def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, api="chat"): """Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client""" + from memgpt.utils import printd + prompt_tokens = count_tokens(prompt) if prompt_tokens > context_window: raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") - # Settings for the generation, includes the prompt + stop tokens, max length, etc - request = settings - request["max_tokens"] = context_window - + # Uses the ChatCompletions API style + # Seems to work better, probably because it's applying some extra settings under-the-hood? if api == "chat": - # Uses the ChatCompletions API style - # Seems to work better, probably because it's applying some extra settings under-the-hood? URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/")) + + # Settings for the generation, includes the prompt + stop tokens, max length, etc + request = settings + request["max_tokens"] = context_window + + # Put the entire completion string inside the first message message_structure = [{"role": "user", "content": prompt}] request["messages"] = message_structure + + # Uses basic string completions (string in, string out) + # Does not work as well as ChatCompletions for some reason elif api == "completions": - # Uses basic string completions (string in, string out) - # Does not work as well as ChatCompletions for some reason URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/")) + + # Settings for the generation, includes the prompt + stop tokens, max length, etc + request = settings + request["max_tokens"] = context_window + + # Standard completions format, formatted string goes in prompt request["prompt"] = prompt + else: raise ValueError(api) @@ -40,13 +52,14 @@ def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, a try: response = requests.post(URI, json=request) if response.status_code == 200: - result = response.json() + result_full = response.json() + printd(f"JSON API response:\n{result_full}") if api == "chat": - result = result["choices"][0]["message"]["content"] + result = result_full["choices"][0]["message"]["content"] + usage = result_full.get("usage", None) elif api == "completions": - result = result["choices"][0]["text"] - if DEBUG: - print(f"json API response.text: {result}") + result = result_full["choices"][0]["text"] + usage = result_full.get("usage", None) else: # Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"} if "context length" in str(response.text).lower(): @@ -62,4 +75,14 @@ def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, a # TODO handle gracefully raise - return result + # Pass usage statistics back to main thread + # These are used to compute memory warning messages + completion_tokens = usage.get("completion_tokens", None) if usage is not None else None + total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None + usage = { + "prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0) + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + return result, usage diff --git a/memgpt/local_llm/ollama/api.py b/memgpt/local_llm/ollama/api.py index d2b959f4e6..323113869f 100644 --- a/memgpt/local_llm/ollama/api.py +++ b/memgpt/local_llm/ollama/api.py @@ -7,11 +7,12 @@ from ...errors import LocalLLMError OLLAMA_API_SUFFIX = "/api/generate" -DEBUG = False def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMPLE, grammar=None): """See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server""" + from memgpt.utils import printd + prompt_tokens = count_tokens(prompt) if prompt_tokens > context_window: raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") @@ -39,10 +40,10 @@ def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMP URI = urljoin(endpoint.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/")) response = requests.post(URI, json=request) if response.status_code == 200: - result = response.json() - result = result["response"] - if DEBUG: - print(f"json API response.text: {result}") + # https://github.com/jmorganca/ollama/blob/main/docs/api.md + result_full = response.json() + printd(f"JSON API response:\n{result_full}") + result = result_full["response"] else: raise Exception( f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." @@ -53,4 +54,15 @@ def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMP # TODO handle gracefully raise - return result + # Pass usage statistics back to main thread + # These are used to compute memory warning messages + # https://github.com/jmorganca/ollama/blob/main/docs/api.md#response + completion_tokens = result_full.get("eval_count", None) + total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None + usage = { + "prompt_tokens": prompt_tokens, # can also grab from "prompt_eval_count" + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + return result, usage diff --git a/memgpt/local_llm/vllm/api.py b/memgpt/local_llm/vllm/api.py index 6329cafeb7..5033d04bd7 100644 --- a/memgpt/local_llm/vllm/api.py +++ b/memgpt/local_llm/vllm/api.py @@ -5,11 +5,12 @@ from ..utils import load_grammar_file, count_tokens WEBUI_API_SUFFIX = "/v1/completions" -DEBUG = False def get_vllm_completion(endpoint, model, prompt, context_window, user, settings={}, grammar=None): """https://github.com/vllm-project/vllm/blob/main/examples/api_client.py""" + from memgpt.utils import printd + prompt_tokens = count_tokens(prompt) if prompt_tokens > context_window: raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") @@ -36,10 +37,10 @@ def get_vllm_completion(endpoint, model, prompt, context_window, user, settings= URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/")) response = requests.post(URI, json=request) if response.status_code == 200: - result = response.json() - result = result["choices"][0]["text"] - if DEBUG: - print(f"json API response.text: {result}") + result_full = response.json() + printd(f"JSON API response:\n{result_full}") + result = result_full["choices"][0]["text"] + usage = result_full.get("usage", None) else: raise Exception( f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." @@ -50,4 +51,14 @@ def get_vllm_completion(endpoint, model, prompt, context_window, user, settings= # TODO handle gracefully raise - return result + # Pass usage statistics back to main thread + # These are used to compute memory warning messages + completion_tokens = usage.get("completion_tokens", None) if usage is not None else None + total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None + usage = { + "prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0) + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + return result, usage diff --git a/memgpt/local_llm/webui/api.py b/memgpt/local_llm/webui/api.py index 3fb88f0140..8fe9e8517d 100644 --- a/memgpt/local_llm/webui/api.py +++ b/memgpt/local_llm/webui/api.py @@ -6,11 +6,12 @@ from ..utils import load_grammar_file, count_tokens WEBUI_API_SUFFIX = "/v1/completions" -DEBUG = False def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None): """Compatibility for the new OpenAI API: https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples""" + from memgpt.utils import printd + prompt_tokens = count_tokens(prompt) if prompt_tokens > context_window: raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") @@ -33,10 +34,10 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/")) response = requests.post(URI, json=request) if response.status_code == 200: - result = response.json() - result = result["choices"][0]["text"] - if DEBUG: - print(f"json API response.text: {result}") + result_full = response.json() + printd(f"JSON API response:\n{result_full}") + result = result_full["choices"][0]["text"] + usage = result_full.get("usage", None) else: raise Exception( f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." @@ -47,4 +48,14 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram # TODO handle gracefully raise - return result + # Pass usage statistics back to main thread + # These are used to compute memory warning messages + completion_tokens = usage.get("completion_tokens", None) if usage is not None else None + total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None + usage = { + "prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0) + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + return result, usage diff --git a/memgpt/local_llm/webui/legacy_api.py b/memgpt/local_llm/webui/legacy_api.py index c84c2e7b2c..17f7fd6f0c 100644 --- a/memgpt/local_llm/webui/legacy_api.py +++ b/memgpt/local_llm/webui/legacy_api.py @@ -6,11 +6,12 @@ from ..utils import load_grammar_file, count_tokens WEBUI_API_SUFFIX = "/api/v1/generate" -DEBUG = False def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None): """See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server""" + from memgpt.utils import printd + prompt_tokens = count_tokens(prompt) if prompt_tokens > context_window: raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") @@ -31,10 +32,9 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/")) response = requests.post(URI, json=request) if response.status_code == 200: - result = response.json() - result = result["results"][0]["text"] - if DEBUG: - print(f"json API response.text: {result}") + result_full = response.json() + printd(f"JSON API response:\n{result_full}") + result = result_full["results"][0]["text"] else: raise Exception( f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." @@ -45,4 +45,13 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram # TODO handle gracefully raise - return result + # TODO correct for legacy + completion_tokens = None + total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None + usage = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + return result, usage