Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH add get_worker_rank with unique rank #285

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
7 changes: 5 additions & 2 deletions loky/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
r"""The :mod:`loky` module manages a pool of worker that can be re-used across time.
It provides a robust and dynamic implementation os the
r"""The :mod:`loky` module manages a pool of worker that can be re-used across
time. It provides a robust and dynamic implementation of the
:class:`ProcessPoolExecutor` and a function :func:`get_reusable_executor` which
hide the pool management under the hood.
"""

from concurrent.futures import (
ALL_COMPLETED,
FIRST_COMPLETED,
Expand All @@ -20,6 +21,7 @@
from .reusable_executor import get_reusable_executor
from .cloudpickle_wrapper import wrap_non_picklable_objects
from .process_executor import BrokenProcessPool, ProcessPoolExecutor
from .process_executor import get_worker_rank


__all__ = [
Expand All @@ -37,6 +39,7 @@
"FIRST_EXCEPTION",
"ALL_COMPLETED",
"wrap_non_picklable_objects",
"get_worker_rank",
"set_loky_pickler",
]

Expand Down
73 changes: 71 additions & 2 deletions loky/process_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,33 @@ def _get_memory_usage(pid, force_gc=False):
except ImportError:
_USE_PSUTIL = False

# Mechanism to obtain the rank of a worker and the total number of workers in
# the executor.
_WORKER_RANK = None
_WORKER_WORLD = None


def get_worker_rank():
"""Returns the rank of the worker and the number of workers in the executor

This helper function should only be called in a worker, else it will throw
a RuntimeError.
"""
if _WORKER_RANK is None:
raise RuntimeError(
"get_worker_id, should only be called in a worker, not in the "
"main process."
)
return _WORKER_RANK, _WORKER_WORLD


def set_worker_rank(pid, rank_mapper):
"""Set worker's rank and world size from the process pid and an rank_mapper."""
global _WORKER_RANK, _WORKER_WORLD
if pid in rank_mapper:
_WORKER_RANK = rank_mapper[pid]
_WORKER_WORLD = rank_mapper["world"]


class _ThreadWakeup:
def __init__(self):
Expand Down Expand Up @@ -277,11 +304,12 @@ def __init__(self, work_id, exception=None, result=None):


class _CallItem:
def __init__(self, work_id, fn, args, kwargs):
def __init__(self, work_id, fn, args, kwargs, rank_mapper):
self.work_id = work_id
self.fn = fn
self.args = args
self.kwargs = kwargs
self.rank_mapper = rank_mapper

# Store the current loky_pickler so it is correctly set in the worker
self.loky_pickler = get_loky_pickler_name()
Expand Down Expand Up @@ -384,6 +412,7 @@ def _process_worker(
timeout,
worker_exit_lock,
current_depth,
rank_mapper,
):
"""Evaluates calls from call_queue and places the results in result_queue.

Expand All @@ -403,6 +432,8 @@ def _process_worker(
worker_exit_lock: Lock to avoid flagging the executor as broken on
workers timeout.
current_depth: Nested parallelism level, to avoid infinite spawning.
rank_mapper: Initial value for rank and world as a dict with keys None
and world.
"""
if initializer is not None:
try:
Expand All @@ -420,6 +451,13 @@ def _process_worker(
_last_memory_leak_check = None
pid = os.getpid()

# Passing an initial value is necessary as some jobs can be sent and
# serialized before this worker is created. In this case, no rank is
# available in the call_item.rank_mapper and this rank is the correct one.
# When initialized, main process does not know the pid and pass the worker
# rank as None.
set_worker_rank(None, rank_mapper)

mp.util.debug(f"Worker started with timeout={timeout}")
while True:
try:
Expand Down Expand Up @@ -447,6 +485,7 @@ def _process_worker(
if call_item is None:
# Notify queue management thread about worker shutdown
result_queue.put(pid)

is_clean = worker_exit_lock.acquire(True, timeout=30)

# Early notify any loky executor running in this worker process
Expand All @@ -459,6 +498,10 @@ def _process_worker(
else:
mp.util.info("Main process did not release worker_exit")
return

# If the executor has been resized, this new rank mapper might contain
# new rank/world info. Correct the value before runnning the task.
set_worker_rank(pid, call_item.rank_mapper)
try:
r = call_item()
except BaseException as e:
Expand Down Expand Up @@ -583,6 +626,10 @@ def weakref_cb(
# of new processes or shut down
self.processes_management_lock = executor._processes_management_lock

# A dict mapping the workers' pid to their rank. Also contains the
# current size of the executor associated to 'world' key.
self.rank_mapper = executor._rank_mapper

super().__init__(name="ExecutorManagerThread")
if sys.version_info < (3, 9):
self.daemon = True
Expand Down Expand Up @@ -634,6 +681,7 @@ def add_call_item_to_queue(self):
work_item.fn,
work_item.args,
work_item.kwargs,
self.rank_mapper,
),
block=True,
)
Expand Down Expand Up @@ -727,6 +775,7 @@ def process_result_item(self, result_item):
# itself: we should not mark the executor as broken.
with self.processes_management_lock:
p = self.processes.pop(result_item, None)
del self.rank_mapper[result_item]

# p can be None if the executor is concurrently shutting down.
if p is not None:
Expand Down Expand Up @@ -1014,7 +1063,6 @@ class TerminatedWorkerError(BrokenProcessPool):


class ShutdownExecutorError(RuntimeError):

"""
Raised when a ProcessPoolExecutor is shutdown while a future was in the
running or pending state.
Expand Down Expand Up @@ -1128,6 +1176,10 @@ def __init__(
# Finally setup the queues for interprocess communication
self._setup_queues(job_reducers, result_reducers)

# A dict mapping the workers' pid to their rank. The current size of
# the executor is associated with the 'world' key.
self._rank_mapper = {"world": max_workers}

mp.util.debug("ProcessPoolExecutor is setup")

def _setup_queues(self, job_reducers, result_reducers, queue_size=None):
Expand Down Expand Up @@ -1184,8 +1236,16 @@ def _start_executor_manager_thread(self):
)

def _adjust_process_count(self):
# Compute available worker ranks for newly spawned workers
given_ranks = set(
v for k, v in self._rank_mapper.items() if k != "world"
)
all_ranks = set(range(self._max_workers))
available_ranks = all_ranks - given_ranks

while len(self._processes) < self._max_workers:
worker_exit_lock = self._context.BoundedSemaphore(1)
rank = available_ranks.pop()
args = (
self._call_queue,
self._result_queue,
Expand All @@ -1195,6 +1255,7 @@ def _adjust_process_count(self):
self._timeout,
worker_exit_lock,
_CURRENT_DEPTH + 1,
{None: rank, "world": self._max_workers},
)
worker_exit_lock.acquire()
try:
Expand All @@ -1208,6 +1269,14 @@ def _adjust_process_count(self):
p._worker_exit_lock = worker_exit_lock
p.start()
self._processes[p.pid] = p
self._rank_mapper[p.pid] = rank

# Reassign rank that are too high to rank that are still available.
# They will be passed to the workers when sending the tasks with
# the CallItem.
for pid, rank in list(self._rank_mapper.items()):
if pid != "world" and rank >= self._max_workers:
self._rank_mapper[pid] = available_ranks.pop()
mp.util.debug(
f"Adjusted process count to {self._max_workers}: "
f"{[(p.name, pid) for pid, p in self._processes.items()]}"
Expand Down
4 changes: 4 additions & 0 deletions loky/reusable_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,14 @@ def _resize(self, max_workers):
# then no processes have been spawned and we can just
# update _max_workers and return
self._max_workers = max_workers
self._rank_mapper["world"] = max_workers
return

self._wait_job_completion()

# Set the new size to be broadcasted to the workers
self._rank_mapper["world"] = max_workers

# Some process might have returned due to timeout so check how many
# children are still alive. Use the _process_management_lock to
# ensure that no process are spawned or timeout during the resize.
Expand Down
28 changes: 28 additions & 0 deletions tests/_test_process_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from math import sqrt
from pickle import PicklingError
from threading import Thread
import multiprocessing as mp
from collections import defaultdict

from concurrent import futures
from concurrent.futures._base import (
PENDING,
Expand Down Expand Up @@ -1125,6 +1127,32 @@ def test_child_env_executor(self):

executor.shutdown(wait=True)

@staticmethod
def _worker_rank(x):
time.sleep(0.2)
rank, world = loky.get_worker_rank()
return dict(
pid=os.getpid(),
name=mp.current_process().name,
rank=rank,
world=world,
)

@pytest.mark.parametrize("max_workers", [1, 5, 13])
@pytest.mark.parametrize("timeout", [None, 0.01])
def test_workers_rank(self, max_workers, timeout):
executor = self.executor_type(max_workers, timeout=timeout)
results = executor.map(self._worker_rank, range(max_workers * 5))
workers_rank = {}
for f in results:
assert f["world"] == max_workers
rank = workers_rank.get(f["pid"], None)
assert rank is None or rank == f["rank"]
workers_rank[f["pid"]] = f["rank"]
msg = ", ".join(f"{k}, {v}" for k, v in executor._rank_mapper.items())
assert set(workers_rank.values()) == set(range(max_workers)), msg
executor.shutdown(wait=True, kill_workers=True)

def test_viztracer_profiler(self):
# Check that viztracer profiler is initialzed in workers when
# installed.
Expand Down
8 changes: 7 additions & 1 deletion tests/test_loky_module.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import multiprocessing as mp
import os
import sys
import shutil
import tempfile
import warnings
import multiprocessing as mp
from subprocess import check_output

import pytest

import loky
from loky import cpu_count
from loky import get_worker_rank
from loky.backend.context import _cpu_count_user, _MAX_WINDOWS_WORKERS


Expand Down Expand Up @@ -215,3 +216,8 @@ def test_only_physical_cores_with_user_limitation():
if cpu_count_user < cpu_count_mp:
assert cpu_count() == cpu_count_user
assert cpu_count(only_physical_cores=True) == cpu_count_user


def test_worker_rank_in_worker_only():
with pytest.raises(RuntimeError):
get_worker_rank()
41 changes: 39 additions & 2 deletions tests/test_reusable_executor.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import os
import subprocess
import sys
import gc
import sys
import time
import ctypes
from tempfile import NamedTemporaryFile
import pytest
import warnings
import threading
import subprocess
from time import sleep
import multiprocessing as mp
from multiprocessing import util, current_process
from pickle import PicklingError, UnpicklingError
from tempfile import NamedTemporaryFile

import cloudpickle
from packaging.version import Version

import loky
from loky import cpu_count
from loky import get_worker_rank
from loky import get_reusable_executor
from loky.process_executor import _RemoteTraceback, TerminatedWorkerError
from loky.process_executor import BrokenProcessPool, ShutdownExecutorError
Expand Down Expand Up @@ -679,6 +683,39 @@ def test_resize_after_timeout(self):
expected_msg = "A worker stopped"
assert expected_msg in recorded_warnings[0].message.args[0]

@staticmethod
def _worker_rank(x):
time.sleep(0.2)
rank, world = get_worker_rank()
return dict(
pid=os.getpid(),
name=mp.current_process().name,
rank=rank,
world=world,
)

def test_workers_rank_resize(self):

executor = get_reusable_executor(max_workers=2)

with warnings.catch_warnings(record=True):
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
for size in [12, 2, 1, 12, 6, 1, 8, 5]:
executor = get_reusable_executor(max_workers=size, reuse=True)
results = executor.map(self._worker_rank, range(size * 5))
executor.map(sleep, [0.01] * 6)
workers_rank = {}
for f in results:
assert f["world"] == size
rank = workers_rank.get(f["pid"], None)
assert rank is None or rank == f["rank"]
workers_rank[f["pid"]] = f["rank"]
msg = ", ".join(
f"{k}: {v}" for k, v in executor._rank_mapper.items()
)
assert set(workers_rank.values()) == set(range(size)), msg


class TestGetReusableExecutor(ReusableExecutorMixin):
def test_invalid_process_number(self):
Expand Down
Loading