Skip to content

Commit

Permalink
Fixing merge conflict in engine.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucaweihs committed Jun 2, 2022
2 parents 030d2ce + a709009 commit 2f3f83e
Show file tree
Hide file tree
Showing 25 changed files with 406 additions and 124 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
python -m pip install --editable="./allenact_plugins[all]"
python -m pip install pip install -r allenact_plugins/babyai_plugin/extra_requirements.txt # Required as babyai is not on PyPI
python -m pip install compress_pickle # Needed for some mapping tests
python -m pip install -U protobuf==3.20.1 # Required until tensorboardX is fixed: https://github.com/lanpa/tensorboardX/issues/663
pip list
- name: Test with pytest
Expand Down
68 changes: 61 additions & 7 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,20 @@
from allenact.utils.tensor_utils import batch_observations, detach_recursively
from allenact.utils.viz_utils import VizSuite

try:
# When debugging we don't want to timeout in the VectorSampledTasks

# noinspection PyPackageRequirements
import pydevd

DEBUGGING = True
except ImportError:
DEBUGGING = False

DEBUG_VST_TIMEOUT: Optional[int] = (lambda x: int(x) if x is not None else x)(
os.getenv("ALLENACT_DEBUG_VST_TIMEOUT", None)
)

TRAIN_MODE_STR = "train"
VALID_MODE_STR = "valid"
TEST_MODE_STR = "test"
Expand Down Expand Up @@ -102,6 +116,7 @@ def __init__(
deterministic_agents: bool = False,
max_sampler_processes_per_worker: Optional[int] = None,
initial_model_state_dict: Optional[Union[Dict[str, Any], int]] = None,
try_restart_after_task_timeout: bool = False,
**kwargs,
):
"""Initializer.
Expand Down Expand Up @@ -129,6 +144,7 @@ def __init__(
self.device = torch.device("cpu") if device == -1 else torch.device(device) # type: ignore
self.distributed_ip = distributed_ip
self.distributed_port = distributed_port
self.try_restart_after_task_timeout = try_restart_after_task_timeout

self.mode = mode.lower().strip()
assert self.mode in [
Expand Down Expand Up @@ -235,7 +251,7 @@ def __init__(
# During testing, we sometimes found that default timeout was too short
# resulting in the run terminating surprisingly, we increase it here.
timeout=datetime.timedelta(minutes=3000)
if self.mode == TEST_MODE_STR
if (self.mode == TEST_MODE_STR or DEBUGGING)
else dist.default_pg_timeout,
)
self.is_distributed = True
Expand Down Expand Up @@ -290,6 +306,7 @@ def vector_tasks(
else None,
mp_ctx=self.mp_ctx,
max_processes=self.max_sampler_processes_per_worker,
read_timeout=DEBUG_VST_TIMEOUT if DEBUGGING else 5 * 60,
)
return self._vector_tasks

Expand Down Expand Up @@ -1421,11 +1438,48 @@ def run_pipeline(self, valid_on_initial_weights: bool = False):
for k, v in self.training_pipeline.current_stage_storage.items()
}

for step in range(cur_stage_training_settings.num_steps):
num_paused = self.collect_step_across_all_task_samplers(
rollout_storage_uuid=self.training_pipeline.rollout_storage_uuid,
uuid_to_storage=uuid_to_storage,
)
vector_tasks_already_restarted = False
step = -1
while step < cur_stage_training_settings.num_steps - 1:
step += 1

try:
num_paused = self.collect_step_across_all_task_samplers(
rollout_storage_uuid=self.training_pipeline.rollout_storage_uuid,
uuid_to_storage=uuid_to_storage,
)
except TimeoutError:
if (
not self.try_restart_after_task_timeout
) or self.mode != TRAIN_MODE_STR:
# Apparently you can just call `raise` here and doing so will just raise the exception as though
# it was not caught (so the stacktrace isn't messed up)
raise
elif vector_tasks_already_restarted:
raise RuntimeError(
f"[{self.mode} worker {self.worker_id}] `vector_tasks` has timed out twice in the same"
f" rollout. This suggests that this error was not recoverable. Timeout exception:\n{traceback.format_exc()}"
)
else:
get_logger().warning(
f"[{self.mode} worker {self.worker_id}] `vector_tasks` appears to have crashed during"
f" training as it has timed out. You have set `try_restart_after_task_timeout` to `True` so"
f" we will attempt to restart these tasks from the beginning. USE THIS FEATURE AT YOUR OWN"
f" RISK. Timeout exception:\n{traceback.format_exc()}."
)
self.vector_tasks.close()
self._vector_tasks = None

vector_tasks_already_restarted = True
for (
storage
) in self.training_pipeline.current_stage_storage.values():
storage.after_updates()
self.initialize_storage_and_viz(
storage_to_initialize=list(uuid_to_storage.values())
)
step = -1
continue

# A more informative error message should already have been thrown in be given in
# `collect_step_across_all_task_samplers` if `num_paused != 0` here but this serves
Expand Down Expand Up @@ -1605,7 +1659,7 @@ def train(
get_logger().error(
f"[{self.mode} worker {self.worker_id}] Encountered {type(e).__name__}, exiting."
)
get_logger().exception(traceback.format_exc())
get_logger().error(traceback.format_exc())
finally:
if training_completed_successfully:
if self.worker_id == 0:
Expand Down
13 changes: 11 additions & 2 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ def handler(_signo, _frame):
except Exception:
get_logger().error(
f"Error occurred when closing the RL engine used by work {mode}-{id}."
f" We cannot recover from this and will simply exit. The exception:"
f" We cannot recover from this and will simply exit. The exception:\n"
f"{traceback.format_exc()}"
)
get_logger().exception(traceback.format_exc())
sys.exit(1)
sys.exit(0)
else:
Expand Down Expand Up @@ -472,6 +472,15 @@ def start_train(

distributed_port = 0 if num_workers == 1 else self.get_port()

if (
num_workers > 1
and "NCCL_ASYNC_ERROR_HANDLING" not in os.environ
and "NCCL_BLOCKING_WAIT" not in os.environ
):
# This ensures the NCCL distributed backend will throw errors
# if we timeout at a call to `barrier()`
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"

worker_ids = self.local_worker_ids(TRAIN_MODE_STR)

model_hash = None
Expand Down
9 changes: 6 additions & 3 deletions allenact/algorithms/onpolicy_sync/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def initialize(
self.full_size + 1, num_samplers, action_flat_dim, device=self.device
)

assert self.step == 0, "Must call `after_update` before calling `initialize`"
assert self.step == 0, "Must call `after_updates` before calling `initialize`"
self.insert_observations(observations=observations, time_step=0)
self.prev_actions[0].zero_() # Have to zero previous actions
self.masks[0].zero_() # Have to zero masks
Expand Down Expand Up @@ -529,8 +529,11 @@ def after_updates(self, **kwargs):
for key in storage:
storage[key][0][0].copy_(storage[key][0][-1])

self.masks[0].copy_(self.masks[-1])
self.prev_actions[0].copy_(self.prev_actions[-1])
if self._masks_full is not None:
self.masks[0].copy_(self.masks[-1])

if self._prev_actions_full is not None:
self.prev_actions[0].copy_(self.prev_actions[-1])

self._before_update_called = False
self._advantages = None
Expand Down
85 changes: 59 additions & 26 deletions allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class VectorSampledTasks:
_mp_ctx: BaseContext
_connection_read_fns: List[Callable[[], Any]]
_connection_write_fns: List[Callable[[Any], None]]
_read_timeout: Optional[float]

def __init__(
self,
Expand All @@ -154,12 +155,16 @@ def __init__(
mp_ctx: Optional[BaseContext] = None,
should_log: bool = True,
max_processes: Optional[int] = None,
read_timeout: Optional[
float
] = 60, # Seconds to wait for a task to return a response before timing out
) -> None:

self._is_waiting = False
self._is_closed = True
self.should_log = should_log
self.max_processes = max_processes
self.read_timeout = read_timeout

assert (
sampler_fn_args is not None and len(sampler_fn_args) > 0
Expand Down Expand Up @@ -195,7 +200,8 @@ def __init__(
for args in sampler_fn_args:
args["mp_ctx"] = self._mp_ctx
(
self._connection_read_fns,
connection_poll_fns,
connection_read_fns,
self._connection_write_fns,
) = self._spawn_workers( # noqa
make_sampler_fn=make_sampler_fn,
Expand All @@ -204,6 +210,13 @@ def __init__(
],
)

self._connection_read_fns = [
self._create_read_function_with_timeout(
read_fn=read_fn, poll_fn=poll_fn, timeout=self.read_timeout
)
for read_fn, poll_fn in zip(connection_read_fns, connection_poll_fns)
]

self._is_closed = False

for write_fn in self._connection_write_fns:
Expand Down Expand Up @@ -234,6 +247,25 @@ def __init__(
space for read_fn in self._connection_read_fns for space in read_fn()
]

@staticmethod
def _create_read_function_with_timeout(
*,
read_fn: Callable[[], Any],
poll_fn: Callable[[float], bool],
timeout: Optional[float],
) -> Callable[[], Any]:
def read_with_timeout(timeout_to_use: Optional[float] = timeout):
if timeout_to_use is not None:
# noinspection PyArgumentList
if not poll_fn(timeout=timeout_to_use):
raise TimeoutError(
f"Did not recieve output from `VectorSampledTask` worker for {timeout_to_use} seconds."
)

return read_fn()

return read_with_timeout

def _reset_sampler_index_to_process_ind_and_subprocess_ind(self):
self.sampler_index_to_process_ind_and_subprocess_ind = [
[i, j]
Expand Down Expand Up @@ -297,7 +329,7 @@ def _task_sampling_loop_worker(
"""process worker for creating and interacting with the
Tasks/TaskSampler."""

ptitle("VectorSampledTask: {}".format(worker_id))
ptitle(f"VectorSampledTask: {worker_id}")

sp_vector_sampled_tasks = SingleProcessVectorSampledTasks(
make_sampler_fn=make_sampler_fn,
Expand All @@ -307,7 +339,7 @@ def _task_sampling_loop_worker(
)

if parent_pipe is not None:
parent_pipe.close()
parent_pipe.close() # Means this pipe will close when the calling process closes it
try:
while True:
read_input = connection_read_fn()
Expand Down Expand Up @@ -368,7 +400,9 @@ def _task_sampling_loop_worker(
if should_log:
get_logger().info(f"Worker {worker_id} KeyboardInterrupt")
except Exception as e:
get_logger().error(traceback.format_exc())
get_logger().error(
f"Worker {worker_id} encountered an exception:\n{traceback.format_exc()}"
)
raise e
finally:
if child_pipe is not None:
Expand All @@ -380,52 +414,50 @@ def _spawn_workers(
self,
make_sampler_fn: Callable[..., TaskSampler],
sampler_fn_args_list: Sequence[Sequence[Dict[str, Any]]],
) -> Tuple[List[Callable[[], Any]], List[Callable[[Any], None]]]:
) -> Tuple[
List[Callable[[], bool]], List[Callable[[], Any]], List[Callable[[Any], None]]
]:
parent_connections, worker_connections = zip(
*[self._mp_ctx.Pipe(duplex=True) for _ in range(self._num_processes)]
)
self._workers = []
k = 0
id: Union[int, str]
for id, stuff in enumerate(
for id, (worker_conn, parent_conn, current_sampler_fn_args_list) in enumerate(
zip(worker_connections, parent_connections, sampler_fn_args_list)
):
worker_conn, parent_conn, current_sampler_fn_args_list = stuff # type: ignore

if len(current_sampler_fn_args_list) != 1:
id = "{}({}-{})".format(
id, k, k + len(current_sampler_fn_args_list) - 1
)
id = f"{id}({k}-{k + len(current_sampler_fn_args_list) - 1})"
k += len(current_sampler_fn_args_list)

if self.should_log:
get_logger().info(
"Starting {}-th VectorSampledTask worker with args {}".format(
id, current_sampler_fn_args_list
)
f"Starting {id}-th VectorSampledTask worker with args {current_sampler_fn_args_list}"
)

ps = self._mp_ctx.Process( # type: ignore
target=self._task_sampling_loop_worker,
args=(
id,
worker_conn.recv,
worker_conn.send,
make_sampler_fn,
current_sampler_fn_args_list,
self._auto_resample_when_done,
self.should_log,
worker_conn,
parent_conn,
kwargs=dict(
worker_id=id,
connection_read_fn=worker_conn.recv,
connection_write_fn=worker_conn.send,
make_sampler_fn=make_sampler_fn,
sampler_fn_args_list=current_sampler_fn_args_list,
auto_resample_when_done=self._auto_resample_when_done,
should_log=self.should_log,
child_pipe=worker_conn,
parent_pipe=parent_conn,
),
)
self._workers.append(ps)
ps.daemon = True
ps.start()
worker_conn.close()
worker_conn.close() # Means this pipe will close when the child process closes it
time.sleep(
0.1
) # Useful to ensure things don't lock up when spawning many envs
return (
[p.poll for p in parent_connections],
[p.recv for p in parent_connections],
[p.send for p in parent_connections],
)
Expand Down Expand Up @@ -593,7 +625,8 @@ def close(self) -> None:
if self._is_waiting:
for read_fn in self._connection_read_fns:
try:
read_fn()
# noinspection PyArgumentList
read_fn(0) # Time out immediately
except Exception:
pass

Expand Down
Loading

0 comments on commit 2f3f83e

Please sign in to comment.