From e39b923792723ac28691f2d3d5a92eed5dbdd291 Mon Sep 17 00:00:00 2001 From: Matt Deitke Date: Sat, 12 Mar 2022 15:52:14 -0800 Subject: [PATCH] add on_test_log and args to setup --- allenact/algorithms/onpolicy_sync/runner.py | 34 +++++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/allenact/algorithms/onpolicy_sync/runner.py b/allenact/algorithms/onpolicy_sync/runner.py index a2cc651ac..1afef8171 100644 --- a/allenact/algorithms/onpolicy_sync/runner.py +++ b/allenact/algorithms/onpolicy_sync/runner.py @@ -16,6 +16,7 @@ import sys import time import traceback +from argparse import ArgumentParser from collections import defaultdict from multiprocessing.context import BaseContext from multiprocessing.process import BaseProcess @@ -91,6 +92,7 @@ def __init__( machine_id: int = 0, save_dir_fmt: SaveDirFormat = SaveDirFormat.FLAT, callbacks: str = "", + args: Optional[ArgumentParser] = None, ): self.config = config self.output_dir = output_dir @@ -140,7 +142,8 @@ def __init__( self.save_dir_fmt = save_dir_fmt - self.callbacks = self.get_callback_classes(callbacks) + self.callbacks = self.get_callback_classes(callbacks, args=args) + self.args = args @property def local_start_time_str(self) -> str: @@ -183,12 +186,17 @@ def init_context( return mp_ctx - @staticmethod - def get_callback_classes(callbacks: str) -> List[Callback]: + def get_callback_classes( + self, callbacks: str, args: Optional[ArgumentParser] = None + ) -> List[Callback]: """Get a list of Callback classes from a comma-separated list of filenames.""" if callbacks == "": return [] + setup_dict = dict(name=self.experiment_name, config=self.config, mode=self.mode) + if args is not None: + setup_dict.update(vars(args)) + callback_classes = set() files = callbacks.split(",") for i, filename in enumerate(files): @@ -205,7 +213,7 @@ def get_callback_classes(callbacks: str) -> List[Callback]: if issubclass(mod_class[1], Callback) and mod_class[1] != Callback: # NOTE: initialize the callback class inst_class = mod_class[1]() - inst_class.setup() + inst_class.setup(**setup_dict) callback_classes.add(inst_class) return callback_classes @@ -587,6 +595,9 @@ def start_test( assert ( self.machine_id == 0 ), f"Received `machine_id={self.machine_id} for test. Only one machine supported." + assert ( + checkpoint_path_dir_or_pattern is not None + ), "Must provide a --checkpoint path or pattern to test on." self.extra_tag += ( "__" * (len(self.extra_tag) > 0) + "enforced_test_expert" @@ -871,7 +882,7 @@ def process_eval_package( get_logger().info(" ".join(message)) for callback in self.callbacks: - callback.on_valid_log(metrics=metrics, step=training_steps) + callback.on_valid_log(metric_means=metrics, step=training_steps) if self.visualizer is not None: self.visualizer.log( @@ -990,7 +1001,7 @@ def _convert(key: str): means = metrics_and_train_info_tracker.means() for callback in self.callbacks: - callback.on_train_log(metrics=means, step=training_steps) + callback.on_train_log(metric_means=means, step=training_steps) for k in sorted( means.keys(), key=lambda mean_key: (mean_key.count("/"), mean_key) @@ -1053,6 +1064,7 @@ def process_test_packages( assert mode == TEST_MODE_STR training_steps = pkgs[0].training_steps + metrics = dict() all_metrics_tracker = ScalarMeanTracker() metric_dicts_list, render, checkpoint_file_name = [], {}, [] @@ -1076,6 +1088,7 @@ def process_test_packages( f"{mode}-metrics/{k}", metric_means[k], training_steps ) message.append(k + f" {metric_means[k]:.3g}") + metrics[f"{mode}-metrics/{k}"] = metric_means[k] if all_results is not None: results = copy.deepcopy(metric_means) @@ -1089,10 +1102,19 @@ def process_test_packages( log_writer.add_scalar( f"{mode}-misc/num_tasks_evaled", num_tasks, training_steps ) + metrics[f"{mode}-misc/num_tasks_evaled"] = num_tasks message.append(f"tasks {num_tasks} checkpoint {checkpoint_file_name[0]}") get_logger().info(" ".join(message)) + for callback in self.callbacks: + callback.on_test_log( + metric_means=metrics, + metrics=all_results[-1], + step=training_steps, + checkpoint=checkpoint_file_name[0], + ) + if self.visualizer is not None: self.visualizer.log( log_writer=log_writer,