Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data] Remove ray.kill in ActorPoolMapOperator #47752

Merged
merged 7 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,7 @@ def pending_to_running(self, ready_ref: ray.ObjectRef) -> bool:
already been killed.
"""
if ready_ref not in self._pending_actors:
# We assume that there was a race between killing the actor and the actor
# ready future resolving. Since we can rely on ray.kill() eventually killing
# the actor, we can safely drop this reference.
# The actor has been removed from the pool before becoming running.
return False
actor = self._pending_actors.pop(ready_ref)
self._num_tasks_in_flight[actor] = 0
Expand Down Expand Up @@ -548,7 +546,7 @@ def return_actor(self, actor: ray.actor.ActorHandle):

self._num_tasks_in_flight[actor] -= 1
if self._should_kill_idle_actors and self._num_tasks_in_flight[actor] == 0:
self._kill_running_actor(actor)
self._remove_actor(actor)

def get_pending_actor_refs(self) -> List[ray.ObjectRef]:
return list(self._pending_actors.keys())
Expand Down Expand Up @@ -585,7 +583,9 @@ def kill_inactive_actor(self) -> bool:
def _maybe_kill_pending_actor(self) -> bool:
if self._pending_actors:
# At least one pending actor, so kill first one.
self._kill_pending_actor(next(iter(self._pending_actors.keys())))
ready_ref = next(iter(self._pending_actors.keys()))
self._remove_actor(self._pending_actors[ready_ref])
del self._pending_actors[ready_ref]
return True
# No pending actors, so indicate to the caller that no actors were killed.
return False
Expand All @@ -594,7 +594,7 @@ def _maybe_kill_idle_actor(self) -> bool:
for actor, tasks_in_flight in self._num_tasks_in_flight.items():
if tasks_in_flight == 0:
# At least one idle actor, so kill first one found.
self._kill_running_actor(actor)
self._remove_actor(actor)
return True
# No idle actors, so indicate to the caller that no actors were killed.
return False
Expand All @@ -619,9 +619,9 @@ def kill_all_actors(self):
self._kill_all_running_actors()

def _kill_all_pending_actors(self):
pending_actor_refs = list(self._pending_actors.keys())
for ref in pending_actor_refs:
self._kill_pending_actor(ref)
for _, actor in self._pending_actors.items():
self._remove_actor(actor)
self._pending_actors.clear()

def _kill_all_idle_actors(self):
idle_actors = [
Expand All @@ -630,23 +630,25 @@ def _kill_all_idle_actors(self):
if tasks_in_flight == 0
]
for actor in idle_actors:
self._kill_running_actor(actor)
self._remove_actor(actor)
self._should_kill_idle_actors = True

def _kill_all_running_actors(self):
actors = list(self._num_tasks_in_flight.keys())
for actor in actors:
self._kill_running_actor(actor)

def _kill_running_actor(self, actor: ray.actor.ActorHandle):
"""Kill the provided actor and remove it from the pool."""
ray.kill(actor)
del self._num_tasks_in_flight[actor]

def _kill_pending_actor(self, ready_ref: ray.ObjectRef):
"""Kill the provided pending actor and remove it from the pool."""
actor = self._pending_actors.pop(ready_ref)
ray.kill(actor)
self._remove_actor(actor)

def _remove_actor(self, actor: ray.actor.ActorHandle):
"""Remove the given actor from the pool."""
# NOTE: we remove references to the actor and let ref counting
# garbage collect the actor, instead of using ray.kill.
# Because otherwise the actor cannot be restarted upon lineage reconstruction.
for state_dict in [
self._num_tasks_in_flight,
self._actor_locations,
]:
if actor in state_dict:
del state_dict[actor]

def _get_location(self, bundle: RefBundle) -> Optional[NodeIdStr]:
"""Ask Ray for the node id of the given bundle.
Expand Down
108 changes: 60 additions & 48 deletions python/ray/data/tests/test_actor_pool_map_operator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import collections
import time
import unittest
from typing import Any, Optional, Tuple

import pytest

import ray
from ray._private.test_utils import wait_for_condition
from ray.actor import ActorHandle
from ray.data._internal.compute import ActorPoolStrategy
from ray.data._internal.execution.operators.actor_pool_map_operator import _ActorPool
from ray.data._internal.execution.util import make_ref_bundles
from ray.tests.conftest import * # noqa
from ray.types import ObjectRef


@ray.remote
Expand All @@ -22,17 +25,19 @@ def get_location(self) -> str:

class TestActorPool(unittest.TestCase):
def setup_class(self):
self._last_created_actor_and_ready_ref = (None, None)
self._last_created_actor_and_ready_ref: Optional[
Tuple[ActorHandle, ObjectRef[Any]]
] = None
self._actor_node_id = "node1"
ray.init(num_cpus=4)

def teardown_class(self):
ray.shutdown()

def _create_actor_fn(self):
def _create_actor_fn(self) -> Tuple[ActorHandle, ObjectRef[Any]]:
actor = PoolWorker.remote(self._actor_node_id)
ready_ref = actor.get_location.remote()
self._last_created_actor_and_ready_ref = (actor, ready_ref)
self._last_created_actor_and_ready_ref = actor, ready_ref
return actor, ready_ref

def _create_actor_pool(
Expand All @@ -51,21 +56,33 @@ def _create_actor_pool(
)
return pool

def _add_pending_actor(self, pool: _ActorPool, node_id="node1"):
def _add_pending_actor(
self, pool: _ActorPool, node_id="node1"
) -> Tuple[ActorHandle, ObjectRef[Any]]:
self._actor_node_id = node_id
assert pool.scale_up(1) == 1
assert self._last_created_actor_and_ready_ref is not None
actor, ready_ref = self._last_created_actor_and_ready_ref
self._last_created_actor_and_ready_ref = None
return actor, ready_ref

def _wait_for_actor_ready(self, pool: _ActorPool, ready_ref):
ray.get(ready_ref)
pool.pending_to_running(ready_ref)

def _add_ready_actor(self, pool: _ActorPool, node_id="node1"):
def _add_ready_actor(self, pool: _ActorPool, node_id="node1") -> ActorHandle:
actor, ready_ref = self._add_pending_actor(pool, node_id)
self._wait_for_actor_ready(pool, ready_ref)
return actor

def _wait_for_actor_dead(self, actor_id: str):
def _check_actor_dead():
nonlocal actor_id
actor_info = ray.state.actors(actor_id)
return actor_info["State"] == "DEAD"

wait_for_condition(_check_actor_dead)

def test_basic_config(self):
pool = self._create_actor_pool(
min_size=1,
Expand Down Expand Up @@ -217,11 +234,10 @@ def test_kill_inactive_pending_actor(self):
assert killed
# Check that actor is not in pool.
assert pool.get_pending_actor_refs() == []
# Check that actor was killed.
# Wait a second to let actor killing happen.
time.sleep(1)
with pytest.raises(ray.exceptions.RayActorError):
ray.get(actor.get_location.remote())
# Check that actor is dead.
actor_id = actor._actor_id.hex()
del actor
self._wait_for_actor_dead(actor_id)
# Check that the per-state pool sizes are as expected.
assert pool.current_size() == 0
assert pool.num_pending_actors() == 0
Expand All @@ -240,11 +256,10 @@ def test_kill_inactive_idle_actor(self):
assert killed
# Check that actor is not in pool.
assert pool.pick_actor() is None
# Check that actor was killed.
# Wait a second to let actor killing happen.
time.sleep(1)
with pytest.raises(ray.exceptions.RayActorError):
ray.get(actor.get_location.remote())
# Check that actor is dead.
actor_id = actor._actor_id.hex()
del actor
self._wait_for_actor_dead(actor_id)
# Check that the per-state pool sizes are as expected.
assert pool.current_size() == 0
assert pool.num_pending_actors() == 0
Expand Down Expand Up @@ -283,11 +298,10 @@ def test_kill_inactive_pending_over_idle(self):
pool.return_actor(idle_actor)
# Check that the pending actor is not in pool.
assert pool.get_pending_actor_refs() == []
# Check that actor was killed.
# Wait a second to let actor killing happen.
time.sleep(1)
with pytest.raises(ray.exceptions.RayActorError):
ray.get(pending_actor.get_location.remote())
# Check that actor is dead.
actor_id = pending_actor._actor_id.hex()
del pending_actor
self._wait_for_actor_dead(actor_id)
# Check that the per-state pool sizes are as expected.
assert pool.current_size() == 1
assert pool.num_pending_actors() == 0
Expand All @@ -307,11 +321,10 @@ def test_kill_all_inactive_pending_actor_killed(self):
# Check that actor is no longer in the pool as pending, to protect against
# ready/killed races.
assert not pool.pending_to_running(ready_ref)
# Check that actor was killed.
# Wait a few seconds to let actor killing happen.
time.sleep(1)
with pytest.raises(ray.exceptions.RayActorError):
ray.get(actor.get_location.remote())
# Check that actor is dead.
actor_id = actor._actor_id.hex()
del actor
self._wait_for_actor_dead(actor_id)
# Check that the per-state pool sizes are as expected.
assert pool.current_size() == 0
assert pool.num_pending_actors() == 0
Expand All @@ -328,11 +341,10 @@ def test_kill_all_inactive_idle_actor_killed(self):
pool.kill_all_inactive_actors()
# Check that actor is not in pool.
assert pool.pick_actor() is None
# Check that actor was killed.
# Wait a few seconds to let actor killing happen.
time.sleep(1)
with pytest.raises(ray.exceptions.RayActorError):
ray.get(actor.get_location.remote())
# Check that actor is dead.
actor_id = actor._actor_id.hex()
del actor
self._wait_for_actor_dead(actor_id)
# Check that the per-state pool sizes are as expected.
assert pool.current_size() == 0
assert pool.num_pending_actors() == 0
Expand Down Expand Up @@ -369,11 +381,10 @@ def test_kill_all_inactive_future_idle_actors_killed(self):
pool.return_actor(actor)
# Check that actor is not in pool.
assert pool.pick_actor() is None
# Check that actor was killed.
# Wait a few seconds to let actor killing happen.
time.sleep(1)
with pytest.raises(ray.exceptions.RayActorError):
ray.get(actor.get_location.remote())
# Check that actor is dead.
actor_id = actor._actor_id.hex()
del actor
self._wait_for_actor_dead(actor_id)
# Check that the per-state pool sizes are as expected.
assert pool.current_size() == 0
assert pool.num_pending_actors() == 0
Expand Down Expand Up @@ -418,11 +429,10 @@ def test_kill_all_inactive_mixture(self):
pool.return_actor(actor1)
# Check that actor is not in pool.
assert pool.pick_actor() is None
# Check that actor was killed.
# Wait a few seconds to let actor killing happen.
time.sleep(1)
with pytest.raises(ray.exceptions.RayActorError):
ray.get(actor1.get_location.remote())
# Check that actor is dead.
actor_id = actor1._actor_id.hex()
del actor1
self._wait_for_actor_dead(actor_id)
# Check that the per-state pool sizes are as expected.
assert pool.current_size() == 0
assert pool.num_pending_actors() == 0
Expand All @@ -442,13 +452,15 @@ def test_all_actors_killed(self):
pool.kill_all_actors()
# Check that the pool is empty.
assert pool.pick_actor() is None
# Check that both actors were killed.
# Wait a few seconds to let actor killing happen.
time.sleep(1)
with pytest.raises(ray.exceptions.RayActorError):
ray.get(idle_actor.get_location.remote())
with pytest.raises(ray.exceptions.RayActorError):
ray.get(active_actor.get_location.remote())

# Check that both actors are dead
actor_id = active_actor._actor_id.hex()
del active_actor
self._wait_for_actor_dead(actor_id)
actor_id = idle_actor._actor_id.hex()
del idle_actor
self._wait_for_actor_dead(actor_id)

# Check that the per-state pool sizes are as expected.
assert pool.current_size() == 0
assert pool.num_pending_actors() == 0
Expand Down
15 changes: 15 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
import pytest

import ray
from ray._private.test_utils import wait_for_condition
from ray.data._internal.execution.interfaces.ref_bundle import (
_ref_bundles_iterator_to_block_refs_list,
)
from ray.data._internal.execution.operators.actor_pool_map_operator import _MapWorker
from ray.data.context import DataContext
from ray.data.exceptions import UserCodeException
from ray.data.tests.conftest import * # noqa
Expand Down Expand Up @@ -76,6 +78,19 @@ def test_basic_actors(shutdown_only):
concurrency=(8, 4),
)

# Make sure all actors are dead after dataset execution finishes.
def _all_actors_dead():
actor_table = ray.state.actors()
actors = {
id: actor_info
for actor_info in actor_table.values()
if actor_info["ActorClassName"] == _MapWorker.__name__
}
assert len(actors) > 0
return all(actor_info["State"] == "DEAD" for actor_info in actors.values())

wait_for_condition(_all_actors_dead)


def test_callable_classes(shutdown_only):
ray.init(num_cpus=2)
Expand Down
Loading