Skip to content

Commit

Permalink
fix: made eval harness more reliable, using Manager (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare authored Sep 16, 2024
1 parent 6ab895c commit 0787f59
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 50 deletions.
129 changes: 82 additions & 47 deletions gptme/eval/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
Inspired by a document by Anton Osika and Axel Theorell.
"""

import concurrent
import concurrent.futures
import csv
import inspect
import io
import logging
import multiprocessing
import os
import signal
import subprocess
Expand All @@ -17,9 +20,9 @@
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
from datetime import datetime
from multiprocessing import Process, Queue
from multiprocessing import Manager, Process
from pathlib import Path
from queue import Empty
from queue import Empty, Queue
from typing import Union

import click
Expand Down Expand Up @@ -88,21 +91,24 @@ def act_process(agent, files, prompt, queue: "Queue[ProcessResult]"):
# Runs in a process for each eval
# each eval has a process group, so we can kill all child processes
os.setpgrp()
pgrp = os.getpgrp()

# redirect stdout and stderr to streams
stdout = StreamTee(sys.stdout)
stderr = StreamTee(sys.stderr)
sys.stdout, sys.stderr = stdout, stderr # type: ignore

def _reset_stream():
sys.stdout, sys.stderr = stdout.stream, stderr.stream

def error_handler(e):
_reset_stream()
duration = time.time() - start
sys.stdout, sys.stderr = stdout.stream, stderr.stream
print(f"Error: {e}")
queue.put(ProcessError(str(e), stdout.getvalue(), stderr.getvalue(), duration))
# kill child processes
# os.killpg(0, signal.SIGKILL)

sys.exit(1)
# kill child processes
os.killpg(pgrp, signal.SIGKILL)

# handle SIGTERM
def sigterm_handler(*_):
Expand All @@ -112,12 +118,14 @@ def sigterm_handler(*_):

start = time.time()
files = agent.act(files, prompt)

_reset_stream()
duration = time.time() - start
sys.stdout, sys.stderr = stdout.stream, stderr.stream
queue.put(ProcessSuccess(files, stdout.getvalue(), stderr.getvalue(), duration))
print("Process finished successfully")
# It seems that adding this prevents the queue from syncing or something, maybe SIGKILL is too harsh...
# os.killpg(0, signal.SIGKILL)

# kill child processes
os.killpg(pgrp, signal.SIGKILL)


# TODO: rewrite to run in Docker? Would help with capturing output + process management.
Expand All @@ -129,36 +137,39 @@ def execute(test: ExecTest, agent: Agent, timeout: int) -> ExecResult:
f'Running "{test["name"]}" with prompt "{test["prompt"]}" for model: {agent.model}'
)

queue: Queue[ProcessResult] = Queue()
p = Process(target=act_process, args=(agent, test["files"], test["prompt"], queue))
p.start()
p.join(timeout)

time_gen = 0.0
time_run = 0.0
time_eval = 0.0

status: Status = "success"
if p.is_alive():
logger.info("Timeout reached, terminating process")
p.terminate()
p.join(timeout=1)
status = "timeout"
time_gen = timeout

logger.info("Getting result from queue")
try:
result = queue.get(timeout=1)
except Empty:
logger.error("Queue is empty, expected a result")
return {
"name": test["name"],
"status": "error",
"results": [],
"timings": {"gen": time_gen, "run": time_run, "eval": time_eval},
"stdout": "",
"stderr": "",
}
with Manager() as manager:
queue = manager.Queue()
p = Process(
target=act_process, args=(agent, test["files"], test["prompt"], queue)
)
p.start()
p.join(timeout)

time_gen = 0.0
time_run = 0.0
time_eval = 0.0

status: Status = "success"
if p.is_alive():
logger.info("Timeout reached, terminating process")
p.terminate()
p.join(timeout=1)
status = "timeout"
time_gen = timeout

logger.info("Getting result from queue")
try:
result = queue.get(timeout=1)
except Empty:
logger.error("Queue is empty, expected a result")
return {
"name": test["name"],
"status": "error",
"results": [],
"timings": {"gen": time_gen, "run": time_run, "eval": time_eval},
"stdout": "",
"stderr": "",
}

logger.info("Got result")
if status != "timeout":
Expand Down Expand Up @@ -247,14 +258,36 @@ def run_evals(
for model in models
}
for model, future_to_test in model_futures_to_test.items():
for future in as_completed(future_to_test):
test = future_to_test[future]
try:
result = future.result()
model_results[model].append(result)
print(f"=== Completed test {test['name']} ===")
except Exception:
logger.exception(f"Test {test['name']} generated an exception")
try:
for future in as_completed(future_to_test, timeout=timeout + 10):
test = future_to_test[future]
try:
result = future.result(
timeout=1
) # Short timeout to quickly move to next future
model_results[model].append(result)
print(f"=== Completed test {test['name']} ===")
except concurrent.futures.TimeoutError:
logger.warning(f"Test {test['name']} timed out")
model_results[model].append(
{
"name": test["name"],
"status": "timeout",
"results": [],
"timings": {"gen": timeout, "run": 0, "eval": 0},
"stdout": "",
"stderr": "",
}
)
except Exception:
logger.exception(f"Test {test['name']} generated an exception")
except concurrent.futures.TimeoutError:
logger.warning(
f"Some tests for model {model} took too long, but did not timeout correctly"
)
# Cancel any remaining futures for this model
for future in future_to_test:
future.cancel()
return model_results


Expand Down Expand Up @@ -466,4 +499,6 @@ def write_results_to_csv(model_results: dict[str, list[ExecResult]]):


if __name__ == "__main__":
# This ensures compatibility across platforms
multiprocessing.set_start_method("spawn")
main()
12 changes: 9 additions & 3 deletions gptme/eval/suites/init_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ def check_output_erik(ctx):


def check_cargo_toml(ctx):
return "Cargo.toml" in ctx.files
return "hello_world/Cargo.toml" in ctx.files


def check_rust_binary_exists(ctx):
# check that target/debug/hello exists
return "hello_world/target/debug/hello_world" in ctx.files


def check_exists_main(ctx):
Expand Down Expand Up @@ -62,10 +67,11 @@ def check_exists_main(ctx):
{
"name": "init-rust",
"files": {},
"run": "cargo build",
"prompt": "create a Rust project in the current directory",
"run": "cd hello_world; cargo build",
"prompt": "create a Rust project in a hello_world directory, write a hello world program (that doesnt take input), build it to a binary called `hello_world`, and run it",
"expect": {
"Cargo.toml exists": check_cargo_toml,
"Binary built": check_rust_binary_exists,
},
},
]

0 comments on commit 0787f59

Please sign in to comment.