Skip to content

Commit

Permalink
fix: fixed to evals, capture eval output on timeout/terminate
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Aug 14, 2024
1 parent 9ef1ec4 commit cd0862a
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions gptme/eval/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import inspect
import io
import logging
import signal
import subprocess
import sys
import time
Expand Down Expand Up @@ -54,20 +55,32 @@ class ProcessError:


def act_process(agent, files, prompt, queue: "Queue[ProcessResult]"):
# Runs on a process for each eval
# Runs in a process for each eval

# redirect stdout and stderr to streams
stdout, stderr = io.StringIO(), io.StringIO()
stdout_orig, stderr_orig = sys.stdout, sys.stderr
sys.stdout, sys.stderr = stdout, stderr

start = time.time()
try:
files = agent.act(files, prompt)
duration = time.time() - start
queue.put(ProcessSuccess(files, stdout.getvalue(), stderr.getvalue(), duration))
except Exception as e:
def error_handler(e):
duration = time.time() - start
sys.stdout, sys.stderr = stdout_orig, stderr_orig
print(f"Error: {e}")
queue.put(ProcessError(str(e), stdout.getvalue(), stderr.getvalue(), duration))
sys.exit(1)

# handle SIGTERM
def sigterm_handler(*_):
error_handler(KeyboardInterrupt("SIGTERM received"))

signal.signal(signal.SIGTERM, sigterm_handler)

start = time.time()
files = agent.act(files, prompt)
duration = time.time() - start
sys.stdout, sys.stderr = stdout_orig, stderr_orig
queue.put(ProcessSuccess(files, stdout.getvalue(), stderr.getvalue(), duration))
print("Process finished")


# Configure logging, including fully-qualified module names
Expand Down Expand Up @@ -99,6 +112,7 @@ def execute(test: ExecTest, agent: Agent, timeout: int) -> ExecResult:

status: Status = "success"
if p.is_alive():
print("Timeout reached, terminating process")
p.terminate()
p.join()
status = "timeout"
Expand All @@ -116,13 +130,14 @@ def execute(test: ExecTest, agent: Agent, timeout: int) -> ExecResult:
}

result = queue.get()
time_gen = result.duration
if status == "success":
time_gen = result.duration
stdout, stderr = result.stdout, result.stderr

if isinstance(result, ProcessError):
return {
"name": test["name"],
"status": "error",
"status": "timeout" if status == "timeout" else "error",
"results": [],
"timings": {"gen": time_gen, "run": time_run, "eval": time_eval},
"stdout": stdout,
Expand Down Expand Up @@ -241,12 +256,17 @@ def print_model_results(model_results: dict[str, list[ExecResult]]):
def print_model_results_table(model_results: dict[str, list[ExecResult]]):
table_data = []
headers = ["Model"] + [test["name"] for test in tests]
all_test_names = {
result["name"]
for model_results in model_results.values()
for result in model_results
}

for model, results in model_results.items():
row = [model]
for test in tests:
for test_name in all_test_names:
try:
result = next(r for r in results if r["name"] == test["name"])
result = next(r for r in results if r["name"] == test_name)
passed = all(case["passed"] for case in result["results"])
checkmark = "✅" if result["status"] == "success" and passed else "❌"
duration = sum(result["timings"].values())
Expand Down

0 comments on commit cd0862a

Please sign in to comment.