Skip to content

Commit

Permalink
fix: improved typing in gptme.evals.run
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Sep 18, 2024
1 parent 8c3cb77 commit 6d00be7
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions gptme/eval/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import time
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
from multiprocessing import Manager, Process
from typing import Union
from typing import TypedDict, Union

from .agents import Agent, GPTMe
from .execenv import SimpleExecutionEnv
Expand All @@ -27,16 +26,16 @@
logger = logging.getLogger(__name__)


@dataclass
class ProcessSuccess:
class ProcessSuccess(TypedDict):
status: str
files: dict[str, str | bytes]
stdout: str
stderr: str
duration: float


@dataclass
class ProcessError:
class ProcessError(TypedDict):
status: str
message: str
stdout: str
stderr: str
Expand All @@ -46,6 +45,10 @@ class ProcessError:
ProcessResult = Union[ProcessSuccess, ProcessError]


class SyncedDict(TypedDict):
result: ProcessResult


def run_evals(
tests: list[ExecTest], models: list[str], timeout: int, parallel: int
) -> dict[str, list[ExecResult]]:
Expand Down Expand Up @@ -145,15 +148,15 @@ def execute(test: ExecTest, agent: Agent, timeout: int, parallel: bool) -> ExecR
time_eval = 0.0

with Manager() as manager:
result_dict = manager.dict()
sync_dict = manager.dict()
p = Process(
target=act_process,
args=(
agent,
test["files"],
test["prompt"],
result_dict,
test["name"],
test["prompt"],
test["files"],
sync_dict,
parallel,
),
)
Expand All @@ -173,10 +176,10 @@ def execute(test: ExecTest, agent: Agent, timeout: int, parallel: bool) -> ExecR
p.terminate()
p.join(timeout=1)

if "result" in result_dict:
result = result_dict["result"]
if "result" in sync_dict:
result = sync_dict["result"]
time_gen = max(result.get("duration", 0.0), time_gen)
status = result.get("status", "success")
status = result["status"]
files = result.get("files", {})
gen_stdout = result.get("stdout", "")
gen_stderr = result.get("stderr", "")
Expand Down Expand Up @@ -264,10 +267,10 @@ def getvalue(self):

def act_process(
agent: Agent,
files,
prompt,
result_dict: dict,
test_name: str,
prompt: str,
files: dict[str, str | bytes],
sync_dict: SyncedDict,
parallel: bool,
):
# Configure logging for this subprocess
Expand All @@ -290,13 +293,14 @@ def error_handler(e):
duration = time.time() - start
if not isinstance(e, KeyboardInterrupt):
subprocess_logger.error(f"Error: {e}")
result_dict["result"] = {
result_error: ProcessError = {
"status": "error",
"message": str(e),
"stdout": stdout.getvalue(),
"stderr": stderr.getvalue(),
"duration": duration,
}
sync_dict["result"] = result_error

# kill child processes
os.killpg(pgrp, signal.SIGKILL)
Expand All @@ -315,13 +319,14 @@ def sigterm_handler(*_):
return

duration = time.time() - start
result_dict["result"] = {
result_success: ProcessSuccess = {
"status": "success",
"files": files,
"stdout": stdout.getvalue(),
"stderr": stderr.getvalue(),
"duration": duration,
}
sync_dict["result"] = result_success
subprocess_logger.info("Success")

# kill child processes
Expand Down

0 comments on commit 6d00be7

Please sign in to comment.