Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pull model list for openai-compatible endpoints #630

Merged
merged 10 commits into from
Dec 22, 2023
129 changes: 109 additions & 20 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from memgpt.constants import LLM_MAX_TOKENS
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
from memgpt.local_llm.utils import get_available_wrappers
from memgpt.openai_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.server.utils import shorten_key_middle

app = typer.Typer()

Expand Down Expand Up @@ -63,6 +65,18 @@ def configure_llm_endpoint(config: MemGPTConfig):
openai_api_key = questionary.text(
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
).ask()
config.openai_key = openai_api_key
config.save()
else:
# Give the user an opportunity to overwrite the key
openai_api_key = None
default_input = shorten_key_middle(config.openai_key) if config.openai_key.startswith("sk-") else config.openai_key
openai_api_key = questionary.text(
"Enter your OpenAI API key (hit enter to use existing key):",
default=default_input,
).ask()
# If the user modified it, use the new one
if openai_api_key != default_input:
config.openai_key = openai_api_key
config.save()

Expand All @@ -78,6 +92,11 @@ def configure_llm_endpoint(config: MemGPTConfig):
raise ValueError(
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
)
else:
config.azure_key = azure_creds["azure_key"]
config.azure_endpoint = azure_creds["azure_endpoint"]
config.azure_version = azure_creds["azure_version"]
config.save()

model_endpoint_type = "azure"
model_endpoint = azure_creds["azure_endpoint"]
Expand Down Expand Up @@ -119,16 +138,56 @@ def configure_llm_endpoint(config: MemGPTConfig):
return model_endpoint_type, model_endpoint


def configure_model(config: MemGPTConfig, model_endpoint_type: str):
def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoint: str):
# set: model, model_wrapper
model, model_wrapper = None, None
if model_endpoint_type == "openai" or model_endpoint_type == "azure":
model_options = ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"]
# TODO: select
valid_model = config.model in model_options
# Get the model list from the openai / azure endpoint
hardcoded_model_options = ["gpt-4", "gpt-4-32k", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"]
fetched_model_options = None
try:
if model_endpoint_type == "openai":
fetched_model_options = openai_get_model_list(url=model_endpoint, api_key=config.openai_key)
elif model_endpoint_type == "azure":
fetched_model_options = azure_openai_get_model_list(
url=model_endpoint, api_key=config.azure_key, api_version=config.azure_version
)
fetched_model_options = [obj["id"] for obj in fetched_model_options["data"] if obj["id"].startswith("gpt-")]
except:
# NOTE: if this fails, it means the user's key is probably bad
typer.secho(
f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED
)

# First ask if the user wants to see the full model list (some may be incompatible)
see_all_option_str = "[see all options]"
other_option_str = "[enter model name manually]"

# Check if the model we have set already is even in the list (informs our default)
valid_model = config.model in hardcoded_model_options
model = questionary.select(
"Select default model (recommended: gpt-4):", choices=model_options, default=config.model if valid_model else model_options[0]
"Select default model (recommended: gpt-4):",
choices=hardcoded_model_options + [see_all_option_str, other_option_str],
default=config.model if valid_model else hardcoded_model_options[0],
).ask()

# If the user asked for the full list, show it
if model == see_all_option_str:
typer.secho(f"Warning: not all models shown are guaranteed to work with MemGPT", fg=typer.colors.RED)
model = questionary.select(
"Select default model (recommended: gpt-4):",
choices=fetched_model_options + [other_option_str],
default=config.model if valid_model else fetched_model_options[0],
).ask()

# Finally if the user asked to manually input, allow it
if model == other_option_str:
model = ""
while len(model) == 0:
model = questionary.text(
"Enter custom model name:",
).ask()

else: # local models
# ollama also needs model type
if model_endpoint_type == "ollama":
Expand All @@ -139,24 +198,51 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str):
).ask()
model = None if len(model) == 0 else model

default_model = config.model if config.model and config.model_endpoint_type == "vllm" else ""

# vllm needs huggingface model tag
if model_endpoint_type == "vllm":
default_model = config.model if config.model and config.model_endpoint_type == "vllm" else ""
sarahwooders marked this conversation as resolved.
Show resolved Hide resolved
model = questionary.text(
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
default=default_model,
).ask()
model = None if len(model) == 0 else model
model_wrapper = None # no model wrapper for vLLM
cpacker marked this conversation as resolved.
Show resolved Hide resolved
try:
# Don't filter model list for vLLM since model list is likely much smaller than OpenAI/Azure endpoint
# + probably has custom model names
model_options = openai_get_model_list(url=smart_urljoin(model_endpoint, "v1"), api_key=None)
model_options = [obj["id"] for obj in model_options["data"]]
except:
print(f"Failed to get model list from {model_endpoint}, using defaults")
model_options = None

# If we got model options from vLLM endpoint, allow selection + custom input
if model_options is not None:
other_option_str = "other (enter name)"
valid_model = config.model in model_options
model_options.append(other_option_str)
model = questionary.select(
"Select default model:", choices=model_options, default=config.model if valid_model else model_options[0]
).ask()

# If we got custom input, ask for raw input
if model == other_option_str:
model = questionary.text(
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
default=default_model,
).ask()
# TODO allow empty string for input?
model = None if len(model) == 0 else model

else:
model = questionary.text(
"Enter HuggingFace model tag (e.g. ehartford/dolphin-2.2.1-mistral-7b):",
default=default_model,
).ask()
model = None if len(model) == 0 else model

# model wrapper
if model_endpoint_type != "vllm":
available_model_wrappers = builtins.list(get_available_wrappers().keys())
model_wrapper = questionary.select(
f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):",
choices=available_model_wrappers,
default=DEFAULT_WRAPPER_NAME,
).ask()
available_model_wrappers = builtins.list(get_available_wrappers().keys())
model_wrapper = questionary.select(
f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):",
choices=available_model_wrappers,
default=DEFAULT_WRAPPER_NAME,
).ask()

# set: context_window
if str(model) not in LLM_MAX_TOKENS:
Expand Down Expand Up @@ -228,6 +314,7 @@ def configure_embedding_endpoint(config: MemGPTConfig):
raise ValueError(
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
)
# TODO we need to write these out to the config once we use them if we plan to ping for embedding lists with them

embedding_endpoint_type = "azure"
embedding_endpoint = azure_creds["azure_embedding_endpoint"]
Expand Down Expand Up @@ -345,7 +432,9 @@ def configure():
config = MemGPTConfig.load()
try:
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
model, model_wrapper, context_window = configure_model(
config=config, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
)
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
default_preset, default_persona, default_human, default_agent = configure_cli(config)
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
Expand Down
89 changes: 89 additions & 0 deletions memgpt/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import requests
import time
from typing import Callable, TypeVar, Union
import urllib

from box import Box
Expand Down Expand Up @@ -75,6 +76,94 @@ def clean_azure_endpoint(raw_endpoint_name):
return endpoint_address


def openai_get_model_list(url: str, api_key: Union[str, None]) -> dict:
"""https://platform.openai.com/docs/api-reference/models/list"""
from memgpt.utils import printd

url = smart_urljoin(url, "models")

headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"

printd(f"Sending request to {url}")
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response = {response}")
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
raise e


def azure_openai_get_model_list(url: str, api_key: Union[str, None], api_version: str) -> dict:
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
from memgpt.utils import printd

# https://xxx.openai.azure.com/openai/models?api-version=xxx
url = smart_urljoin(url, "openai")
url = smart_urljoin(url, f"models?api-version={api_version}")

headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["api-key"] = f"{api_key}"

printd(f"Sending request to {url}")
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response = {response}")
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
raise e


def openai_chat_completions_request(url, api_key, data):
"""https://platform.openai.com/docs/guides/text-generation?lang=curl"""
from memgpt.utils import printd
Expand Down
Loading