Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Callback Support #339

Merged
merged 41 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
81f5f4a
add callback support
mattdeitke Mar 12, 2022
59bab94
update valid log
mattdeitke Mar 12, 2022
b9ddd05
add args to on policy runner
mattdeitke Mar 12, 2022
e39b923
add on_test_log and args to setup
mattdeitke Mar 12, 2022
bc49d47
update metrics, metric means
mattdeitke Mar 13, 2022
6fe333e
reuse metric means
mattdeitke Mar 13, 2022
b83dc70
add all metrics
mattdeitke Mar 13, 2022
52f2622
add way to pass in task_callback data
mattdeitke Mar 13, 2022
2ed3f8e
pass back pkg.metric_dicts for metrics
mattdeitke Mar 13, 2022
333dcd6
fix metrics results
mattdeitke Mar 13, 2022
5899282
reset single_process_task_callbacks_data after use
mattdeitke Mar 13, 2022
cd7bf18
send metric means when not using tensorboard
mattdeitke Apr 2, 2022
6824c9f
convert means/stds to arrays
mattdeitke Apr 3, 2022
38e5e57
add metrics to on_train_log
Apr 7, 2022
298acda
Merge branch 'main' into callbacks
mattdeitke Apr 9, 2022
f6e7abd
pass checkpoint into val
Apr 11, 2022
943a246
Merge branch 'callbacks' of https://github.com/allenai/allenact into …
Apr 11, 2022
698e56b
add after_save_project_state callback
Apr 11, 2022
3a4f2a5
add distributed_preemption_threshold to OnPolicyRunner
Apr 13, 2022
030d2ce
merge main into callbacks
Apr 28, 2022
79b7bbb
Look constraints, vit, and auxiliary losses.
Lucaweihs May 20, 2022
41a7d2c
Fixing bug that resulted in cpca loss not running when it should.
Lucaweihs May 21, 2022
2f3f83e
Fixing merge conflict in engine.py
Lucaweihs Jun 2, 2022
85a81c8
Merge branch 'main' of github.com:allenai/allenact into cpca-softmax
Lucaweihs Jun 2, 2022
7badab7
Merge branch 'cpca-softmax' into callbacks-cpca-softmax
Lucaweihs Jun 3, 2022
90fe75c
Improvements to manipulathor plugin, thor sensors, inference code, an…
Lucaweihs Jul 8, 2022
dea48a9
Improvements to error reporting and vector sampled task termination.
Lucaweihs Jul 18, 2022
d85853f
Extended timeout for TCP store access
jordis-ai2 Jul 28, 2022
085739a
Removing argument parser input to the runner (previously passed to ca…
Lucaweihs Jul 29, 2022
1fe09a4
Generalizing clip preprocessor to use varying input sizes.
Lucaweihs Jul 29, 2022
9016229
Updating mixin to check for input image heights/widths before passing…
Lucaweihs Jul 29, 2022
2b538f5
Force advance scene for habitat objectnav.
Lucaweihs Jul 29, 2022
60f2d8f
Merge branch 'callbacks-cpca-softmax' of github.com:allenai/allenact …
Lucaweihs Jul 29, 2022
7ef8015
Merge pull request #347 from allenai/callbacks-merge
Lucaweihs Jul 29, 2022
72e4fa4
Fixing `callbacks` pr requests, in particular making it so that a tas…
Lucaweihs Jul 29, 2022
bd87a6f
Typo.
Lucaweihs Jul 29, 2022
3e51c45
Grabbing debug flag from the environment.
Lucaweihs Aug 2, 2022
4812860
Improving grab of debug flag.
Lucaweihs Aug 2, 2022
d05d140
Using str2bool function.
Lucaweihs Aug 2, 2022
1cc558c
Fix to ai2thor version check.
Lucaweihs Aug 8, 2022
e89eae1
Merge pull request #350 from allenai/callbacks-cpca-softmax
Lucaweihs Aug 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 97 additions & 46 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,9 @@
import traceback
from functools import partial
from multiprocessing.context import BaseContext
from typing import (
Optional,
Any,
Dict,
Union,
List,
cast,
Sequence,
)
from typing import Any, Dict, List, Optional, Sequence, Union, cast

import filelock
import torch
import torch.distributed as dist # type: ignore
import torch.distributions # type: ignore
Expand All @@ -27,7 +20,9 @@
# noinspection PyProtectedMember
from torch._C._distributed_c10d import ReduceOp

from allenact.algorithms.onpolicy_sync.misc import TrackingInfoType, TrackingInfo
from allenact.algorithms.onpolicy_sync.misc import TrackingInfo, TrackingInfoType
from allenact.base_abstractions.sensor import Sensor
from allenact.utils.misc_utils import str2bool
from allenact.utils.model_utils import md5_hash_of_state_dict

try:
Expand All @@ -43,37 +38,35 @@
from allenact.algorithms.onpolicy_sync.storage import (
ExperienceStorage,
MiniBatchStorageMixin,
StreamingStorageMixin,
RolloutStorage,
StreamingStorageMixin,
)
from allenact.algorithms.onpolicy_sync.vector_sampled_tasks import (
VectorSampledTasks,
COMPLETE_TASK_CALLBACK_KEY,
COMPLETE_TASK_METRICS_KEY,
SingleProcessVectorSampledTasks,
VectorSampledTasks,
)
from allenact.base_abstractions.distributions import TeacherForcingDistr
from allenact.base_abstractions.experiment_config import ExperimentConfig, MachineParams
from allenact.base_abstractions.misc import (
RLStepResult,
Memory,
ActorCriticOutput,
GenericAbstractLoss,
Memory,
RLStepResult,
)
from allenact.base_abstractions.distributions import TeacherForcingDistr
from allenact.utils import spaces_utils as su
from allenact.utils.experiment_utils import (
set_seed,
TrainingPipeline,
LoggingPackage,
PipelineStage,
set_deterministic_cudnn,
ScalarMeanTracker,
StageComponent,
TrainingPipeline,
set_deterministic_cudnn,
set_seed,
)
from allenact.utils.system import get_logger
from allenact.utils.tensor_utils import (
batch_observations,
detach_recursively,
)
from allenact.utils.tensor_utils import batch_observations, detach_recursively
from allenact.utils.viz_utils import VizSuite

try:
Expand All @@ -82,9 +75,9 @@
# noinspection PyPackageRequirements
import pydevd

DEBUGGING = True
DEBUGGING = str2bool(os.getenv("ALLENACT_DEBUG", "true"))
except ImportError:
DEBUGGING = False
DEBUGGING = str2bool(os.getenv("ALLENACT_DEBUG", "false"))

DEBUG_VST_TIMEOUT: Optional[int] = (lambda x: int(x) if x is not None else x)(
os.getenv("ALLENACT_DEBUG_VST_TIMEOUT", None)
Expand Down Expand Up @@ -115,6 +108,7 @@ def __init__(
], # to write/read (trainer/evaluator) ready checkpoints
checkpoints_dir: str,
mode: str = "train",
callback_sensors: Optional[Sequence[Sensor]] = None,
seed: Optional[int] = None,
deterministic_cudnn: bool = False,
mp_ctx: Optional[BaseContext] = None,
Expand All @@ -126,7 +120,7 @@ def __init__(
deterministic_agents: bool = False,
max_sampler_processes_per_worker: Optional[int] = None,
initial_model_state_dict: Optional[Union[Dict[str, Any], int]] = None,
try_restart_after_task_timeout: bool = False,
try_restart_after_task_error: bool = False,
**kwargs,
):
"""Initializer.
Expand Down Expand Up @@ -154,7 +148,7 @@ def __init__(
self.device = torch.device("cpu") if device == -1 else torch.device(device) # type: ignore
self.distributed_ip = distributed_ip
self.distributed_port = distributed_port
self.try_restart_after_task_timeout = try_restart_after_task_timeout
self.try_restart_after_task_error = try_restart_after_task_error

self.mode = mode.lower().strip()
assert self.mode in [
Expand All @@ -163,6 +157,7 @@ def __init__(
TEST_MODE_STR,
], 'Only "train", "valid", "test" modes supported'

self.callback_sensors = callback_sensors
self.deterministic_cudnn = deterministic_cudnn
if self.deterministic_cudnn:
set_deterministic_cudnn()
Expand Down Expand Up @@ -245,6 +240,9 @@ def __init__(
port=self.distributed_port,
world_size=self.num_workers,
is_master=self.worker_id == 0,
timeout=datetime.timedelta(
seconds=3 * (DEBUG_VST_TIMEOUT if DEBUGGING else 1 * 60) + 300
),
)
cpu_device = self.device == torch.device("cpu") # type:ignore

Expand Down Expand Up @@ -277,6 +275,7 @@ def __init__(

# Keeping track of metrics during training/inference
self.single_process_metrics: List = []
self.single_process_task_callback_data: List = []

@property
def vector_tasks(
Expand Down Expand Up @@ -310,12 +309,13 @@ def vector_tasks(
self._vector_tasks = VectorSampledTasks(
make_sampler_fn=self.config.make_sampler_fn,
sampler_fn_args=self.get_sampler_fn_args(seeds),
callback_sensors=self.callback_sensors,
multiprocessing_start_method="forkserver"
if self.mp_ctx is None
else None,
mp_ctx=self.mp_ctx,
max_processes=self.max_sampler_processes_per_worker,
read_timeout=DEBUG_VST_TIMEOUT if DEBUGGING else 5 * 60,
read_timeout=DEBUG_VST_TIMEOUT if DEBUGGING else 1 * 60,
)
return self._vector_tasks

Expand Down Expand Up @@ -588,14 +588,17 @@ def collect_step_across_all_task_samplers(

# Save after task completion metrics
for step_result in outputs:
if (
step_result.info is not None
and COMPLETE_TASK_METRICS_KEY in step_result.info
):
self.single_process_metrics.append(
step_result.info[COMPLETE_TASK_METRICS_KEY]
)
del step_result.info[COMPLETE_TASK_METRICS_KEY]
if step_result.info is not None:
if COMPLETE_TASK_METRICS_KEY in step_result.info:
self.single_process_metrics.append(
step_result.info[COMPLETE_TASK_METRICS_KEY]
)
del step_result.info[COMPLETE_TASK_METRICS_KEY]
if COMPLETE_TASK_CALLBACK_KEY in step_result.info:
self.single_process_task_callback_data.append(
step_result.info[COMPLETE_TASK_CALLBACK_KEY]
)
del step_result.info[COMPLETE_TASK_CALLBACK_KEY]

rewards: Union[List, torch.Tensor]
observations, rewards, dones, infos = [list(x) for x in zip(*outputs)]
Expand Down Expand Up @@ -856,6 +859,36 @@ def deterministic_seeds(self) -> None:
) # use the latest seed for workers and update rng state
self.vector_tasks.set_seeds(seeds)

def save_error_data(self, batch: Dict[str, Any]) -> str:
model_path = os.path.join(
self.checkpoints_dir,
"error_for_exp_{}__stage_{:02d}__steps_{:012d}.pt".format(
self.experiment_name,
self.training_pipeline.current_stage_index,
self.training_pipeline.total_steps,
),
)
with filelock.FileLock(
os.path.join(self.checkpoints_dir, "error.lock"), timeout=60
):
if not os.path.exists(model_path):
save_dict = {
"model_state_dict": self.actor_critic.state_dict(), # type:ignore
"total_steps": self.training_pipeline.total_steps, # Total steps including current stage
"optimizer_state_dict": self.optimizer.state_dict(), # type: ignore
"training_pipeline_state_dict": self.training_pipeline.state_dict(),
"trainer_seed": self.seed,
"batch": batch,
}

if self.lr_scheduler is not None:
save_dict["scheduler_state"] = cast(
_LRScheduler, self.lr_scheduler
).state_dict()

torch.save(save_dict, model_path)
return model_path

def checkpoint_save(self, pipeline_stage_index: Optional[int] = None) -> str:
model_path = os.path.join(
self.checkpoints_dir,
Expand Down Expand Up @@ -1124,12 +1157,21 @@ def single_batch_generator(streaming_storage: StreamingStorageMixin):
bsize = batch["bsize"]

if actor_critic_output_for_batch is None:
actor_critic_output_for_batch, _ = self.actor_critic(
observations=batch["observations"],
memory=batch["memory"],
prev_actions=batch["prev_actions"],
masks=batch["masks"],
)

try:
actor_critic_output_for_batch, _ = self.actor_critic(
observations=batch["observations"],
memory=batch["memory"],
prev_actions=batch["prev_actions"],
masks=batch["masks"],
)
except ValueError:
save_path = self.save_error_data(batch=batch)
get_logger().error(
f"Encountered a value error! Likely because of nans in the output/input."
f" Saving all error information to {save_path}."
)
raise

loss_return = loss.loss(
step_count=self.step_count,
Expand Down Expand Up @@ -1287,6 +1329,10 @@ def aggregate_and_send_logging_package(

self.aggregate_task_metrics(logging_pkg=logging_pkg)

for callback_dict in self.single_process_task_callback_data:
logging_pkg.task_callback_data.append(callback_dict)
self.single_process_task_callback_data = []

if self.mode == TRAIN_MODE_STR:
# Technically self.mode should always be "train" here (as this is the training engine),
# this conditional is defensive
Expand Down Expand Up @@ -1450,9 +1496,9 @@ def run_pipeline(self, valid_on_initial_weights: bool = False):
rollout_storage_uuid=self.training_pipeline.rollout_storage_uuid,
uuid_to_storage=uuid_to_storage,
)
except TimeoutError:
except (TimeoutError, EOFError) as e:
if (
not self.try_restart_after_task_timeout
not self.try_restart_after_task_error
) or self.mode != TRAIN_MODE_STR:
# Apparently you can just call `raise` here and doing so will just raise the exception as though
# it was not caught (so the stacktrace isn't messed up)
Expand All @@ -1465,9 +1511,10 @@ def run_pipeline(self, valid_on_initial_weights: bool = False):
else:
get_logger().warning(
f"[{self.mode} worker {self.worker_id}] `vector_tasks` appears to have crashed during"
f" training as it has timed out. You have set `try_restart_after_task_timeout` to `True` so"
f" we will attempt to restart these tasks from the beginning. USE THIS FEATURE AT YOUR OWN"
f" RISK. Timeout exception:\n{traceback.format_exc()}."
f" training due to an {type(e).__name__} error. You have set"
f" `try_restart_after_task_error` to `True` so we will attempt to restart these tasks from"
f" the beginning. USE THIS FEATURE AT YOUR OWN"
f" RISK. Exception:\n{traceback.format_exc()}."
)
self.vector_tasks.close()
self._vector_tasks = None
Expand Down Expand Up @@ -1852,6 +1899,10 @@ def run_eval(

self.aggregate_task_metrics(logging_pkg=logging_pkg)

for callback_dict in self.single_process_task_callback_data:
logging_pkg.task_callback_data.append(callback_dict)
self.single_process_task_callback_data = []

logging_pkg.viz_data = (
visualizer.read_and_reset() if visualizer is not None else None
)
Expand Down
Loading