Skip to content

Commit

Permalink
fix(eval): misc improvements to evals, including reading results files
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Aug 23, 2024
1 parent 6ceb9d5 commit e30bd08
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 24 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ COPY media ./media
RUN poetry config virtualenvs.create false \
&& poetry install --no-interaction --no-ansi -E server -E browser -E datascience

RUN poetry run playwright install chromium

# Make port 5000 available to the world outside this container
# (assuming your Flask server runs on port 5000)
EXPOSE 5000
Expand Down
120 changes: 96 additions & 24 deletions gptme/eval/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dataclasses import dataclass
from datetime import datetime
from multiprocessing import Process, Queue
from multiprocessing.queues import Empty
from pathlib import Path
from typing import Union

Expand All @@ -35,6 +36,15 @@
Status,
)

# Configure logging, including fully-qualified module names
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
)
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)

project_dir = Path(__file__).parent.parent


@dataclass
class ProcessSuccess:
Expand Down Expand Up @@ -71,7 +81,7 @@ def error_handler(e):
print(f"Error: {e}")
queue.put(ProcessError(str(e), stdout.getvalue(), stderr.getvalue(), duration))
# kill child processes
os.killpg(0, signal.SIGKILL)
# os.killpg(0, signal.SIGKILL)
sys.exit(1)

# handle SIGTERM
Expand All @@ -86,18 +96,11 @@ def sigterm_handler(*_):
sys.stdout, sys.stderr = stdout_orig, stderr_orig
queue.put(ProcessSuccess(files, stdout.getvalue(), stderr.getvalue(), duration))
print("Process finished")
# It seems that adding this prevents the queue from syncing or something, maybe SIGKILL is too harsh...
# os.killpg(0, signal.SIGKILL)


# Configure logging, including fully-qualified module names
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
)
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)

project_dir = Path(__file__).parent.parent


# TODO: rewrite to run in Docker? Would help with capturing output + process management.
def execute(test: ExecTest, agent: Agent, timeout: int) -> ExecResult:
"""
Executes the code for a specific model with a timeout.
Expand All @@ -123,7 +126,10 @@ def execute(test: ExecTest, agent: Agent, timeout: int) -> ExecResult:
status = "timeout"
time_gen = timeout

if queue.empty():
logger.info("Getting result from queue")
try:
result = queue.get(timeout=1)
except Empty:
logger.error("Queue is empty, expected a result")
return {
"name": test["name"],
Expand All @@ -134,8 +140,6 @@ def execute(test: ExecTest, agent: Agent, timeout: int) -> ExecResult:
"stderr": "",
}

logger.info("Getting result from queue")
result = queue.get(timeout=1)
logger.info("Got result")
if status == "success":
time_gen = result.duration
Expand Down Expand Up @@ -229,8 +233,8 @@ def run_evals(
result = future.result()
model_results[model].append(result)
print(f"=== Completed test {test['name']} ===")
except Exception as exc:
print(f"Test {test['name']} generated an exception: {exc}")
except Exception:
logger.exception(f"Test {test['name']} generated an exception")
return model_results


Expand Down Expand Up @@ -263,16 +267,18 @@ def print_model_results(model_results: dict[str, list[ExecResult]]):

def print_model_results_table(model_results: dict[str, list[ExecResult]]):
table_data = []
headers = ["Model"] + [test["name"] for test in tests]
all_test_names = {
headers = ["Model"] + list(
{result["name"] for results in model_results.values() for result in results}
)
all_eval_names_or_result_files = {
result["name"]
for model_results in model_results.values()
for result in model_results
}

for model, results in model_results.items():
row = [model]
for test_name in all_test_names:
for test_name in all_eval_names_or_result_files:
try:
result = next(r for r in results if r["name"] == test_name)
passed = all(case["passed"] for case in result["results"])
Expand All @@ -291,20 +297,62 @@ def print_model_results_table(model_results: dict[str, list[ExecResult]]):


@click.command()
@click.argument("test_names", nargs=-1)
@click.option("_model", "--model", "-m", multiple=True, help="Model to use")
@click.argument("eval_names_or_result_files", nargs=-1)
@click.option(
"_model",
"--model",
"-m",
multiple=True,
help="Model to use, can be massed multiple times.",
)
@click.option("--timeout", "-t", default=15, help="Timeout for code generation")
@click.option("--parallel", "-p", default=10, help="Number of parallel evals to run")
def main(test_names: list[str], _model: list[str], timeout: int, parallel: int):
def main(
eval_names_or_result_files: list[str],
_model: list[str],
timeout: int,
parallel: int,
):
"""
Run evals for gptme.
Pass test names to run, or result files to print.
"""
models = _model or [
"openai/gpt-4o",
"openai/gpt-4o-mini",
"anthropic/claude-3-5-sonnet-20240620",
"openrouter/meta-llama/llama-3.1-8b-instruct",
"openrouter/meta-llama/llama-3.1-70b-instruct",
"openrouter/meta-llama/llama-3.1-405b-instruct",
"openrouter/nousresearch/hermes-3-llama-3.1-405b",
"openrouter/microsoft/wizardlm-2-8x22b",
"openrouter/mistralai/mistral-nemo",
"openrouter/mistralai/codestral-mamba",
"openrouter/mistralai/mixtral-8x22b-instruct",
"openrouter/deepseek/deepseek-coder",
]

results_files = [f for f in eval_names_or_result_files if f.endswith(".csv")]
for results_file in results_files:
p = Path(results_file)
if p.exists():
results = read_results_from_csv(str(p))
print_model_results_table(results)
else:
print(f"File {results_file} not found")

tests_to_run = (
[tests_map[test_name] for test_name in test_names] if test_names else tests
[
tests_map[test_name]
for test_name in eval_names_or_result_files
if test_name not in results_files
]
if eval_names_or_result_files
else tests
)
if not tests_to_run:
sys.exit(0)

print("=== Running evals ===")
model_results = run_evals(tests_to_run, models, timeout, parallel)
Expand All @@ -322,6 +370,28 @@ def main(test_names: list[str], _model: list[str], timeout: int, parallel: int):
sys.exit(0)


def read_results_from_csv(filename: str) -> dict[str, list[ExecResult]]:
model_results = defaultdict(list)
with open(filename, newline="") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
model = row["Model"]
result = ExecResult(
name=row["Test"],
status="success" if row["Passed"] == "true" else "error",
results=[], # We don't have detailed results in the CSV
timings={
"gen": float(row["Generation Time"]),
"run": float(row["Run Time"]),
"eval": float(row["Eval Time"]),
},
stdout="", # We don't have stdout in the CSV
stderr="", # We don't have stderr in the CSV
)
model_results[model].append(result)
return dict(model_results)


def write_results_to_csv(model_results: dict[str, list[ExecResult]]):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# get current commit hash and dirty status, like: a8b2ef0-dirty
Expand All @@ -330,7 +400,9 @@ def write_results_to_csv(model_results: dict[str, list[ExecResult]]):
text=True,
capture_output=True,
).stdout.strip()
filename = project_dir / f"eval_results_{timestamp}.csv"
filename = project_dir / "eval_results" / f"eval_results_{timestamp}.csv"
if not filename.parent.exists():
filename.parent.mkdir(parents=True)

with open(filename, "w", newline="") as csvfile:
fieldnames = [
Expand Down

0 comments on commit e30bd08

Please sign in to comment.