Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Callback Support #339

Merged
merged 41 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
81f5f4a
add callback support
mattdeitke Mar 12, 2022
59bab94
update valid log
mattdeitke Mar 12, 2022
b9ddd05
add args to on policy runner
mattdeitke Mar 12, 2022
e39b923
add on_test_log and args to setup
mattdeitke Mar 12, 2022
bc49d47
update metrics, metric means
mattdeitke Mar 13, 2022
6fe333e
reuse metric means
mattdeitke Mar 13, 2022
b83dc70
add all metrics
mattdeitke Mar 13, 2022
52f2622
add way to pass in task_callback data
mattdeitke Mar 13, 2022
2ed3f8e
pass back pkg.metric_dicts for metrics
mattdeitke Mar 13, 2022
333dcd6
fix metrics results
mattdeitke Mar 13, 2022
5899282
reset single_process_task_callbacks_data after use
mattdeitke Mar 13, 2022
cd7bf18
send metric means when not using tensorboard
mattdeitke Apr 2, 2022
6824c9f
convert means/stds to arrays
mattdeitke Apr 3, 2022
38e5e57
add metrics to on_train_log
Apr 7, 2022
298acda
Merge branch 'main' into callbacks
mattdeitke Apr 9, 2022
f6e7abd
pass checkpoint into val
Apr 11, 2022
943a246
Merge branch 'callbacks' of https://github.com/allenai/allenact into …
Apr 11, 2022
698e56b
add after_save_project_state callback
Apr 11, 2022
3a4f2a5
add distributed_preemption_threshold to OnPolicyRunner
Apr 13, 2022
030d2ce
merge main into callbacks
Apr 28, 2022
79b7bbb
Look constraints, vit, and auxiliary losses.
Lucaweihs May 20, 2022
41a7d2c
Fixing bug that resulted in cpca loss not running when it should.
Lucaweihs May 21, 2022
2f3f83e
Fixing merge conflict in engine.py
Lucaweihs Jun 2, 2022
85a81c8
Merge branch 'main' of github.com:allenai/allenact into cpca-softmax
Lucaweihs Jun 2, 2022
7badab7
Merge branch 'cpca-softmax' into callbacks-cpca-softmax
Lucaweihs Jun 3, 2022
90fe75c
Improvements to manipulathor plugin, thor sensors, inference code, an…
Lucaweihs Jul 8, 2022
dea48a9
Improvements to error reporting and vector sampled task termination.
Lucaweihs Jul 18, 2022
d85853f
Extended timeout for TCP store access
jordis-ai2 Jul 28, 2022
085739a
Removing argument parser input to the runner (previously passed to ca…
Lucaweihs Jul 29, 2022
1fe09a4
Generalizing clip preprocessor to use varying input sizes.
Lucaweihs Jul 29, 2022
9016229
Updating mixin to check for input image heights/widths before passing…
Lucaweihs Jul 29, 2022
2b538f5
Force advance scene for habitat objectnav.
Lucaweihs Jul 29, 2022
60f2d8f
Merge branch 'callbacks-cpca-softmax' of github.com:allenai/allenact …
Lucaweihs Jul 29, 2022
7ef8015
Merge pull request #347 from allenai/callbacks-merge
Lucaweihs Jul 29, 2022
72e4fa4
Fixing `callbacks` pr requests, in particular making it so that a tas…
Lucaweihs Jul 29, 2022
bd87a6f
Typo.
Lucaweihs Jul 29, 2022
3e51c45
Grabbing debug flag from the environment.
Lucaweihs Aug 2, 2022
4812860
Improving grab of debug flag.
Lucaweihs Aug 2, 2022
d05d140
Using str2bool function.
Lucaweihs Aug 2, 2022
1cc558c
Fix to ai2thor version check.
Lucaweihs Aug 8, 2022
e89eae1
Merge pull request #350 from allenai/callbacks-cpca-softmax
Lucaweihs Aug 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 37 additions & 34 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,20 @@
import traceback
from functools import partial
from multiprocessing.context import BaseContext
from typing import (
Optional,
Any,
Dict,
Union,
List,
cast,
Sequence,
)
from typing import Any, Dict, List, Optional, Sequence, Union, cast

import torch
import torch.distributed as dist # type: ignore
import torch.distributions # type: ignore
import torch.multiprocessing as mp # type: ignore
import torch.nn as nn
import torch.optim as optim
from allenact.algorithms.onpolicy_sync.misc import TrackingInfo, TrackingInfoType
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throughout the code base we try to keep third party imports before allenact ones, so I would move this a few lines below.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're using PyCharm I think this would be handled by running Code -> Optimize Imports.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I will revert these import updates.

I am using isort, which works well with Black and is pretty popular for sorting and organizing imports. It autoformats on save for VSCode so it ended up changing the import order automatically.

from allenact.utils.model_utils import md5_hash_of_state_dict

# noinspection PyProtectedMember
from torch._C._distributed_c10d import ReduceOp

from allenact.algorithms.onpolicy_sync.misc import TrackingInfoType, TrackingInfo
from allenact.utils.model_utils import md5_hash_of_state_dict

try:
# noinspection PyProtectedMember,PyUnresolvedReferences
from torch.optim.lr_scheduler import _LRScheduler
Expand All @@ -44,37 +35,35 @@
from allenact.algorithms.onpolicy_sync.storage import (
ExperienceStorage,
MiniBatchStorageMixin,
StreamingStorageMixin,
RolloutStorage,
StreamingStorageMixin,
)
from allenact.algorithms.onpolicy_sync.vector_sampled_tasks import (
VectorSampledTasks,
COMPLETE_TASK_CALLBACK_KEY,
COMPLETE_TASK_METRICS_KEY,
SingleProcessVectorSampledTasks,
VectorSampledTasks,
)
from allenact.base_abstractions.distributions import TeacherForcingDistr
from allenact.base_abstractions.experiment_config import ExperimentConfig, MachineParams
from allenact.base_abstractions.misc import (
RLStepResult,
Memory,
ActorCriticOutput,
GenericAbstractLoss,
Memory,
RLStepResult,
)
from allenact.base_abstractions.distributions import TeacherForcingDistr
from allenact.utils import spaces_utils as su
from allenact.utils.experiment_utils import (
set_seed,
TrainingPipeline,
LoggingPackage,
PipelineStage,
set_deterministic_cudnn,
ScalarMeanTracker,
StageComponent,
TrainingPipeline,
set_deterministic_cudnn,
set_seed,
)
from allenact.utils.system import get_logger
from allenact.utils.tensor_utils import (
batch_observations,
detach_recursively,
)
from allenact.utils.tensor_utils import batch_observations, detach_recursively
from allenact.utils.viz_utils import VizSuite

TRAIN_MODE_STR = "train"
Expand Down Expand Up @@ -262,6 +251,7 @@ def __init__(

# Keeping track of metrics during training/inference
self.single_process_metrics: List = []
self.single_process_task_callback_data: List = []

@property
def vector_tasks(
Expand Down Expand Up @@ -572,14 +562,17 @@ def collect_step_across_all_task_samplers(

# Save after task completion metrics
for step_result in outputs:
if (
step_result.info is not None
and COMPLETE_TASK_METRICS_KEY in step_result.info
):
self.single_process_metrics.append(
step_result.info[COMPLETE_TASK_METRICS_KEY]
)
del step_result.info[COMPLETE_TASK_METRICS_KEY]
if step_result.info is not None:
if COMPLETE_TASK_METRICS_KEY in step_result.info:
self.single_process_metrics.append(
step_result.info[COMPLETE_TASK_METRICS_KEY]
)
del step_result.info[COMPLETE_TASK_METRICS_KEY]
if COMPLETE_TASK_CALLBACK_KEY in step_result.info:
self.single_process_task_callback_data.append(
step_result.info[COMPLETE_TASK_CALLBACK_KEY]
)
del step_result.info[COMPLETE_TASK_CALLBACK_KEY]

rewards: Union[List, torch.Tensor]
observations, rewards, dones, infos = [list(x) for x in zip(*outputs)]
Expand Down Expand Up @@ -1271,6 +1264,10 @@ def aggregate_and_send_logging_package(

self.aggregate_task_metrics(logging_pkg=logging_pkg)

for callback_dict in self.single_process_task_callback_data:
logging_pkg.task_callback_data.append(callback_dict)
self.single_process_task_callback_data = []

if self.mode == TRAIN_MODE_STR:
# Technically self.mode should always be "train" here (as this is the training engine),
# this conditional is defensive
Expand Down Expand Up @@ -1667,7 +1664,8 @@ def run_eval(
assert visualizer.empty()

num_paused = self.initialize_storage_and_viz(
storage_to_initialize=[rollout_storage], visualizer=visualizer,
storage_to_initialize=[rollout_storage],
visualizer=visualizer,
)
assert num_paused == 0, f"{num_paused} tasks paused when initializing eval"

Expand Down Expand Up @@ -1736,7 +1734,8 @@ def run_eval(
lengths: List[int]
if self.num_active_samplers > 0:
lengths = self.vector_tasks.command(
"sampler_attr", ["length"] * self.num_active_samplers,
"sampler_attr",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see many formatting changes - are you also using black to ensure consistency?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, note that we're using version 19.10b0 of black.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, using Black, but must be using some different preferences with it (such as the number of characters in a line). I will try with 19.10b0 :)

["length"] * self.num_active_samplers,
)
npending = sum(lengths)
else:
Expand Down Expand Up @@ -1786,6 +1785,10 @@ def run_eval(

self.aggregate_task_metrics(logging_pkg=logging_pkg)

for callback_dict in self.single_process_task_callback_data:
logging_pkg.task_callback_data.append(callback_dict)
self.single_process_task_callback_data = []

logging_pkg.viz_data = (
visualizer.read_and_reset() if visualizer is not None else None
)
Expand Down
Loading