diff --git a/gptme/eval/run.py b/gptme/eval/run.py index 0e60441b..521d9067 100644 --- a/gptme/eval/run.py +++ b/gptme/eval/run.py @@ -9,7 +9,7 @@ import sys import time from collections import defaultdict -from concurrent.futures import ProcessPoolExecutor, as_completed +from concurrent.futures import Future, ProcessPoolExecutor, as_completed from multiprocessing import Manager, Process from typing import TypedDict, Union @@ -68,7 +68,7 @@ def run_evals( cleanup_on_sigterm() n_runs = len(tests) * len(models) - model_results = defaultdict(list) + model_results: dict[str, dict[str, ExecResult]] = defaultdict(dict) parallel = min(n_runs, parallel) with ProcessPoolExecutor(parallel) as executor: futures = [] @@ -85,36 +85,41 @@ def run_evals( futures.append(future) future_to_model_test[future] = (model, test) - def _handle_future(future): + def _handle_future(future: Future): model, test = future_to_model_test[future] test_name = test["name"] try: result = future.result(timeout=0.1) - model_results[model].append(result) - except concurrent.futures.TimeoutError: - # NOTE: is this really a good description of what is happening? - # shouldn't a timeout still give a result? - logger.warning(f"Test {test_name} for model {model} timed out") - model_results[model].append( - ExecResult( - name=test_name, - status="timeout", - results=[], - timings={"gen": timeout, "run": 0, "eval": 0}, - gen_stdout="", - gen_stderr="", - run_stdout="", - run_stderr="", + except Exception as e: + # TODO: we still want to get stdout/stderr from the process + gen_time = 0 + if isinstance(e, concurrent.futures.TimeoutError) or isinstance( + e, concurrent.futures.CancelledError + ): + status: Status = "timeout" + gen_time = timeout + else: + status = "error" + logger.exception( + f"Test {test_name} for model {model} generated an exception when trying to get result" ) + result = ExecResult( + name=test_name, + status=status, + results=[], + timings={"gen": gen_time, "run": 0, "eval": 0}, + gen_stdout="", + gen_stderr="", + run_stdout="", + run_stderr="", ) - except Exception: - logger.exception( - f"Test {test_name} for model {model} generated an exception" - ) + model_results[model][test_name] = result + # worse-case run time, with some buffer to account for overhead + max_timeout = timeout * len(tests) / parallel + 10 + completed = set() try: - # worse-case run time - max_timeout = timeout * len(tests) / parallel + 10 + # TODO: can we do better than this? handle timeouts within futures instead? for future in tqdm( as_completed(futures, timeout=max_timeout), total=n_runs, @@ -122,26 +127,33 @@ def _handle_future(future): desc="Progress", ): _handle_future(future) + completed.add(future) except concurrent.futures.TimeoutError: - logger.warning("Timeout reached, cancelling remaining futures") + # NOTE: this should rarely happen, as `execute` should handle timeouts + logger.warning( + "Timeout reached in top-level (shouldnt happen). Cancelling remaining futures..." + ) + # Cancel any remaining futures for future in futures: - _handle_future(future) - future.cancel() + if future not in completed: + future.cancel() + _handle_future(future) # Ensure all processes are terminated for process in multiprocessing.active_children(): process.terminate() process.join() - # sort model_results by test order + model_results_final: dict[str, list[ExecResult]] = defaultdict(list) for model in model_results: - model_results[model] = sorted( - model_results[model], + # sort results by test order + model_results_final[model] = sorted( + model_results[model].values(), key=lambda result: [test["name"] for test in tests].index(result.name), ) - return model_results + return model_results_final # TODO: rewrite to run in Docker? Would help with capturing output + process management. @@ -174,10 +186,10 @@ def execute(test: ExecTest, agent: Agent, timeout: int, parallel: bool) -> ExecR status: Status = "success" if p.is_alive(): logger.info("Timeout reached, terminating process") - p.terminate() - p.join(timeout=1) status = "timeout" time_gen = timeout + p.terminate() + p.join(timeout=1) finally: if p.is_alive(): p.terminate()