From 6fa8016e74b0681405b5b4105f10047c9bd5cafe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Sat, 10 Aug 2024 00:31:10 +0200 Subject: [PATCH] feat: major eval revamp, openrouter support, removed `--llm` in favor of `--model /` --- README.md | 7 +- docs/providers.md | 52 ++++------ eval/agents.py | 4 +- eval/evals.py | 56 ++++++++--- eval/main.py | 240 +++++++++++++++++++++++++++++++++----------- eval/types.py | 19 +++- gptme/cli.py | 15 +-- gptme/init.py | 37 ++++--- gptme/llm.py | 34 ++++--- gptme/models.py | 53 ++++++++-- gptme/server/cli.py | 12 +-- poetry.lock | 27 ++++- pyproject.toml | 2 + 13 files changed, 383 insertions(+), 175 deletions(-) diff --git a/README.md b/README.md index 34fad949..2e20c2a5 100644 --- a/README.md +++ b/README.md @@ -230,9 +230,10 @@ Options: --name TEXT Name of conversation. Defaults to generating a random name. Pass 'ask' to be prompted for a name. - --llm [openai|anthropic|azure|local] - LLM provider to use. - --model TEXT Model to use. + --model TEXT Model to use, e.g. openai/gpt-4-turbo, + anthropic/claude-3-5-sonnet-20240620. If + only provider is given, the default model + for that provider is used. --stream / --no-stream Stream responses -v, --verbose Verbose output. -y, --no-confirm Skips all confirmation prompts. diff --git a/docs/providers.md b/docs/providers.md index 2ab4d107..f773f840 100644 --- a/docs/providers.md +++ b/docs/providers.md @@ -3,6 +3,16 @@ Providers We support several LLM providers, including OpenAI, Anthropic, Azure, and any OpenAI-compatible server (e.g. `ollama`, `llama-cpp-python`). +To select a provider and model, run `gptme` with the `--model` flag set to `/`, for example: + +```sh +gptme --model openai/gpt-4o "hello" +gptme --model anthropic "hello" # if model part unspecified, will fall back to the provider default +gptme --model openrouter/meta-llama/llama-3.1-70b-instruct "hello" +``` + +On first startup, if `--model` is not set, and no API keys are set in the config or environment it will be prompted for. It will then auto-detect the provider, and save the key in the configuration file. + ## OpenAI To use OpenAI, set your API key: @@ -11,8 +21,6 @@ To use OpenAI, set your API key: export OPENAI_API_KEY="your-api-key" ``` -If no key is set, it will be prompted for and saved in the configuration file. - ## Anthropic To use Anthropic, set your API key: @@ -21,11 +29,17 @@ To use Anthropic, set your API key: export ANTHROPIC_API_KEY="your-api-key" ``` -If no key is set, it will be prompted for and saved in the configuration file. +## OpenRouter + +To use OpenRouter, set your API key: + +```sh +export OPENROUTER_API_KEY="your-api-key" +``` ## Local -There are several ways to run local LLM models in a way that exposes a OpenAI API-compatible server, here we will cover two: +There are several ways to run local LLM models in a way that exposes a OpenAI API-compatible server, here we will cover: ### ollama + litellm @@ -39,33 +53,3 @@ ollama serve litellm --model ollama/mistral export OPENAI_API_BASE="http://localhost:8000" ``` - -### llama-cpp-python - -Here's how to use `llama-cpp-python`. - -You first need to install and run the [llama-cpp-python][llama-cpp-python] server. To ensure you get the most out of your hardware, make sure you build it with [the appropriate hardware acceleration][hwaccel]. For macOS, you can find detailed instructions [here][metal]. - -```sh -MODEL=~/ML/wizardcoder-python-13b-v1.0.Q4_K_M.gguf -poetry run python -m llama_cpp.server --model $MODEL --n_gpu_layers 1 # Use `--n_gpu_layer 1` if you have a M1/M2 chip -export OPENAI_API_BASE="http://localhost:8000/v1" -``` - -### Usage - -Now, simply run `gptme` with the `--llm` flag set to `local`: - -```sh -gptme --llm local "hello" -``` - -### How well does it work? - -I've had mixed results. They are not nearly as good as GPT-4, and often struggles with the tools laid out in the system prompt. However I haven't tested with models larger than 7B/13B. - -I'm hoping future models, trained better for tool-use and interactive coding (where outputs are fed back), can remedy this, even at 7B/13B model sizes. Perhaps we can fine-tune a model on (GPT-4) conversation logs to create a purpose-fit model that knows how to use the tools. - -[llama-cpp-python]: https://github.com/abetlen/llama-cpp-python -[hwaccel]: https://github.com/abetlen/llama-cpp-python#installation-with-hardware-acceleration -[metal]: https://github.com/abetlen/llama-cpp-python/blob/main/docs/install/macos.md diff --git a/eval/agents.py b/eval/agents.py index d1b4e1cb..87d8b8f5 100644 --- a/eval/agents.py +++ b/eval/agents.py @@ -10,8 +10,7 @@ class Agent: - def __init__(self, llm: str, model: str): - self.llm = llm + def __init__(self, model: str): self.model = model @abstractmethod @@ -42,7 +41,6 @@ def act(self, files: Files | None, prompt: str): [Message("user", prompt)], [prompt_sys], f"gptme-evals-{store.id}", - llm=self.llm, model=self.model, no_confirm=True, interactive=False, diff --git a/eval/evals.py b/eval/evals.py index 33c39cd7..687fa18a 100644 --- a/eval/evals.py +++ b/eval/evals.py @@ -3,6 +3,39 @@ if TYPE_CHECKING: from main import ExecTest + +def correct_output_hello(ctx): + return ctx.stdout == "Hello, human!\n" + + +def correct_file_hello(ctx): + return ctx.files["hello.py"].strip() == "print('Hello, human!')" + + +def check_prime_output(ctx): + return "541" in ctx.stdout.split() + + +def check_clean_exit(ctx): + return ctx.exit_code == 0 + + +def check_clean_working_tree(ctx): + return "nothing to commit, working tree clean" in ctx.stdout + + +def check_main_py_exists(ctx): + return "main.py" in ctx.files + + +def check_commit_exists(ctx): + return "No commits yet" not in ctx.stdout + + +def check_output_hello_ask(ctx): + return "Hello, Erik!" in ctx.stdout + + tests: list["ExecTest"] = [ { "name": "hello", @@ -10,9 +43,8 @@ "run": "python hello.py", "prompt": "Change the code in hello.py to print 'Hello, human!'", "expect": { - "correct output": lambda ctx: ctx.stdout == "Hello, human!\n", - "correct file": lambda ctx: ctx.files["hello.py"].strip() - == "print('Hello, human!')", + "correct output": correct_output_hello, + "correct file": correct_file_hello, }, }, { @@ -21,9 +53,8 @@ "run": "python hello.py", "prompt": "Patch the code in hello.py to print 'Hello, human!'", "expect": { - "correct output": lambda ctx: ctx.stdout == "Hello, human!\n", - "correct file": lambda ctx: ctx.files["hello.py"].strip() - == "print('Hello, human!')", + "correct output": correct_output_hello, + "correct file": correct_file_hello, }, }, { @@ -33,7 +64,7 @@ # TODO: work around the "don't try to execute it" part by improving gptme such that it just gives EOF to stdin in non-interactive mode "prompt": "modify hello.py to ask the user for their name and print 'Hello, !'. don't try to execute it", "expect": { - "correct output": lambda ctx: "Hello, Erik!" in ctx.stdout, + "correct output": check_output_hello_ask, }, }, { @@ -42,7 +73,7 @@ "run": "python prime.py", "prompt": "write a script prime.py that computes and prints the 100th prime number", "expect": { - "correct output": lambda ctx: "541" in ctx.stdout.split(), + "correct output": check_prime_output, }, }, { @@ -51,11 +82,10 @@ "run": "git status", "prompt": "initialize a git repository, write a main.py file, and commit it", "expect": { - "clean exit": lambda ctx: ctx.exit_code == 0, - "clean working tree": lambda ctx: "nothing to commit, working tree clean" - in ctx.stdout, - "main.py exists": lambda ctx: "main.py" in ctx.files, - "we have a commit": lambda ctx: "No commits yet" not in ctx.stdout, + "clean exit": check_clean_exit, + "clean working tree": check_clean_working_tree, + "main.py exists": check_main_py_exists, + "we have a commit": check_commit_exists, }, }, # Fails, gets stuck on interactive stuff diff --git a/eval/main.py b/eval/main.py index 1c97e0cd..0e6f3daa 100644 --- a/eval/main.py +++ b/eval/main.py @@ -6,12 +6,21 @@ import csv import inspect +import io import logging import subprocess import sys import time +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass from datetime import datetime +from multiprocessing import Process, Queue from pathlib import Path +from typing import Literal, Union + +import click +from tabulate import tabulate from .agents import Agent, GPTMe from .evals import tests, tests_map @@ -23,23 +32,105 @@ ResultContext, ) + +@dataclass +class ProcessSuccess: + files: dict[str, str | bytes] + stdout: str + stderr: str + duration: float + + +@dataclass +class ProcessError: + message: str + stdout: str + stderr: str + duration: float + + +Status = Literal["success", "error"] +ProcessResult = Union[ProcessSuccess, ProcessError] + + +def act_process(agent, files, prompt, queue: "Queue[ProcessResult]"): + # Runs on a process for each eval + + # redirect stdout and stderr to streams + stdout, stderr = io.StringIO(), io.StringIO() + sys.stdout, sys.stderr = stdout, stderr + + start = time.time() + try: + files = agent.act(files, prompt) + duration = time.time() - start + queue.put(ProcessSuccess(files, stdout.getvalue(), stderr.getvalue(), duration)) + except Exception as e: + duration = time.time() - start + queue.put(ProcessError(str(e), stdout.getvalue(), stderr.getvalue(), duration)) + + +# Configure logging, including fully-qualified module names +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s" +) +logging.getLogger("httpx").setLevel(logging.WARNING) logger = logging.getLogger(__name__) project_dir = Path(__file__).parent.parent -def execute(test: ExecTest, agent: Agent) -> ExecResult: +def execute(test: ExecTest, agent: Agent, timeout: int) -> ExecResult: """ - Executes the code for a specific model. + Executes the code for a specific model with a timeout. """ print( f"Running test {test['name']} with prompt: {test['prompt']} for model: {agent.model}" ) - # generate code - gen_start = time.time() - files = agent.act(test["files"], test["prompt"]) - gen_duration = time.time() - gen_start + queue: Queue[ProcessResult] = Queue() + p = Process(target=act_process, args=(agent, test["files"], test["prompt"], queue)) + p.start() + p.join(timeout) + + if p.is_alive(): + p.terminate() + p.join() + return { + "name": test["name"], + "status": "timeout", + "results": [], + "timings": {"gen": timeout, "run": 0, "eval": 0}, + # TODO: get stdout/stderr for timeouts somehow + "stdout": "", + "stderr": "", + } + + if queue.empty(): + logger.error("Queue is empty, expected a result") + return { + "name": test["name"], + "status": "error", + "results": [], + "timings": {"gen": 0, "run": 0, "eval": 0}, + "stdout": "", + "stderr": "", + } + + result = queue.get() + stdout, stderr = result.stdout, result.stderr + + if isinstance(result, ProcessError): + return { + "name": test["name"], + "status": "error", + "results": [], + "timings": {"gen": result.duration, "run": 0, "eval": 0}, + "stdout": stdout, + "stderr": stderr, + } + else: + files = result.files # check and collect results run_start = time.time() @@ -71,48 +162,47 @@ def execute(test: ExecTest, agent: Agent) -> ExecResult: return { "name": test["name"], + "status": "success", "results": results, "timings": { - "gen": gen_duration, + "gen": result.duration, "run": run_duration, "eval": sum(r["duration"] for r in results), }, + "stdout": stdout, + "stderr": stderr, } -def main(): - models = [ - # "openai/gpt-3.5-turbo", - # "openai/gpt-4-turbo", - # "openai/gpt-4o", - "openai/gpt-4o-mini", - # "anthropic/claude-3-5-sonnet-20240620", - "anthropic/claude-3-haiku-20240307", - ] - test_name = sys.argv[1] if len(sys.argv) > 1 else None - - all_results = {} - for model in models: - print(f"\n=== Running tests for model: {model} ===") - llm, model = model.split("/") - agent = GPTMe(llm=llm, model=model) - - results = [] - if test_name: - print(f"=== Running test {test_name} ===") - result = execute(tests_map[test_name], agent) - results.append(result) - else: - print("=== Running all tests ===") - for test in tests: - result = execute(test, agent) - results.append(result) - - all_results[model] = results +def run_evals( + tests, models, timeout: int, parallel: int +) -> dict[str, list[ExecResult]]: + """ + Run evals for a list of tests. + """ + model_results = defaultdict(list) + with ProcessPoolExecutor(parallel) as executor: + model_futures_to_test = { + model: { + executor.submit(execute, test, GPTMe(model=model), timeout): test + for test in tests + } + for model in models + } + for model, future_to_test in model_futures_to_test.items(): + for future in as_completed(future_to_test): + test = future_to_test[future] + try: + result = future.result() + model_results[model].append(result) + print(f"=== Completed test {test['name']} ===") + except Exception as exc: + print(f"Test {test['name']} generated an exception: {exc}") + return model_results - print("\n=== Finished ===\n") - for model, results in all_results.items(): +def print_model_results(model_results: dict[str, list[ExecResult]]): + for model, results in model_results.items(): print(f"\nResults for model: {model}") duration_total = sum( result["timings"]["gen"] @@ -137,32 +227,64 @@ def main(): checkmark = "✅" if case["passed"] else "❌" print(f" {checkmark} {case['name']}") - print("\n=== Model Comparison ===") - for test in tests: - print(f"\nTest: {test['name']}") - for model, results in all_results.items(): - result = next(r for r in results if r["name"] == test["name"]) - passed = all(case["passed"] for case in result["results"]) - checkmark = "✅" if passed else "❌" - duration = sum(result["timings"].values()) - print(f"{model}: {checkmark} {duration:.2f}s") - - all_success = all( - all(all(case["passed"] for case in result["results"]) for result in results) - for results in all_results.values() + +def print_model_results_table(model_results: dict[str, list[ExecResult]]): + table_data = [] + headers = ["Model"] + [test["name"] for test in tests] + + for model, results in model_results.items(): + row = [model] + for test in tests: + try: + result = next(r for r in results if r["name"] == test["name"]) + passed = all(case["passed"] for case in result["results"]) + checkmark = "✅" if result["status"] == "success" and passed else "❌" + duration = sum(result["timings"].values()) + reason = "timeout" if result["status"] == "timeout" else "" + if reason: + row.append(f"{checkmark} {reason}") + else: + row.append(f"{checkmark} {duration:.2f}s") + except StopIteration: + row.append("❌ N/A") + table_data.append(row) + + print(tabulate(table_data, headers=headers, tablefmt="grid")) + + +@click.command() +@click.argument("test_names", nargs=-1) +@click.option("_model", "--model", "-m", multiple=True, help="Model to use") +@click.option("--timeout", "-t", default=15, help="Timeout for code generation") +@click.option("--parallel", "-p", default=10, help="Number of parallel evals to run") +def main(test_names: list[str], _model: list[str], timeout: int, parallel: int): + models = _model or [ + "openai/gpt-4o", + "anthropic/claude-3-5-sonnet-20240620", + "openrouter/meta-llama/llama-3.1-70b-instruct", + ] + + tests_to_run = ( + [tests_map[test_name] for test_name in test_names] if test_names else tests ) - if all_success: - print("\n✅ All tests passed for all models!") - else: - print("\n❌ Some tests failed!") + + print("=== Running evals ===") + model_results = run_evals(tests_to_run, models, timeout, parallel) + print("\n=== Finished ===\n") + + print("\n\n=== Model Results ===") + print_model_results(model_results) + + print("\n\n=== Model Comparison ===") + print_model_results_table(model_results) # Write results to CSV - write_results_to_csv(all_results) + write_results_to_csv(model_results) - sys.exit(0 if all_success else 1) + sys.exit(0) -def write_results_to_csv(all_results): +def write_results_to_csv(model_results: dict[str, list[ExecResult]]): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # get current commit hash and dirty status, like: a8b2ef0-dirty commit_hash = subprocess.run( @@ -186,7 +308,7 @@ def write_results_to_csv(all_results): writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() - for model, results in all_results.items(): + for model, results in model_results.items(): for result in results: passed = all(case["passed"] for case in result["results"]) writer.writerow( diff --git a/eval/types.py b/eval/types.py index 94a6060b..2d2c5370 100644 --- a/eval/types.py +++ b/eval/types.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass -from typing import TypedDict from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal, TypedDict Files = dict[str, str | bytes] @@ -18,6 +18,10 @@ class ResultContext: class CaseResult(TypedDict): + """ + Result of a single test case on the execution of a prompt. + """ + name: str passed: bool code: str @@ -25,12 +29,23 @@ class CaseResult(TypedDict): class ExecResult(TypedDict): + """ + Result of executing a prompt. + """ + name: str + status: Literal["success", "error", "timeout"] results: list[CaseResult] timings: dict[str, float] + stdout: str + stderr: str class ExecTest(TypedDict): + """ + Test case for executing a prompt. + """ + name: str files: Files run: str diff --git a/gptme/cli.py b/gptme/cli.py index 549a3a94..f9540604 100644 --- a/gptme/cli.py +++ b/gptme/cli.py @@ -20,7 +20,7 @@ from .commands import CMDFIX, action_descriptions, execute_cmd from .constants import MULTIPROMPT_SEPARATOR, PROMPT_USER from .dirs import get_logs_dir -from .init import PROVIDERS, init, init_logging +from .init import init, init_logging from .llm import reply from .logmanager import LogManager, _conversations from .message import Message @@ -75,16 +75,10 @@ default="random", help="Name of conversation. Defaults to generating a random name. Pass 'ask' to be prompted for a name.", ) -@click.option( - "--llm", - default=None, - help="LLM provider to use.", - type=click.Choice(PROVIDERS), -) @click.option( "--model", default=None, - help="Model to use.", + help="Model to use, e.g. openai/gpt-4-turbo, anthropic/claude-3-5-sonnet-20240620. If only provider is given, the default model for that provider is used.", ) @click.option( "--stream/--no-stream", @@ -116,7 +110,6 @@ def main( prompts: list[str], prompt_system: str, name: str, - llm: LLMChoice, model: ModelChoice, stream: bool, verbose: bool, @@ -172,7 +165,6 @@ def main( prompt_msgs, initial_msgs, name, - llm, model, stream, no_confirm, @@ -185,7 +177,6 @@ def chat( prompt_msgs: list[Message], initial_msgs: list[Message], name: str, - llm: str | None, model: str | None, stream: bool = True, no_confirm: bool = False, @@ -201,7 +192,7 @@ def chat( Callable from other modules. """ # init - init(llm, model, interactive) + init(model, interactive) # (re)init shell set_shell(ShellSession()) diff --git a/gptme/init.py b/gptme/init.py index 19dba903..bb6a65e1 100644 --- a/gptme/init.py +++ b/gptme/init.py @@ -6,18 +6,16 @@ from .config import config_path, load_config, set_config_value from .dirs import get_readline_history_file -from .llm import get_recommended_model, init_llm -from .models import set_default_model +from .llm import init_llm +from .models import PROVIDERS, get_recommended_model, set_default_model from .tabcomplete import register_tabcomplete from .tools import init_tools logger = logging.getLogger(__name__) _init_done = False -PROVIDERS = ["openai", "anthropic", "azure", "local"] - -def init(provider: str | None, model: str | None, interactive: bool): +def init(model: str | None, interactive: bool): global _init_done if _init_done: logger.warning("init() called twice, ignoring") @@ -31,30 +29,38 @@ def init(provider: str | None, model: str | None, interactive: bool): config = load_config() # get from config - if not provider: - provider = config.get_env("PROVIDER") + if not model: + model = config.get_env("MODEL") - if not provider: # pragma: no cover + if not model: # pragma: no cover # auto-detect depending on if OPENAI_API_KEY or ANTHROPIC_API_KEY is set if config.get_env("OPENAI_API_KEY"): print("Found OpenAI API key, using OpenAI provider") - provider = "openai" + model = "openai" elif config.get_env("ANTHROPIC_API_KEY"): print("Found Anthropic API key, using Anthropic provider") - provider = "anthropic" + model = "anthropic" # ask user for API key elif interactive: - provider, _ = ask_for_api_key() + model, _ = ask_for_api_key() # fail - if not provider: + if not model: raise ValueError("No API key found, couldn't auto-detect provider") + if any(model.startswith(f"{provider}/") for provider in PROVIDERS): + provider, model = model.split("/", 1) + else: + provider, model = model, None + # set up API_KEY and API_BASE, needs to be done before loading history to avoid saving API_KEY init_llm(provider) if not model: - model = config.get_env("MODEL") or get_recommended_model() + model = get_recommended_model(provider) + logger.info( + "No model specified, using recommended model for provider: %s", model + ) set_default_model(model) if interactive: @@ -103,7 +109,7 @@ def _load_readline_history() -> None: # pragma: no cover def ask_for_api_key(): # pragma: no cover """Interactively ask user for API key""" - print("No API key set for OpenAI or Anthropic.") + print("No API key set for OpenAI, Anthropic, or OpenRouter.") print( """You can get one at: - OpenAI: https://platform.openai.com/account/api-keys @@ -115,6 +121,9 @@ def ask_for_api_key(): # pragma: no cover if api_key.startswith("sk-ant-"): provider = "anthropic" env_var = "ANTHROPIC_API_KEY" + elif api_key.startswith("sk-or-"): + provider = "openrouter" + env_var = "OPENROUTER_API_KEY" else: provider = "openai" env_var = "OPENAI_API_KEY" diff --git a/gptme/llm.py b/gptme/llm.py index 6020a4bb..5cefb20c 100644 --- a/gptme/llm.py +++ b/gptme/llm.py @@ -10,7 +10,7 @@ from .config import get_config from .constants import PROMPT_ASSISTANT from .message import Message -from .models import MODELS +from .models import MODELS, get_summary_model from .util import len_tokens, msgs2dicts # Optimized for code @@ -42,16 +42,18 @@ def init_llm(llm: str): api_version="2023-07-01-preview", azure_endpoint=azure_endpoint, ) - elif llm == "anthropic": api_key = config.get_env_required("ANTHROPIC_API_KEY") anthropic_client = Anthropic( api_key=api_key, ) - + elif llm == "openrouter": + api_key = config.get_env_required("OPENROUTER_API_KEY") + oai_client = OpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1") elif llm == "local": api_base = config.get_env_required("OPENAI_API_BASE") - oai_client = OpenAI(api_key="ollama", base_url=api_base) + api_key = config.get_env("OPENAI_API_KEY") or "ollama" + oai_client = OpenAI(api_key=api_key, base_url=api_base) else: print(f"Error: Unknown LLM: {llm}") sys.exit(1) @@ -244,14 +246,18 @@ def print_clear(): return Message("assistant", output) -def get_recommended_model() -> str: - assert oai_client or anthropic_client, "LLM not initialized" - return "gpt-4-turbo" if oai_client else "claude-3-5-sonnet-20240620" - - -def get_summary_model() -> str: - assert oai_client or anthropic_client, "LLM not initialized" - return "gpt-4o-mini" if oai_client else "claude-3-haiku-20240307" +def _client_to_provider() -> str: + if oai_client: + if "openai" in oai_client.base_url.host: + return "openai" + elif "openrouter" in oai_client.base_url.host: + return "openrouter" + else: + return "azure" + elif anthropic_client: + return "anthropic" + else: + raise ValueError("Unknown client type") def summarize(content: str) -> str: @@ -269,7 +275,7 @@ def summarize(content: str) -> str: Message("user", content=f"Summarize this:\n{content}"), ] - model = get_summary_model() + model = get_summary_model(_client_to_provider()) context_limit = MODELS["openai" if oai_client else "anthropic"][model]["context"] if len_tokens(messages) > context_limit: raise ValueError( @@ -312,5 +318,5 @@ def generate_name(msgs: list[Message]) -> str: + msgs + [Message("user", "Now, generate a name for this conversation.")] ) - name = _chat_complete(msgs, model=get_summary_model()).strip() + name = _chat_complete(msgs, model=get_summary_model(_client_to_provider())).strip() return name diff --git a/gptme/models.py b/gptme/models.py index 749db73d..7efa4098 100644 --- a/gptme/models.py +++ b/gptme/models.py @@ -29,8 +29,11 @@ class _ModelDictMeta(TypedDict): price_output: NotRequired[float] +# available providers +PROVIDERS = ["openai", "anthropic", "azure", "openrouter", "local"] + # default model -DEFAULT_MODEL: str | None = None +DEFAULT_MODEL: ModelMeta | None = None # known models metadata # TODO: can we get this from the API? @@ -118,34 +121,62 @@ class _ModelDictMeta(TypedDict): def set_default_model(model: str) -> None: - assert get_model(model) + modelmeta = get_model(model) + assert modelmeta global DEFAULT_MODEL - DEFAULT_MODEL = model + DEFAULT_MODEL = modelmeta def get_model(model: str | None = None) -> ModelMeta: if model is None: assert DEFAULT_MODEL, "Default model not set, set it with set_default_model()" - model = DEFAULT_MODEL - - if "/" in model: - provider, model = model.split("/") + return DEFAULT_MODEL + + if model in PROVIDERS: + provider = model + return ModelMeta( + provider, model, **MODELS[provider][get_recommended_model(provider)] + ) + if any(f"{provider}/" in model for provider in PROVIDERS): + provider, model = model.split("/", 1) if provider not in MODELS or model not in MODELS[provider]: logger.warning( - f"Model {provider}/{model} not found, using fallback model metadata" + f"Unknown model {model} from {provider}, using fallback metadata" ) - return ModelMeta(provider=provider, model=model, context=4000) + return ModelMeta(provider=provider, model=model, context=128_000) else: # try to find model in all providers for provider in MODELS: if model in MODELS[provider]: break else: - logger.warning(f"Model {model} not found, using fallback model metadata") - return ModelMeta(provider="unknown", model=model, context=4000) + logger.warning(f"Unknown model {model} not found, using fallback metadata") + return ModelMeta(provider="unknown", model=model, context=128_000) return ModelMeta( provider=provider, model=model, **MODELS[provider][model], ) + + +def get_recommended_model(provider: str) -> str: + if provider == "openai": + return "gpt-4-turbo" + elif provider == "openrouter": + return "meta-llama/llama-3.1-70b-instruct" + elif provider == "anthropic": + return "claude-3-5-sonnet-20240620" + else: + raise ValueError(f"Unknown provider {provider}") + + +def get_summary_model(provider: str) -> str: + if provider == "openai": + return "gpt-4o-mini" + elif provider == "openrouter": + return "meta-llama/llama-3.1-8b-instruct" + elif provider == "anthropic": + return "claude-3-haiku-20240307" + else: + raise ValueError(f"Unknown provider {provider}") diff --git a/gptme/server/cli.py b/gptme/server/cli.py index 4d86501e..6335b449 100644 --- a/gptme/server/cli.py +++ b/gptme/server/cli.py @@ -2,32 +2,26 @@ import click -from ..init import PROVIDERS, init, init_logging +from ..init import init, init_logging logger = logging.getLogger(__name__) @click.command("gptme-server") @click.option("-v", "--verbose", is_flag=True, help="Verbose output.") -@click.option( - "--llm", - default=None, - help="LLM provider to use.", - type=click.Choice(PROVIDERS), -) @click.option( "--model", default=None, help="Model to use by default, can be overridden in each request.", ) -def main(verbose: bool, llm: str | None, model: str | None): # pragma: no cover +def main(verbose: bool, model: str | None): # pragma: no cover """ Starts a server and web UI for gptme. Note that this is very much a work in progress, and is not yet ready for normal use. """ init_logging(verbose) - init(llm, model, interactive=False) + init(model, interactive=False) # if flask not installed, ask the user to install `server` extras try: diff --git a/poetry.lock b/poetry.lock index 9e079036..c852f674 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2975,6 +2975,20 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tabulate" +version = "0.9.0" +description = "Pretty-print tabular data" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, + {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, +] + +[package.extras] +widechars = ["wcwidth"] + [[package]] name = "tiktoken" version = "0.5.2" @@ -3356,6 +3370,17 @@ build = ["cmake (>=3.18)", "lit"] tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] tutorials = ["matplotlib", "pandas", "tabulate"] +[[package]] +name = "types-tabulate" +version = "0.9.0.20240106" +description = "Typing stubs for tabulate" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-tabulate-0.9.0.20240106.tar.gz", hash = "sha256:c9b6db10dd7fcf55bd1712dd3537f86ddce72a08fd62bb1af4338c7096ce947e"}, + {file = "types_tabulate-0.9.0.20240106-py3-none-any.whl", hash = "sha256:0378b7b6fe0ccb4986299496d027a6d4c218298ecad67199bbd0e2d7e9d335a1"}, +] + [[package]] name = "typing-extensions" version = "4.9.0" @@ -3455,4 +3480,4 @@ training = ["torch", "transformers"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "2cd0d3009d86179aeb1fbd50db9b23aa270855e5eeed782be4c56b28e18df38e" +content-hash = "22e411898e15765e493850be77a01c4b84851a5a0a083b1650dca7707db9e5a2" diff --git a/pyproject.toml b/pyproject.toml index 2e1198dd..e9c6166b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ python = "^3.10" click = "^8.0" python-dotenv = "^1.0.0" rich = "^13.5.2" +tabulate = "^0.9.0" pick = "^2.2.0" tiktoken = "^0.5.1" tomlkit = "^0.12.1" @@ -68,6 +69,7 @@ sphinx-book-theme = "^1.0.1" myst-parser = "^2.0.0" pyupgrade = "^3.15.0" greenlet = "*" # dependency of playwright, but needed for coverage +types-tabulate = "^0.9.0.20240106" [tool.poetry.extras] server = ["llama-cpp-python", "flask"]