Skip to content

Commit

Permalink
update metrics, metric means
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdeitke committed Mar 13, 2022
1 parent e39b923 commit bc49d47
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,13 +856,13 @@ def process_eval_package(

mode = pkg.mode

metrics = dict()
metric_means = dict()

if log_writer is not None:
log_writer.add_scalar(
f"{mode}-misc/num_tasks_evaled", num_tasks, training_steps
)
metrics[f"{mode}-misc/num_tasks_evaled"] = num_tasks
metric_means[f"{mode}-misc/num_tasks_evaled"] = num_tasks

message = [f"{mode} {training_steps} steps:"]
for k in sorted(metric_means.keys()):
Expand All @@ -871,7 +871,7 @@ def process_eval_package(
f"{mode}-metrics/{k}", metric_means[k], training_steps
)
message.append(f"{k} {metric_means[k]}")
metrics[f"{mode}-metrics/{k}"] = metric_means[k]
metric_means[f"{mode}-metrics/{k}"] = metric_means[k]

if all_results is not None:
results = copy.deepcopy(metric_means)
Expand All @@ -881,8 +881,11 @@ def process_eval_package(
message.append(f"tasks {num_tasks} checkpoint {checkpoint_file_name}")
get_logger().info(" ".join(message))

metrics = all_results[-1] if all_results else None
for callback in self.callbacks:
callback.on_valid_log(metric_means=metrics, step=training_steps)
callback.on_valid_log(
metric_means=metric_means, metrics=metrics, step=training_steps
)

if self.visualizer is not None:
self.visualizer.log(
Expand Down Expand Up @@ -1064,7 +1067,7 @@ def process_test_packages(
assert mode == TEST_MODE_STR

training_steps = pkgs[0].training_steps
metrics = dict()
metric_means = dict()

all_metrics_tracker = ScalarMeanTracker()
metric_dicts_list, render, checkpoint_file_name = [], {}, []
Expand All @@ -1088,7 +1091,7 @@ def process_test_packages(
f"{mode}-metrics/{k}", metric_means[k], training_steps
)
message.append(k + f" {metric_means[k]:.3g}")
metrics[f"{mode}-metrics/{k}"] = metric_means[k]
metric_means[f"{mode}-metrics/{k}"] = metric_means[k]

if all_results is not None:
results = copy.deepcopy(metric_means)
Expand All @@ -1102,14 +1105,14 @@ def process_test_packages(
log_writer.add_scalar(
f"{mode}-misc/num_tasks_evaled", num_tasks, training_steps
)
metrics[f"{mode}-misc/num_tasks_evaled"] = num_tasks
metric_means[f"{mode}-misc/num_tasks_evaled"] = num_tasks

message.append(f"tasks {num_tasks} checkpoint {checkpoint_file_name[0]}")
get_logger().info(" ".join(message))

for callback in self.callbacks:
callback.on_test_log(
metric_means=metrics,
metric_means=metric_means,
metrics=all_results[-1],
step=training_steps,
checkpoint=checkpoint_file_name[0],
Expand Down

0 comments on commit bc49d47

Please sign in to comment.