Skip to content

Commit

Permalink
fix: futher reliability improvements to evals
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Sep 18, 2024
1 parent cd33e06 commit 622e574
Showing 1 changed file with 45 additions and 33 deletions.
78 changes: 45 additions & 33 deletions gptme/eval/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sys
import time
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from concurrent.futures import Future, ProcessPoolExecutor, as_completed
from multiprocessing import Manager, Process
from typing import TypedDict, Union

Expand Down Expand Up @@ -68,7 +68,7 @@ def run_evals(
cleanup_on_sigterm()

n_runs = len(tests) * len(models)
model_results = defaultdict(list)
model_results: dict[str, dict[str, ExecResult]] = defaultdict(dict)
parallel = min(n_runs, parallel)
with ProcessPoolExecutor(parallel) as executor:
futures = []
Expand All @@ -85,63 +85,75 @@ def run_evals(
futures.append(future)
future_to_model_test[future] = (model, test)

def _handle_future(future):
def _handle_future(future: Future):
model, test = future_to_model_test[future]
test_name = test["name"]
try:
result = future.result(timeout=0.1)
model_results[model].append(result)
except concurrent.futures.TimeoutError:
# NOTE: is this really a good description of what is happening?
# shouldn't a timeout still give a result?
logger.warning(f"Test {test_name} for model {model} timed out")
model_results[model].append(
ExecResult(
name=test_name,
status="timeout",
results=[],
timings={"gen": timeout, "run": 0, "eval": 0},
gen_stdout="",
gen_stderr="",
run_stdout="",
run_stderr="",
except Exception as e:
# TODO: we still want to get stdout/stderr from the process
gen_time = 0
if isinstance(e, concurrent.futures.TimeoutError) or isinstance(
e, concurrent.futures.CancelledError
):
status: Status = "timeout"
gen_time = timeout
else:
status = "error"
logger.exception(
f"Test {test_name} for model {model} generated an exception when trying to get result"
)
result = ExecResult(
name=test_name,
status=status,
results=[],
timings={"gen": gen_time, "run": 0, "eval": 0},
gen_stdout="",
gen_stderr="",
run_stdout="",
run_stderr="",
)
except Exception:
logger.exception(
f"Test {test_name} for model {model} generated an exception"
)
model_results[model][test_name] = result

# worse-case run time, with some buffer to account for overhead
max_timeout = timeout * len(tests) / parallel + 10
completed = set()
try:
# worse-case run time
max_timeout = timeout * len(tests) / parallel + 10
# TODO: can we do better than this? handle timeouts within futures instead?
for future in tqdm(
as_completed(futures, timeout=max_timeout),
total=n_runs,
unit="eval",
desc="Progress",
):
_handle_future(future)
completed.add(future)
except concurrent.futures.TimeoutError:
logger.warning("Timeout reached, cancelling remaining futures")
# NOTE: this should rarely happen, as `execute` should handle timeouts
logger.warning(
"Timeout reached in top-level (shouldnt happen). Cancelling remaining futures..."
)

# Cancel any remaining futures
for future in futures:
_handle_future(future)
future.cancel()
if future not in completed:
future.cancel()
_handle_future(future)

# Ensure all processes are terminated
for process in multiprocessing.active_children():
process.terminate()
process.join()

# sort model_results by test order
model_results_final: dict[str, list[ExecResult]] = defaultdict(list)
for model in model_results:
model_results[model] = sorted(
model_results[model],
# sort results by test order
model_results_final[model] = sorted(
model_results[model].values(),
key=lambda result: [test["name"] for test in tests].index(result.name),
)

return model_results
return model_results_final


# TODO: rewrite to run in Docker? Would help with capturing output + process management.
Expand Down Expand Up @@ -174,10 +186,10 @@ def execute(test: ExecTest, agent: Agent, timeout: int, parallel: bool) -> ExecR
status: Status = "success"
if p.is_alive():
logger.info("Timeout reached, terminating process")
p.terminate()
p.join(timeout=1)
status = "timeout"
time_gen = timeout
p.terminate()
p.join(timeout=1)
finally:
if p.is_alive():
p.terminate()
Expand Down

0 comments on commit 622e574

Please sign in to comment.