Skip to content

Commit

Permalink
Merge pull request #339 from allenai/callbacks
Browse files Browse the repository at this point in the history
Add Callback Support
  • Loading branch information
Lucaweihs authored Aug 16, 2022
2 parents a709009 + e89eae1 commit b5c7192
Show file tree
Hide file tree
Showing 32 changed files with 1,049 additions and 241 deletions.
143 changes: 97 additions & 46 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,9 @@
import traceback
from functools import partial
from multiprocessing.context import BaseContext
from typing import (
Optional,
Any,
Dict,
Union,
List,
cast,
Sequence,
)
from typing import Any, Dict, List, Optional, Sequence, Union, cast

import filelock
import torch
import torch.distributed as dist # type: ignore
import torch.distributions # type: ignore
Expand All @@ -27,7 +20,9 @@
# noinspection PyProtectedMember
from torch._C._distributed_c10d import ReduceOp

from allenact.algorithms.onpolicy_sync.misc import TrackingInfoType, TrackingInfo
from allenact.algorithms.onpolicy_sync.misc import TrackingInfo, TrackingInfoType
from allenact.base_abstractions.sensor import Sensor
from allenact.utils.misc_utils import str2bool
from allenact.utils.model_utils import md5_hash_of_state_dict

try:
Expand All @@ -43,37 +38,35 @@
from allenact.algorithms.onpolicy_sync.storage import (
ExperienceStorage,
MiniBatchStorageMixin,
StreamingStorageMixin,
RolloutStorage,
StreamingStorageMixin,
)
from allenact.algorithms.onpolicy_sync.vector_sampled_tasks import (
VectorSampledTasks,
COMPLETE_TASK_CALLBACK_KEY,
COMPLETE_TASK_METRICS_KEY,
SingleProcessVectorSampledTasks,
VectorSampledTasks,
)
from allenact.base_abstractions.distributions import TeacherForcingDistr
from allenact.base_abstractions.experiment_config import ExperimentConfig, MachineParams
from allenact.base_abstractions.misc import (
RLStepResult,
Memory,
ActorCriticOutput,
GenericAbstractLoss,
Memory,
RLStepResult,
)
from allenact.base_abstractions.distributions import TeacherForcingDistr
from allenact.utils import spaces_utils as su
from allenact.utils.experiment_utils import (
set_seed,
TrainingPipeline,
LoggingPackage,
PipelineStage,
set_deterministic_cudnn,
ScalarMeanTracker,
StageComponent,
TrainingPipeline,
set_deterministic_cudnn,
set_seed,
)
from allenact.utils.system import get_logger
from allenact.utils.tensor_utils import (
batch_observations,
detach_recursively,
)
from allenact.utils.tensor_utils import batch_observations, detach_recursively
from allenact.utils.viz_utils import VizSuite

try:
Expand All @@ -82,9 +75,9 @@
# noinspection PyPackageRequirements
import pydevd

DEBUGGING = True
DEBUGGING = str2bool(os.getenv("ALLENACT_DEBUG", "true"))
except ImportError:
DEBUGGING = False
DEBUGGING = str2bool(os.getenv("ALLENACT_DEBUG", "false"))

DEBUG_VST_TIMEOUT: Optional[int] = (lambda x: int(x) if x is not None else x)(
os.getenv("ALLENACT_DEBUG_VST_TIMEOUT", None)
Expand Down Expand Up @@ -115,6 +108,7 @@ def __init__(
], # to write/read (trainer/evaluator) ready checkpoints
checkpoints_dir: str,
mode: str = "train",
callback_sensors: Optional[Sequence[Sensor]] = None,
seed: Optional[int] = None,
deterministic_cudnn: bool = False,
mp_ctx: Optional[BaseContext] = None,
Expand All @@ -126,7 +120,7 @@ def __init__(
deterministic_agents: bool = False,
max_sampler_processes_per_worker: Optional[int] = None,
initial_model_state_dict: Optional[Union[Dict[str, Any], int]] = None,
try_restart_after_task_timeout: bool = False,
try_restart_after_task_error: bool = False,
**kwargs,
):
"""Initializer.
Expand Down Expand Up @@ -154,7 +148,7 @@ def __init__(
self.device = torch.device("cpu") if device == -1 else torch.device(device) # type: ignore
self.distributed_ip = distributed_ip
self.distributed_port = distributed_port
self.try_restart_after_task_timeout = try_restart_after_task_timeout
self.try_restart_after_task_error = try_restart_after_task_error

self.mode = mode.lower().strip()
assert self.mode in [
Expand All @@ -163,6 +157,7 @@ def __init__(
TEST_MODE_STR,
], 'Only "train", "valid", "test" modes supported'

self.callback_sensors = callback_sensors
self.deterministic_cudnn = deterministic_cudnn
if self.deterministic_cudnn:
set_deterministic_cudnn()
Expand Down Expand Up @@ -245,6 +240,9 @@ def __init__(
port=self.distributed_port,
world_size=self.num_workers,
is_master=self.worker_id == 0,
timeout=datetime.timedelta(
seconds=3 * (DEBUG_VST_TIMEOUT if DEBUGGING else 1 * 60) + 300
),
)
cpu_device = self.device == torch.device("cpu") # type:ignore

Expand Down Expand Up @@ -277,6 +275,7 @@ def __init__(

# Keeping track of metrics during training/inference
self.single_process_metrics: List = []
self.single_process_task_callback_data: List = []

@property
def vector_tasks(
Expand Down Expand Up @@ -310,12 +309,13 @@ def vector_tasks(
self._vector_tasks = VectorSampledTasks(
make_sampler_fn=self.config.make_sampler_fn,
sampler_fn_args=self.get_sampler_fn_args(seeds),
callback_sensors=self.callback_sensors,
multiprocessing_start_method="forkserver"
if self.mp_ctx is None
else None,
mp_ctx=self.mp_ctx,
max_processes=self.max_sampler_processes_per_worker,
read_timeout=DEBUG_VST_TIMEOUT if DEBUGGING else 5 * 60,
read_timeout=DEBUG_VST_TIMEOUT if DEBUGGING else 1 * 60,
)
return self._vector_tasks

Expand Down Expand Up @@ -588,14 +588,17 @@ def collect_step_across_all_task_samplers(

# Save after task completion metrics
for step_result in outputs:
if (
step_result.info is not None
and COMPLETE_TASK_METRICS_KEY in step_result.info
):
self.single_process_metrics.append(
step_result.info[COMPLETE_TASK_METRICS_KEY]
)
del step_result.info[COMPLETE_TASK_METRICS_KEY]
if step_result.info is not None:
if COMPLETE_TASK_METRICS_KEY in step_result.info:
self.single_process_metrics.append(
step_result.info[COMPLETE_TASK_METRICS_KEY]
)
del step_result.info[COMPLETE_TASK_METRICS_KEY]
if COMPLETE_TASK_CALLBACK_KEY in step_result.info:
self.single_process_task_callback_data.append(
step_result.info[COMPLETE_TASK_CALLBACK_KEY]
)
del step_result.info[COMPLETE_TASK_CALLBACK_KEY]

rewards: Union[List, torch.Tensor]
observations, rewards, dones, infos = [list(x) for x in zip(*outputs)]
Expand Down Expand Up @@ -856,6 +859,36 @@ def deterministic_seeds(self) -> None:
) # use the latest seed for workers and update rng state
self.vector_tasks.set_seeds(seeds)

def save_error_data(self, batch: Dict[str, Any]) -> str:
model_path = os.path.join(
self.checkpoints_dir,
"error_for_exp_{}__stage_{:02d}__steps_{:012d}.pt".format(
self.experiment_name,
self.training_pipeline.current_stage_index,
self.training_pipeline.total_steps,
),
)
with filelock.FileLock(
os.path.join(self.checkpoints_dir, "error.lock"), timeout=60
):
if not os.path.exists(model_path):
save_dict = {
"model_state_dict": self.actor_critic.state_dict(), # type:ignore
"total_steps": self.training_pipeline.total_steps, # Total steps including current stage
"optimizer_state_dict": self.optimizer.state_dict(), # type: ignore
"training_pipeline_state_dict": self.training_pipeline.state_dict(),
"trainer_seed": self.seed,
"batch": batch,
}

if self.lr_scheduler is not None:
save_dict["scheduler_state"] = cast(
_LRScheduler, self.lr_scheduler
).state_dict()

torch.save(save_dict, model_path)
return model_path

def checkpoint_save(self, pipeline_stage_index: Optional[int] = None) -> str:
model_path = os.path.join(
self.checkpoints_dir,
Expand Down Expand Up @@ -1124,12 +1157,21 @@ def single_batch_generator(streaming_storage: StreamingStorageMixin):
bsize = batch["bsize"]

if actor_critic_output_for_batch is None:
actor_critic_output_for_batch, _ = self.actor_critic(
observations=batch["observations"],
memory=batch["memory"],
prev_actions=batch["prev_actions"],
masks=batch["masks"],
)

try:
actor_critic_output_for_batch, _ = self.actor_critic(
observations=batch["observations"],
memory=batch["memory"],
prev_actions=batch["prev_actions"],
masks=batch["masks"],
)
except ValueError:
save_path = self.save_error_data(batch=batch)
get_logger().error(
f"Encountered a value error! Likely because of nans in the output/input."
f" Saving all error information to {save_path}."
)
raise

loss_return = loss.loss(
step_count=self.step_count,
Expand Down Expand Up @@ -1287,6 +1329,10 @@ def aggregate_and_send_logging_package(

self.aggregate_task_metrics(logging_pkg=logging_pkg)

for callback_dict in self.single_process_task_callback_data:
logging_pkg.task_callback_data.append(callback_dict)
self.single_process_task_callback_data = []

if self.mode == TRAIN_MODE_STR:
# Technically self.mode should always be "train" here (as this is the training engine),
# this conditional is defensive
Expand Down Expand Up @@ -1450,9 +1496,9 @@ def run_pipeline(self, valid_on_initial_weights: bool = False):
rollout_storage_uuid=self.training_pipeline.rollout_storage_uuid,
uuid_to_storage=uuid_to_storage,
)
except TimeoutError:
except (TimeoutError, EOFError) as e:
if (
not self.try_restart_after_task_timeout
not self.try_restart_after_task_error
) or self.mode != TRAIN_MODE_STR:
# Apparently you can just call `raise` here and doing so will just raise the exception as though
# it was not caught (so the stacktrace isn't messed up)
Expand All @@ -1465,9 +1511,10 @@ def run_pipeline(self, valid_on_initial_weights: bool = False):
else:
get_logger().warning(
f"[{self.mode} worker {self.worker_id}] `vector_tasks` appears to have crashed during"
f" training as it has timed out. You have set `try_restart_after_task_timeout` to `True` so"
f" we will attempt to restart these tasks from the beginning. USE THIS FEATURE AT YOUR OWN"
f" RISK. Timeout exception:\n{traceback.format_exc()}."
f" training due to an {type(e).__name__} error. You have set"
f" `try_restart_after_task_error` to `True` so we will attempt to restart these tasks from"
f" the beginning. USE THIS FEATURE AT YOUR OWN"
f" RISK. Exception:\n{traceback.format_exc()}."
)
self.vector_tasks.close()
self._vector_tasks = None
Expand Down Expand Up @@ -1852,6 +1899,10 @@ def run_eval(

self.aggregate_task_metrics(logging_pkg=logging_pkg)

for callback_dict in self.single_process_task_callback_data:
logging_pkg.task_callback_data.append(callback_dict)
self.single_process_task_callback_data = []

logging_pkg.viz_data = (
visualizer.read_and_reset() if visualizer is not None else None
)
Expand Down
Loading

0 comments on commit b5c7192

Please sign in to comment.