Skip to content

Commit

Permalink
add on_test_log and args to setup
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdeitke committed Mar 12, 2022
1 parent b9ddd05 commit e39b923
Showing 1 changed file with 28 additions and 6 deletions.
34 changes: 28 additions & 6 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
import time
import traceback
from argparse import ArgumentParser
from collections import defaultdict
from multiprocessing.context import BaseContext
from multiprocessing.process import BaseProcess
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
machine_id: int = 0,
save_dir_fmt: SaveDirFormat = SaveDirFormat.FLAT,
callbacks: str = "",
args: Optional[ArgumentParser] = None,
):
self.config = config
self.output_dir = output_dir
Expand Down Expand Up @@ -140,7 +142,8 @@ def __init__(

self.save_dir_fmt = save_dir_fmt

self.callbacks = self.get_callback_classes(callbacks)
self.callbacks = self.get_callback_classes(callbacks, args=args)
self.args = args

@property
def local_start_time_str(self) -> str:
Expand Down Expand Up @@ -183,12 +186,17 @@ def init_context(

return mp_ctx

@staticmethod
def get_callback_classes(callbacks: str) -> List[Callback]:
def get_callback_classes(
self, callbacks: str, args: Optional[ArgumentParser] = None
) -> List[Callback]:
"""Get a list of Callback classes from a comma-separated list of filenames."""
if callbacks == "":
return []

setup_dict = dict(name=self.experiment_name, config=self.config, mode=self.mode)
if args is not None:
setup_dict.update(vars(args))

callback_classes = set()
files = callbacks.split(",")
for i, filename in enumerate(files):
Expand All @@ -205,7 +213,7 @@ def get_callback_classes(callbacks: str) -> List[Callback]:
if issubclass(mod_class[1], Callback) and mod_class[1] != Callback:
# NOTE: initialize the callback class
inst_class = mod_class[1]()
inst_class.setup()
inst_class.setup(**setup_dict)
callback_classes.add(inst_class)

return callback_classes
Expand Down Expand Up @@ -587,6 +595,9 @@ def start_test(
assert (
self.machine_id == 0
), f"Received `machine_id={self.machine_id} for test. Only one machine supported."
assert (
checkpoint_path_dir_or_pattern is not None
), "Must provide a --checkpoint path or pattern to test on."

self.extra_tag += (
"__" * (len(self.extra_tag) > 0) + "enforced_test_expert"
Expand Down Expand Up @@ -871,7 +882,7 @@ def process_eval_package(
get_logger().info(" ".join(message))

for callback in self.callbacks:
callback.on_valid_log(metrics=metrics, step=training_steps)
callback.on_valid_log(metric_means=metrics, step=training_steps)

if self.visualizer is not None:
self.visualizer.log(
Expand Down Expand Up @@ -990,7 +1001,7 @@ def _convert(key: str):
means = metrics_and_train_info_tracker.means()

for callback in self.callbacks:
callback.on_train_log(metrics=means, step=training_steps)
callback.on_train_log(metric_means=means, step=training_steps)

for k in sorted(
means.keys(), key=lambda mean_key: (mean_key.count("/"), mean_key)
Expand Down Expand Up @@ -1053,6 +1064,7 @@ def process_test_packages(
assert mode == TEST_MODE_STR

training_steps = pkgs[0].training_steps
metrics = dict()

all_metrics_tracker = ScalarMeanTracker()
metric_dicts_list, render, checkpoint_file_name = [], {}, []
Expand All @@ -1076,6 +1088,7 @@ def process_test_packages(
f"{mode}-metrics/{k}", metric_means[k], training_steps
)
message.append(k + f" {metric_means[k]:.3g}")
metrics[f"{mode}-metrics/{k}"] = metric_means[k]

if all_results is not None:
results = copy.deepcopy(metric_means)
Expand All @@ -1089,10 +1102,19 @@ def process_test_packages(
log_writer.add_scalar(
f"{mode}-misc/num_tasks_evaled", num_tasks, training_steps
)
metrics[f"{mode}-misc/num_tasks_evaled"] = num_tasks

message.append(f"tasks {num_tasks} checkpoint {checkpoint_file_name[0]}")
get_logger().info(" ".join(message))

for callback in self.callbacks:
callback.on_test_log(
metric_means=metrics,
metrics=all_results[-1],
step=training_steps,
checkpoint=checkpoint_file_name[0],
)

if self.visualizer is not None:
self.visualizer.log(
log_writer=log_writer,
Expand Down

0 comments on commit e39b923

Please sign in to comment.