diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 106a1a5450..d17e70372d 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -14,6 +14,8 @@ from memgpt.constants import LLM_MAX_TOKENS from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME from memgpt.local_llm.utils import get_available_wrappers +from memgpt.openai_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin +from memgpt.server.utils import shorten_key_middle app = typer.Typer() @@ -63,6 +65,18 @@ def configure_llm_endpoint(config: MemGPTConfig): openai_api_key = questionary.text( "Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):" ).ask() + config.openai_key = openai_api_key + config.save() + else: + # Give the user an opportunity to overwrite the key + openai_api_key = None + default_input = shorten_key_middle(config.openai_key) if config.openai_key.startswith("sk-") else config.openai_key + openai_api_key = questionary.text( + "Enter your OpenAI API key (hit enter to use existing key):", + default=default_input, + ).ask() + # If the user modified it, use the new one + if openai_api_key != default_input: config.openai_key = openai_api_key config.save() @@ -78,6 +92,11 @@ def configure_llm_endpoint(config: MemGPTConfig): raise ValueError( "Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again." ) + else: + config.azure_key = azure_creds["azure_key"] + config.azure_endpoint = azure_creds["azure_endpoint"] + config.azure_version = azure_creds["azure_version"] + config.save() model_endpoint_type = "azure" model_endpoint = azure_creds["azure_endpoint"] @@ -119,16 +138,56 @@ def configure_llm_endpoint(config: MemGPTConfig): return model_endpoint_type, model_endpoint -def configure_model(config: MemGPTConfig, model_endpoint_type: str): +def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoint: str): # set: model, model_wrapper model, model_wrapper = None, None if model_endpoint_type == "openai" or model_endpoint_type == "azure": - model_options = ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"] - # TODO: select - valid_model = config.model in model_options + # Get the model list from the openai / azure endpoint + hardcoded_model_options = ["gpt-4", "gpt-4-32k", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"] + fetched_model_options = None + try: + if model_endpoint_type == "openai": + fetched_model_options = openai_get_model_list(url=model_endpoint, api_key=config.openai_key) + elif model_endpoint_type == "azure": + fetched_model_options = azure_openai_get_model_list( + url=model_endpoint, api_key=config.azure_key, api_version=config.azure_version + ) + fetched_model_options = [obj["id"] for obj in fetched_model_options["data"] if obj["id"].startswith("gpt-")] + except: + # NOTE: if this fails, it means the user's key is probably bad + typer.secho( + f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED + ) + + # First ask if the user wants to see the full model list (some may be incompatible) + see_all_option_str = "[see all options]" + other_option_str = "[enter model name manually]" + + # Check if the model we have set already is even in the list (informs our default) + valid_model = config.model in hardcoded_model_options model = questionary.select( - "Select default model (recommended: gpt-4):", choices=model_options, default=config.model if valid_model else model_options[0] + "Select default model (recommended: gpt-4):", + choices=hardcoded_model_options + [see_all_option_str, other_option_str], + default=config.model if valid_model else hardcoded_model_options[0], ).ask() + + # If the user asked for the full list, show it + if model == see_all_option_str: + typer.secho(f"Warning: not all models shown are guaranteed to work with MemGPT", fg=typer.colors.RED) + model = questionary.select( + "Select default model (recommended: gpt-4):", + choices=fetched_model_options + [other_option_str], + default=config.model if valid_model else fetched_model_options[0], + ).ask() + + # Finally if the user asked to manually input, allow it + if model == other_option_str: + model = "" + while len(model) == 0: + model = questionary.text( + "Enter custom model name:", + ).ask() + else: # local models # ollama also needs model type if model_endpoint_type == "ollama": @@ -139,24 +198,51 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str): ).ask() model = None if len(model) == 0 else model + default_model = config.model if config.model and config.model_endpoint_type == "vllm" else "" + # vllm needs huggingface model tag if model_endpoint_type == "vllm": - default_model = config.model if config.model and config.model_endpoint_type == "vllm" else "" - model = questionary.text( - "Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):", - default=default_model, - ).ask() - model = None if len(model) == 0 else model - model_wrapper = None # no model wrapper for vLLM + try: + # Don't filter model list for vLLM since model list is likely much smaller than OpenAI/Azure endpoint + # + probably has custom model names + model_options = openai_get_model_list(url=smart_urljoin(model_endpoint, "v1"), api_key=None) + model_options = [obj["id"] for obj in model_options["data"]] + except: + print(f"Failed to get model list from {model_endpoint}, using defaults") + model_options = None + + # If we got model options from vLLM endpoint, allow selection + custom input + if model_options is not None: + other_option_str = "other (enter name)" + valid_model = config.model in model_options + model_options.append(other_option_str) + model = questionary.select( + "Select default model:", choices=model_options, default=config.model if valid_model else model_options[0] + ).ask() + + # If we got custom input, ask for raw input + if model == other_option_str: + model = questionary.text( + "Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):", + default=default_model, + ).ask() + # TODO allow empty string for input? + model = None if len(model) == 0 else model + + else: + model = questionary.text( + "Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):", + default=default_model, + ).ask() + model = None if len(model) == 0 else model # model wrapper - if model_endpoint_type != "vllm": - available_model_wrappers = builtins.list(get_available_wrappers().keys()) - model_wrapper = questionary.select( - f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):", - choices=available_model_wrappers, - default=DEFAULT_WRAPPER_NAME, - ).ask() + available_model_wrappers = builtins.list(get_available_wrappers().keys()) + model_wrapper = questionary.select( + f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):", + choices=available_model_wrappers, + default=DEFAULT_WRAPPER_NAME, + ).ask() # set: context_window if str(model) not in LLM_MAX_TOKENS: @@ -228,6 +314,7 @@ def configure_embedding_endpoint(config: MemGPTConfig): raise ValueError( "Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again." ) + # TODO we need to write these out to the config once we use them if we plan to ping for embedding lists with them embedding_endpoint_type = "azure" embedding_endpoint = azure_creds["azure_embedding_endpoint"] @@ -345,7 +432,9 @@ def configure(): config = MemGPTConfig.load() try: model_endpoint_type, model_endpoint = configure_llm_endpoint(config) - model, model_wrapper, context_window = configure_model(config, model_endpoint_type) + model, model_wrapper, context_window = configure_model( + config=config, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint + ) embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config) default_preset, default_persona, default_human, default_agent = configure_cli(config) archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config) diff --git a/memgpt/openai_tools.py b/memgpt/openai_tools.py index fc6e933c65..8e68449e1d 100644 --- a/memgpt/openai_tools.py +++ b/memgpt/openai_tools.py @@ -2,6 +2,7 @@ import time import requests import time +from typing import Callable, TypeVar, Union import urllib from box import Box @@ -75,6 +76,94 @@ def clean_azure_endpoint(raw_endpoint_name): return endpoint_address +def openai_get_model_list(url: str, api_key: Union[str, None]) -> dict: + """https://platform.openai.com/docs/api-reference/models/list""" + from memgpt.utils import printd + + url = smart_urljoin(url, "models") + + headers = {"Content-Type": "application/json"} + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + + printd(f"Sending request to {url}") + try: + response = requests.get(url, headers=headers) + response.raise_for_status() # Raises HTTPError for 4XX/5XX status + response = response.json() # convert to dict from string + printd(f"response = {response}") + return response + except requests.exceptions.HTTPError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + try: + response = response.json() + except: + pass + printd(f"Got HTTPError, exception={http_err}, response={response}") + raise http_err + except requests.exceptions.RequestException as req_err: + # Handle other requests-related errors (e.g., connection error) + try: + response = response.json() + except: + pass + printd(f"Got RequestException, exception={req_err}, response={response}") + raise req_err + except Exception as e: + # Handle other potential errors + try: + response = response.json() + except: + pass + printd(f"Got unknown Exception, exception={e}, response={response}") + raise e + + +def azure_openai_get_model_list(url: str, api_key: Union[str, None], api_version: str) -> dict: + """https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP""" + from memgpt.utils import printd + + # https://xxx.openai.azure.com/openai/models?api-version=xxx + url = smart_urljoin(url, "openai") + url = smart_urljoin(url, f"models?api-version={api_version}") + + headers = {"Content-Type": "application/json"} + if api_key is not None: + headers["api-key"] = f"{api_key}" + + printd(f"Sending request to {url}") + try: + response = requests.get(url, headers=headers) + response.raise_for_status() # Raises HTTPError for 4XX/5XX status + response = response.json() # convert to dict from string + printd(f"response = {response}") + return response + except requests.exceptions.HTTPError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + try: + response = response.json() + except: + pass + printd(f"Got HTTPError, exception={http_err}, response={response}") + raise http_err + except requests.exceptions.RequestException as req_err: + # Handle other requests-related errors (e.g., connection error) + try: + response = response.json() + except: + pass + printd(f"Got RequestException, exception={req_err}, response={response}") + raise req_err + except Exception as e: + # Handle other potential errors + try: + response = response.json() + except: + pass + printd(f"Got unknown Exception, exception={e}, response={response}") + raise e + + def openai_chat_completions_request(url, api_key, data): """https://platform.openai.com/docs/guides/text-generation?lang=curl""" from memgpt.utils import printd