Skip to content

Commit

Permalink
[RLlib] Enhance node-failure tolerance. (ray-project#50007)
Browse files Browse the repository at this point in the history
Signed-off-by: Anson Qian <[email protected]>
  • Loading branch information
sven1977 authored and anson627 committed Jan 31, 2025
1 parent 5417a75 commit 9d9b87d
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 77 deletions.
71 changes: 39 additions & 32 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@
NUM_EPISODES,
NUM_EPISODES_LIFETIME,
NUM_TRAINING_STEP_CALLS_PER_ITERATION,
RESTORE_WORKERS_TIMER,
RESTORE_EVAL_WORKERS_TIMER,
RESTORE_ENV_RUNNERS_TIMER,
RESTORE_EVAL_ENV_RUNNERS_TIMER,
SYNCH_ENV_CONNECTOR_STATES_TIMER,
SYNCH_EVAL_ENV_CONNECTOR_STATES_TIMER,
SYNCH_WORKER_WEIGHTS_TIMER,
Expand Down Expand Up @@ -1685,8 +1685,7 @@ def _env_runner_remote(worker, num, round, iter):
return env_runner_results, env_steps, agent_steps, all_batches

@OverrideToImplementCustomLogic
@DeveloperAPI
def restore_workers(self, workers: EnvRunnerGroup) -> None:
def restore_env_runners(self, env_runner_group: EnvRunnerGroup) -> None:
"""Try bringing back unhealthy EnvRunners and - if successful - sync with local.
Algorithms that use custom EnvRunners may override this method to
Expand All @@ -1695,33 +1694,31 @@ def restore_workers(self, workers: EnvRunnerGroup) -> None:
after such a restart of a (previously failed) worker.
Args:
workers: The EnvRunnerGroup to restore. This may be the training or the
evaluation EnvRunnerGroup.
env_runner_group: The EnvRunnerGroup to restore. This may be the training or
the evaluation EnvRunnerGroup.
"""
# If `workers` is None, or
# 1. `workers` (EnvRunnerGroup) does not have a local worker, and
# If `env_runner_group` is None, or
# 1. `env_runner_group` (EnvRunnerGroup) does not have a local worker, and
# 2. `self.env_runner_group` (EnvRunnerGroup used for training) does not have a
# local EnvRunner -> we don't have a local worker to get state from, so we can't
# recover remote EnvRunners in this case.
if not workers or (
not workers.local_env_runner and not self.env_runner_group.local_env_runner
# local EnvRunner -> we don't have an EnvRunner to get state from, so we can't
# recover remote EnvRunner actors in this case.
if not env_runner_group or (
not env_runner_group.local_env_runner and not self.env_runner
):
return

# This is really cheap, since probe_unhealthy_workers() is a no-op
# This is really cheap, since probe_unhealthy_env_runners() is a no-op
# if there are no unhealthy workers.
restored = workers.probe_unhealthy_workers()
restored = env_runner_group.probe_unhealthy_env_runners()

if restored:
# Count the restored workers.
self._counters["total_num_restored_workers"] += len(restored)

from_worker = (
workers.local_env_runner or self.env_runner_group.local_env_runner
)
from_env_runner = env_runner_group.local_env_runner or self.env_runner
# Get the state of the correct (reference) worker. For example the local
# worker of an EnvRunnerGroup.
state = from_worker.get_state()
state = from_env_runner.get_state()
state_ref = ray.put(state)

def _sync_env_runner(er):
Expand All @@ -1734,7 +1731,9 @@ def _sync_env_runner(er):

elif self.config.is_multi_agent:

multi_rl_module_spec = MultiRLModuleSpec.from_module(from_worker.module)
multi_rl_module_spec = MultiRLModuleSpec.from_module(
from_env_runner.module
)

def _sync_env_runner(er): # noqa
# Remove modules, if necessary.
Expand All @@ -1752,7 +1751,7 @@ def _sync_env_runner(er): # noqa

# By default, entire local EnvRunner state is synced after restoration
# to bring the previously failed EnvRunner up to date.
workers.foreach_env_runner(
env_runner_group.foreach_env_runner(
func=_sync_env_runner,
remote_worker_ids=restored,
# Don't update the local EnvRunner, b/c it's the one we are synching
Expand All @@ -1768,9 +1767,11 @@ def _sync_env_runner(er): # noqa
callbacks_functions=self.config.callbacks_on_env_runners_recreated,
kwargs=dict(
algorithm=self,
env_runner_group=workers,
env_runner_group=env_runner_group,
env_runner_indices=restored,
is_evaluation=workers.local_env_runner.config.in_evaluation,
is_evaluation=(
env_runner_group.local_env_runner.config.in_evaluation
),
),
)
# TODO (sven): Deprecate this call.
Expand All @@ -1779,9 +1780,11 @@ def _sync_env_runner(er): # noqa
callbacks_objects=self.callbacks,
kwargs=dict(
algorithm=self,
worker_set=workers,
worker_set=env_runner_group,
worker_ids=restored,
is_evaluation=workers.local_env_runner.config.in_evaluation,
is_evaluation=(
env_runner_group.local_env_runner.config.in_evaluation
),
),
)

Expand Down Expand Up @@ -3371,8 +3374,8 @@ def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]:
# when we have reached `min_time_s_per_iteration`).
while not train_iter_ctx.should_stop(has_run_once):
# Before training step, try to bring failed workers back.
with self.metrics.log_time((TIMERS, RESTORE_WORKERS_TIMER)):
self.restore_workers(self.env_runner_group)
with self.metrics.log_time((TIMERS, RESTORE_ENV_RUNNERS_TIMER)):
self.restore_env_runners(self.env_runner_group)

# Try to train one step.
with self.metrics.log_time((TIMERS, TRAINING_STEP_TIMER)):
Expand Down Expand Up @@ -3447,11 +3450,11 @@ def _run_one_evaluation(
"""
if self.eval_env_runner_group is not None:
if self.config.enable_env_runner_and_connector_v2:
with self.metrics.log_time((TIMERS, RESTORE_EVAL_WORKERS_TIMER)):
self.restore_workers(self.eval_env_runner_group)
with self.metrics.log_time((TIMERS, RESTORE_EVAL_ENV_RUNNERS_TIMER)):
self.restore_env_runners(self.eval_env_runner_group)
else:
with self._timers[RESTORE_EVAL_WORKERS_TIMER]:
self.restore_workers(self.eval_env_runner_group)
with self._timers["restore_eval_workers"]:
self.restore_env_runners(self.eval_env_runner_group)

# Run `self.evaluate()` only once per training iteration.
if self.config.enable_env_runner_and_connector_v2:
Expand Down Expand Up @@ -4001,8 +4004,8 @@ def _run_one_training_iteration_old_api_stack(self):
training_step_results = None
with TrainIterCtx(algo=self) as train_iter_ctx:
while not train_iter_ctx.should_stop(training_step_results):
with self._timers[RESTORE_WORKERS_TIMER]:
self.restore_workers(self.env_runner_group)
with self._timers["restore_workers"]:
self.restore_env_runners(self.env_runner_group)

with self._timers[TRAINING_STEP_TIMER]:
training_step_results = self.training_step()
Expand Down Expand Up @@ -4131,6 +4134,10 @@ def _compile_iteration_results_old_api_stack(

return results

@Deprecated(new="Algorithm.restore_env_runners", error=False)
def restore_workers(self, *args, **kwargs):
return self.restore_env_runners(*args, **kwargs)

@Deprecated(
new="Algorithm.env_runner_group",
error=False,
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/tests/test_callbacks_on_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_on_env_runners_recreated_callback(self):
print(algo.train())
time.sleep(15.0)

algo.restore_workers(algo.env_runner_group)
algo.restore_env_runners(algo.env_runner_group)
# After training, the `on_workers_recreated` callback should have captured
# the exact worker IDs recreated (the exact number of times) as the actor
# manager itself. This confirms that the callback is triggered correctly,
Expand Down
12 changes: 6 additions & 6 deletions rllib/algorithms/tests/test_env_runner_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ def _do_test_failing_recover(self, config, multi_agent=False):
for _ in range(2):
algo.train()
time.sleep(15.0)
algo.restore_workers(algo.env_runner_group)
algo.restore_workers(algo.eval_env_runner_group)
algo.restore_env_runners(algo.env_runner_group)
algo.restore_env_runners(algo.eval_env_runner_group)

self.assertEqual(algo.env_runner_group.num_healthy_remote_workers(), 1)
self.assertEqual(algo.eval_env_runner_group.num_healthy_remote_workers(), 1)
Expand Down Expand Up @@ -563,7 +563,7 @@ def test_workers_failing_recover(self):

algo.train()
time.sleep(15.0)
algo.restore_workers(algo.env_runner_group)
algo.restore_env_runners(algo.env_runner_group)

# After training, still 2 healthy workers.
self.assertEqual(algo.env_runner_group.num_healthy_remote_workers(), 2)
Expand Down Expand Up @@ -645,8 +645,8 @@ def test_modules_are_restored_on_recovered_worker(self):

algo.train()
time.sleep(15.0)
algo.restore_workers(algo.env_runner_group)
algo.restore_workers(algo.eval_env_runner_group)
algo.restore_env_runners(algo.env_runner_group)
algo.restore_env_runners(algo.eval_env_runner_group)

# Everything healthy again. And all workers have been restarted.
self.assertEqual(algo.env_runner_group.num_healthy_remote_workers(), 2)
Expand Down Expand Up @@ -730,7 +730,7 @@ def test_eval_workers_failing_recover(self):

algo.train()
time.sleep(15.0)
algo.restore_workers(algo.eval_env_runner_group)
algo.restore_env_runners(algo.eval_env_runner_group)

# Everything still healthy. And all workers are restarted.
self.assertEqual(algo.eval_env_runner_group.num_healthy_remote_workers(), 2)
Expand Down
31 changes: 28 additions & 3 deletions rllib/algorithms/tests/test_node_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ray
from ray._private.test_utils import get_other_nodes
from ray.cluster_utils import Cluster
from ray.rllib.algorithms.appo import APPOConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.utils.metrics import (
Expand Down Expand Up @@ -68,6 +69,23 @@ def test_node_failure_ignore(self):

def test_node_failure_recreate_env_runners(self):
# We recreate failed EnvRunners and continue training.
config = (
APPOConfig()
.environment("CartPole-v1")
.learners(num_learners=0)
.experimental(_validate_config=False)
.env_runners(
num_env_runners=6,
validate_env_runners_after_construction=True,
)
.fault_tolerance(
restart_failed_env_runners=True,
ignore_env_runner_failures=False, # True also ok here; we restart.
)
)

self._train(config=config, iters=20, min_reward=300.0, preempt_freq=5)

config = (
PPOConfig()
.environment("CartPole-v1")
Expand All @@ -77,11 +95,11 @@ def test_node_failure_recreate_env_runners(self):
)
.fault_tolerance(
restart_failed_env_runners=True,
ignore_env_runner_failures=False, # True also ok here we recreate.
ignore_env_runner_failures=False, # True also ok here; we restart.
)
)

self._train(config=config, iters=30, min_reward=450.0, preempt_freq=5)
self._train(config=config, iters=20, min_reward=300.0, preempt_freq=5)

def test_node_failure_expect_crash(self):
# We do not ignore EnvRunner failures and expect to crash upon failure.
Expand Down Expand Up @@ -135,7 +153,13 @@ def _train(self, *, config, iters, min_reward, preempt_freq):
# node, which are always safe from preemption).
if (i - 1) % preempt_freq == 0:
if config.restart_failed_env_runners:
self.assertEqual(healthy_env_runners, 4)
# For async algos that call `restore_env_runners()` several times
# per iteration, the failed env runners may have already been
# restored.
if isinstance(config, APPOConfig):
self.assertIn(healthy_env_runners, [4, 6])
else:
self.assertEqual(healthy_env_runners, 4)
elif config.ignore_env_runner_failures:
self.assertIn(healthy_env_runners, [2, 4])
# After the 0th iteration, in which we already killed one node, if
Expand Down Expand Up @@ -167,6 +191,7 @@ def _train(self, *, config, iters, min_reward, preempt_freq):
dashboard_host="0.0.0.0",
)

algo.stop()
self.assertGreaterEqual(best_return, min_reward)


Expand Down
13 changes: 12 additions & 1 deletion rllib/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,18 @@ def __init__(self, config: AlgorithmConfig, rl_module_spec):
)

def get_batch(self, episode_refs: List[ray.ObjectRef]):
episodes: List[EpisodeType] = tree.flatten(ray.get(episode_refs))
episodes: List[EpisodeType] = []
# It's possible that individual refs are invalid due to the EnvRunner
# that produced the ref has crashed or had its entire node go down.
# In this case, try each ref individually and collect only valid results.
try:
episodes = tree.flatten(ray.get(episode_refs))
except ray.exceptions.OwnerDiedError:
for ref in episode_refs:
try:
episodes.extend(ray.get(ref))
except ray.exceptions.OwnerDiedError:
pass

env_steps = sum(len(e) for e in episodes)

Expand Down
19 changes: 15 additions & 4 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,11 +1338,22 @@ def _update_from_batch_or_episodes(
# actual batch/episodes objects).
if isinstance(batch, ray.ObjectRef):
batch = ray.get(batch)
if isinstance(episodes, ray.ObjectRef) or (
isinstance(episodes, list) and isinstance(episodes[0], ray.ObjectRef)
):
if isinstance(episodes, ray.ObjectRef):
episodes = ray.get(episodes)
episodes = tree.flatten(episodes)
elif isinstance(episodes, list) and isinstance(episodes[0], ray.ObjectRef):
# It's possible that individual refs are invalid due to the EnvRunner
# that produced the ref has crashed or had its entire node go down.
# In this case, try each ref individually and collect only valid results.
try:
episodes = tree.flatten(ray.get(episodes))
except ray.exceptions.OwnerDiedError:
episode_refs = episodes
episodes = []
for ref in episode_refs:
try:
episodes.extend(ray.get(ref))
except ray.exceptions.OwnerDiedError:
pass

# Call the learner connector on the given `episodes` (if we have one).
if episodes is not None and self._learner_connector is not None:
Expand Down
2 changes: 1 addition & 1 deletion rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def foreach_learner(
remote_actor_ids: List[int] = None,
timeout_seconds: Optional[float] = None,
return_obj_refs: bool = False,
mark_healthy: bool = True,
mark_healthy: bool = False,
**kwargs,
) -> RemoteCallResults:
"""Calls the given function on each Learner L with the args: (L, \*\*kwargs).
Expand Down
Loading

0 comments on commit 9d9b87d

Please sign in to comment.