-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Callback Support #339
Add Callback Support #339
Changes from 19 commits
81f5f4a
59bab94
b9ddd05
e39b923
bc49d47
6fe333e
b83dc70
52f2622
2ed3f8e
333dcd6
5899282
cd7bf18
6824c9f
38e5e57
298acda
f6e7abd
943a246
698e56b
3a4f2a5
030d2ce
79b7bbb
41a7d2c
2f3f83e
85a81c8
7badab7
90fe75c
dea48a9
d85853f
085739a
1fe09a4
9016229
2b538f5
60f2d8f
7ef8015
72e4fa4
bd87a6f
3e51c45
4812860
d05d140
1cc558c
e89eae1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,29 +8,20 @@ | |
import traceback | ||
from functools import partial | ||
from multiprocessing.context import BaseContext | ||
from typing import ( | ||
Optional, | ||
Any, | ||
Dict, | ||
Union, | ||
List, | ||
cast, | ||
Sequence, | ||
) | ||
from typing import Any, Dict, List, Optional, Sequence, Union, cast | ||
|
||
import torch | ||
import torch.distributed as dist # type: ignore | ||
import torch.distributions # type: ignore | ||
import torch.multiprocessing as mp # type: ignore | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from allenact.algorithms.onpolicy_sync.misc import TrackingInfo, TrackingInfoType | ||
from allenact.utils.model_utils import md5_hash_of_state_dict | ||
|
||
# noinspection PyProtectedMember | ||
from torch._C._distributed_c10d import ReduceOp | ||
|
||
from allenact.algorithms.onpolicy_sync.misc import TrackingInfoType, TrackingInfo | ||
from allenact.utils.model_utils import md5_hash_of_state_dict | ||
|
||
try: | ||
# noinspection PyProtectedMember,PyUnresolvedReferences | ||
from torch.optim.lr_scheduler import _LRScheduler | ||
|
@@ -44,37 +35,35 @@ | |
from allenact.algorithms.onpolicy_sync.storage import ( | ||
ExperienceStorage, | ||
MiniBatchStorageMixin, | ||
StreamingStorageMixin, | ||
RolloutStorage, | ||
StreamingStorageMixin, | ||
) | ||
from allenact.algorithms.onpolicy_sync.vector_sampled_tasks import ( | ||
VectorSampledTasks, | ||
COMPLETE_TASK_CALLBACK_KEY, | ||
COMPLETE_TASK_METRICS_KEY, | ||
SingleProcessVectorSampledTasks, | ||
VectorSampledTasks, | ||
) | ||
from allenact.base_abstractions.distributions import TeacherForcingDistr | ||
from allenact.base_abstractions.experiment_config import ExperimentConfig, MachineParams | ||
from allenact.base_abstractions.misc import ( | ||
RLStepResult, | ||
Memory, | ||
ActorCriticOutput, | ||
GenericAbstractLoss, | ||
Memory, | ||
RLStepResult, | ||
) | ||
from allenact.base_abstractions.distributions import TeacherForcingDistr | ||
from allenact.utils import spaces_utils as su | ||
from allenact.utils.experiment_utils import ( | ||
set_seed, | ||
TrainingPipeline, | ||
LoggingPackage, | ||
PipelineStage, | ||
set_deterministic_cudnn, | ||
ScalarMeanTracker, | ||
StageComponent, | ||
TrainingPipeline, | ||
set_deterministic_cudnn, | ||
set_seed, | ||
) | ||
from allenact.utils.system import get_logger | ||
from allenact.utils.tensor_utils import ( | ||
batch_observations, | ||
detach_recursively, | ||
) | ||
from allenact.utils.tensor_utils import batch_observations, detach_recursively | ||
from allenact.utils.viz_utils import VizSuite | ||
|
||
TRAIN_MODE_STR = "train" | ||
|
@@ -262,6 +251,7 @@ def __init__( | |
|
||
# Keeping track of metrics during training/inference | ||
self.single_process_metrics: List = [] | ||
self.single_process_task_callback_data: List = [] | ||
|
||
@property | ||
def vector_tasks( | ||
|
@@ -572,14 +562,17 @@ def collect_step_across_all_task_samplers( | |
|
||
# Save after task completion metrics | ||
for step_result in outputs: | ||
if ( | ||
step_result.info is not None | ||
and COMPLETE_TASK_METRICS_KEY in step_result.info | ||
): | ||
self.single_process_metrics.append( | ||
step_result.info[COMPLETE_TASK_METRICS_KEY] | ||
) | ||
del step_result.info[COMPLETE_TASK_METRICS_KEY] | ||
if step_result.info is not None: | ||
if COMPLETE_TASK_METRICS_KEY in step_result.info: | ||
self.single_process_metrics.append( | ||
step_result.info[COMPLETE_TASK_METRICS_KEY] | ||
) | ||
del step_result.info[COMPLETE_TASK_METRICS_KEY] | ||
if COMPLETE_TASK_CALLBACK_KEY in step_result.info: | ||
self.single_process_task_callback_data.append( | ||
step_result.info[COMPLETE_TASK_CALLBACK_KEY] | ||
) | ||
del step_result.info[COMPLETE_TASK_CALLBACK_KEY] | ||
|
||
rewards: Union[List, torch.Tensor] | ||
observations, rewards, dones, infos = [list(x) for x in zip(*outputs)] | ||
|
@@ -1271,6 +1264,10 @@ def aggregate_and_send_logging_package( | |
|
||
self.aggregate_task_metrics(logging_pkg=logging_pkg) | ||
|
||
for callback_dict in self.single_process_task_callback_data: | ||
logging_pkg.task_callback_data.append(callback_dict) | ||
self.single_process_task_callback_data = [] | ||
|
||
if self.mode == TRAIN_MODE_STR: | ||
# Technically self.mode should always be "train" here (as this is the training engine), | ||
# this conditional is defensive | ||
|
@@ -1667,7 +1664,8 @@ def run_eval( | |
assert visualizer.empty() | ||
|
||
num_paused = self.initialize_storage_and_viz( | ||
storage_to_initialize=[rollout_storage], visualizer=visualizer, | ||
storage_to_initialize=[rollout_storage], | ||
visualizer=visualizer, | ||
) | ||
assert num_paused == 0, f"{num_paused} tasks paused when initializing eval" | ||
|
||
|
@@ -1736,7 +1734,8 @@ def run_eval( | |
lengths: List[int] | ||
if self.num_active_samplers > 0: | ||
lengths = self.vector_tasks.command( | ||
"sampler_attr", ["length"] * self.num_active_samplers, | ||
"sampler_attr", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see many formatting changes - are you also using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, note that we're using version There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, using Black, but must be using some different preferences with it (such as the number of characters in a line). I will try with |
||
["length"] * self.num_active_samplers, | ||
) | ||
npending = sum(lengths) | ||
else: | ||
|
@@ -1786,6 +1785,10 @@ def run_eval( | |
|
||
self.aggregate_task_metrics(logging_pkg=logging_pkg) | ||
|
||
for callback_dict in self.single_process_task_callback_data: | ||
logging_pkg.task_callback_data.append(callback_dict) | ||
self.single_process_task_callback_data = [] | ||
|
||
logging_pkg.viz_data = ( | ||
visualizer.read_and_reset() if visualizer is not None else None | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Throughout the code base we try to keep third party imports before allenact ones, so I would move this a few lines below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're using PyCharm I think this would be handled by running
Code -> Optimize Imports
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I will revert these import updates.
I am using isort, which works well with Black and is pretty popular for sorting and organizing imports. It autoformats on save for VSCode so it ended up changing the import order automatically.