From 17e63ec735af0282fb7b50457471c9179fea9ce5 Mon Sep 17 00:00:00 2001 From: Yifan Mai Date: Wed, 31 May 2023 16:25:56 -0700 Subject: [PATCH] Skip completed display JSON (#1620) --- .../benchmark/presentation/run_display.py | 28 +++++++++++++++---- src/helm/benchmark/presentation/summarize.py | 11 ++++---- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/helm/benchmark/presentation/run_display.py b/src/helm/benchmark/presentation/run_display.py index ac9189fcdc5..fce259add11 100644 --- a/src/helm/benchmark/presentation/run_display.py +++ b/src/helm/benchmark/presentation/run_display.py @@ -16,7 +16,7 @@ from helm.benchmark.runner import RunSpec from helm.benchmark.scenarios.scenario import Instance from helm.common.general import write -from helm.common.hierarchical_logger import htrack +from helm.common.hierarchical_logger import hlog, htrack from helm.common.request import Request from helm.common.codec import from_json, to_json @@ -141,8 +141,13 @@ def _get_metric_names_for_groups(run_group_names: Iterable[str], schema: Schema) return result +_INSTANCES_JSON_FILE_NAME = "instances.json" +_DISPLAY_PREDICTIONS_JSON_FILE_NAME = "display_predictions.json" +_DISPLAY_REQUESTS_JSON_FILE_NAME = "display_requests.json" + + @htrack(None) -def write_run_display_json(run_path: str, run_spec: RunSpec, schema: Schema): +def write_run_display_json(run_path: str, run_spec: RunSpec, schema: Schema, skip_completed: bool) -> None: """Write run JSON files that are used by the web frontend. The derived JSON files that are used by the web frontend are much more compact than @@ -159,6 +164,18 @@ def write_run_display_json(run_path: str, run_spec: RunSpec, schema: Schema): - List[DisplayPrediction] to `display_predictions.json` - List[DisplayRequest] to `display_requests.json` """ + instances_file_path = os.path.join(run_path, _INSTANCES_JSON_FILE_NAME) + display_predictions_file_path = os.path.join(run_path, _DISPLAY_PREDICTIONS_JSON_FILE_NAME) + display_requests_file_path = os.path.join(run_path, _DISPLAY_REQUESTS_JSON_FILE_NAME) + + if ( + skip_completed + and os.path.exists(instances_file_path) + and os.path.exists(display_predictions_file_path) + and os.path.exists(display_requests_file_path) + ): + hlog(f"Skipping writing display JSON for run {run_spec.name} because all output display JSON files exist.") + return scenario_state = _read_scenario_state(run_path) per_instance_stats = _read_per_instance_stats(run_path) @@ -245,13 +262,12 @@ def write_run_display_json(run_path: str, run_spec: RunSpec, schema: Schema): request=request_state.request, ) ) - write( - os.path.join(run_path, "instances.json"), + instances_file_path, to_json(list(instance_id_to_instance.values())), ) - write(os.path.join(run_path, "display_predictions.json"), to_json(predictions)) + write(display_predictions_file_path, to_json(predictions)) write( - os.path.join(run_path, "display_requests.json"), + display_requests_file_path, to_json(requests), ) diff --git a/src/helm/benchmark/presentation/summarize.py b/src/helm/benchmark/presentation/summarize.py index 23edb1af465..5bf2a25cbff 100644 --- a/src/helm/benchmark/presentation/summarize.py +++ b/src/helm/benchmark/presentation/summarize.py @@ -939,9 +939,9 @@ def write_groups(self): json.dumps(list(map(asdict_without_nones, tables)), indent=2), ) - def write_run_display_json(self) -> None: + def write_run_display_json(self, skip_completed: bool) -> None: def process(run: Run) -> None: - write_run_display_json(run.run_path, run.run_spec, self.schema) + write_run_display_json(run.run_path, run.run_spec, self.schema, skip_completed) parallel_map(process, self.runs, parallelism=self.num_threads) @@ -978,9 +978,9 @@ def main(): help="Display debugging information.", ) parser.add_argument( - "--skip-write-run-display-json", + "--skip-completed-run-display-json", action="store_true", - help="Skip write_run_display_json", + help="Skip write_run_display_json() for runs which already have all output display JSON files", ) args = parser.parse_args() @@ -997,8 +997,7 @@ def main(): summarizer.write_groups() summarizer.write_cost_report() - if not args.skip_write_run_display_json: - summarizer.write_run_display_json() + summarizer.write_run_display_json(skip_completed=args.skip_completed_run_display_json) symlink_latest(args.output_path, args.suite) hlog("Done.")