From 6d00be7a46f42ed19ebe91c0527f67a454cc172c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Wed, 18 Sep 2024 16:00:32 +0200 Subject: [PATCH] fix: improved typing in gptme.evals.run --- gptme/eval/run.py | 41 +++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/gptme/eval/run.py b/gptme/eval/run.py index f5cc0c05..165722b3 100644 --- a/gptme/eval/run.py +++ b/gptme/eval/run.py @@ -10,9 +10,8 @@ import time from collections import defaultdict from concurrent.futures import ProcessPoolExecutor, as_completed -from dataclasses import dataclass from multiprocessing import Manager, Process -from typing import Union +from typing import TypedDict, Union from .agents import Agent, GPTMe from .execenv import SimpleExecutionEnv @@ -27,16 +26,16 @@ logger = logging.getLogger(__name__) -@dataclass -class ProcessSuccess: +class ProcessSuccess(TypedDict): + status: str files: dict[str, str | bytes] stdout: str stderr: str duration: float -@dataclass -class ProcessError: +class ProcessError(TypedDict): + status: str message: str stdout: str stderr: str @@ -46,6 +45,10 @@ class ProcessError: ProcessResult = Union[ProcessSuccess, ProcessError] +class SyncedDict(TypedDict): + result: ProcessResult + + def run_evals( tests: list[ExecTest], models: list[str], timeout: int, parallel: int ) -> dict[str, list[ExecResult]]: @@ -145,15 +148,15 @@ def execute(test: ExecTest, agent: Agent, timeout: int, parallel: bool) -> ExecR time_eval = 0.0 with Manager() as manager: - result_dict = manager.dict() + sync_dict = manager.dict() p = Process( target=act_process, args=( agent, - test["files"], - test["prompt"], - result_dict, test["name"], + test["prompt"], + test["files"], + sync_dict, parallel, ), ) @@ -173,10 +176,10 @@ def execute(test: ExecTest, agent: Agent, timeout: int, parallel: bool) -> ExecR p.terminate() p.join(timeout=1) - if "result" in result_dict: - result = result_dict["result"] + if "result" in sync_dict: + result = sync_dict["result"] time_gen = max(result.get("duration", 0.0), time_gen) - status = result.get("status", "success") + status = result["status"] files = result.get("files", {}) gen_stdout = result.get("stdout", "") gen_stderr = result.get("stderr", "") @@ -264,10 +267,10 @@ def getvalue(self): def act_process( agent: Agent, - files, - prompt, - result_dict: dict, test_name: str, + prompt: str, + files: dict[str, str | bytes], + sync_dict: SyncedDict, parallel: bool, ): # Configure logging for this subprocess @@ -290,13 +293,14 @@ def error_handler(e): duration = time.time() - start if not isinstance(e, KeyboardInterrupt): subprocess_logger.error(f"Error: {e}") - result_dict["result"] = { + result_error: ProcessError = { "status": "error", "message": str(e), "stdout": stdout.getvalue(), "stderr": stderr.getvalue(), "duration": duration, } + sync_dict["result"] = result_error # kill child processes os.killpg(pgrp, signal.SIGKILL) @@ -315,13 +319,14 @@ def sigterm_handler(*_): return duration = time.time() - start - result_dict["result"] = { + result_success: ProcessSuccess = { "status": "success", "files": files, "stdout": stdout.getvalue(), "stderr": stderr.getvalue(), "duration": duration, } + sync_dict["result"] = result_success subprocess_logger.info("Success") # kill child processes