Skip to content

Commit

Permalink
Unify ThreadPoolExecutor usage in all Dashboard(Agent)Head. (#47160)
Browse files Browse the repository at this point in the history
Previously each Dashboard(Agent)HeadModule can have its own
ThreadPoolExecutor. This PR makes a unified TPE in Dashboard(Agent)Head,
and uses it everywhere. Also adds asyncio yield in
DataOrganizer.organize() to avoid event loop blocking in
DataOrganizer.organize for big time.


Signed-off-by: Ruiyang Wang <[email protected]>
  • Loading branch information
rynewang authored Aug 27, 2024
1 parent 579995c commit 6c7da02
Show file tree
Hide file tree
Showing 15 changed files with 115 additions and 55 deletions.
8 changes: 7 additions & 1 deletion python/ray/_private/state_api_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from copy import deepcopy
from collections import defaultdict
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
import logging
import numpy as np
Expand Down Expand Up @@ -324,7 +325,12 @@ def get_state_api_manager(gcs_address: str) -> StateAPIManager:
state_api_data_source_client = StateDataSourceClient(
gcs_channel.channel(), gcs_aio_client
)
return StateAPIManager(state_api_data_source_client)
return StateAPIManager(
state_api_data_source_client,
thread_pool_executor=ThreadPoolExecutor(
thread_name_prefix="state_api_test_utils"
),
)


def summarize_worker_startup_time():
Expand Down
5 changes: 5 additions & 0 deletions python/ray/dashboard/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pathlib
import signal
import sys
from concurrent.futures import ThreadPoolExecutor

import ray
import ray._private.ray_constants as ray_constants
Expand Down Expand Up @@ -47,6 +48,10 @@ def __init__(
# Public attributes are accessible for all agent modules.
self.ip = node_ip_address
self.minimal = minimal
self.thread_pool_executor = ThreadPoolExecutor(
max_workers=dashboard_consts.RAY_AGENT_THREAD_POOL_MAX_WORKERS,
thread_name_prefix="dashboard_agent_tpe",
)

assert gcs_address is not None
self.gcs_address = gcs_address
Expand Down
6 changes: 6 additions & 0 deletions python/ray/dashboard/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
# Example: "your.module.ray_cluster_activity_hook".
RAY_CLUSTER_ACTIVITY_HOOK = "RAY_CLUSTER_ACTIVITY_HOOK"

# Works in the thread pool should not starve the main thread loop, so we default to 1.
RAY_DASHBOARD_THREAD_POOL_MAX_WORKERS = env_integer(
"RAY_DASHBOARD_THREAD_POOL_MAX_WORKERS", 1
)
RAY_AGENT_THREAD_POOL_MAX_WORKERS = env_integer("RAY_AGENT_THREAD_POOL_MAX_WORKERS", 1)

# The number of candidate agents
CANDIDATE_AGENT_NUMBER = max(env_integer("CANDIDATE_AGENT_NUMBER", 1), 1)
# when head receive JobSubmitRequest, maybe not any agent is available,
Expand Down
29 changes: 23 additions & 6 deletions python/ray/dashboard/datacenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, List, Optional

import ray.dashboard.consts as dashboard_consts
from ray._private.utils import get_or_create_event_loop
from ray.dashboard.utils import (
Dict,
MutableNotificationDict,
Expand All @@ -13,6 +14,7 @@
logger = logging.getLogger(__name__)


# NOT thread safe. Every assignment must be on the main event loop thread.
class DataSource:
# {node id hex(str): node stats(dict of GetNodeStatsReply
# in node_manager.proto)}
Expand Down Expand Up @@ -61,12 +63,29 @@ async def purge():

@classmethod
@async_loop_forever(dashboard_consts.RAY_DASHBOARD_STATS_UPDATING_INTERVAL)
async def organize(cls):
async def organize(cls, thread_pool_executor):
"""
Organizes data: read from (node_physical_stats, node_stats) and updates
(node_workers, node_worker_stats).
This methods is not really async, but DataSource is not thread safe so we need
to make sure it's on the main event loop thread. To avoid blocking the main
event loop, we yield after each node processed.
"""
node_workers = {}
core_worker_stats = {}
# await inside for loop, so we create a copy of keys().
# nodes may change during process, so we create a copy of keys().
for node_id in list(DataSource.nodes.keys()):
workers = await cls.get_node_workers(node_id)
node_physical_stats = DataSource.node_physical_stats.get(node_id, {})
node_stats = DataSource.node_stats.get(node_id, {})
# Offloads the blocking operation to a thread pool executor. This also
# yields to the event loop.
workers = await get_or_create_event_loop().run_in_executor(
thread_pool_executor,
cls.merge_workers_for_node,
node_physical_stats,
node_stats,
)
for worker in workers:
stats = worker.get("coreWorkerStats", {})
if stats:
Expand All @@ -77,10 +96,8 @@ async def organize(cls):
DataSource.core_worker_stats.reset(core_worker_stats)

@classmethod
async def get_node_workers(cls, node_id):
def merge_workers_for_node(cls, node_physical_stats, node_stats):
workers = []
node_physical_stats = DataSource.node_physical_stats.get(node_id, {})
node_stats = DataSource.node_stats.get(node_id, {})
# Merge coreWorkerStats (node stats) to workers (node physical stats)
pid_to_worker_stats = {}
pid_to_language = {}
Expand Down
10 changes: 9 additions & 1 deletion python/ray/dashboard/head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Optional, Set

Expand Down Expand Up @@ -96,6 +97,13 @@ def __init__(
self._modules_to_load = modules_to_load
self._modules_loaded = False

# A TPE holding background, compute-heavy, latency-tolerant jobs, typically
# state updates.
self._thread_pool_executor = ThreadPoolExecutor(
max_workers=dashboard_consts.RAY_DASHBOARD_THREAD_POOL_MAX_WORKERS,
thread_name_prefix="dashboard_head_tpe",
)

self.gcs_address = None
assert gcs_address is not None
self.gcs_address = gcs_address
Expand Down Expand Up @@ -312,7 +320,7 @@ async def _async_notify():
self._gcs_check_alive(),
_async_notify(),
DataOrganizer.purge(),
DataOrganizer.organize(),
DataOrganizer.organize(self._thread_pool_executor),
]
for m in modules:
concurrent_tasks.append(m.run(self.server))
Expand Down
6 changes: 1 addition & 5 deletions python/ray/dashboard/modules/event/event_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Union

import ray._private.ray_constants as ray_constants
Expand All @@ -26,9 +25,6 @@ def __init__(self, dashboard_agent):
self._stub: Union[event_pb2_grpc.ReportEventServiceStub, None] = None
self._cached_events = asyncio.Queue(event_consts.EVENT_AGENT_CACHE_SIZE)
self._gcs_aio_client = dashboard_agent.gcs_aio_client
self.monitor_thread_pool_executor = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="event_monitor"
)
# Total number of event created from this agent.
self.total_event_reported = 0
# Total number of event report request sent.
Expand Down Expand Up @@ -111,7 +107,7 @@ async def run(self, server):
self._monitor = monitor_events(
self._event_dir,
lambda data: create_task(self._cached_events.put(data)),
self.monitor_thread_pool_executor,
self._dashboard_agent.thread_pool_executor,
)

await asyncio.gather(
Expand Down
6 changes: 1 addition & 5 deletions python/ray/dashboard/modules/event/event_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import time
from collections import OrderedDict, defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Union

import aiohttp.web
Expand Down Expand Up @@ -31,9 +30,6 @@ def __init__(self, dashboard_head):
self._event_dir = os.path.join(self._dashboard_head.log_dir, "events")
os.makedirs(self._event_dir, exist_ok=True)
self._monitor: Union[asyncio.Task, None] = None
self.monitor_thread_pool_executor = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="event_monitor"
)
self.total_report_events_count = 0
self.total_events_received = 0
self.module_started = time.monotonic()
Expand Down Expand Up @@ -111,7 +107,7 @@ async def run(self, server):
self._monitor = monitor_events(
self._event_dir,
lambda data: self._update_events(parse_event_strings(data)),
self.monitor_thread_pool_executor,
self._dashboard_head._thread_pool_executor,
)

@staticmethod
Expand Down
6 changes: 1 addition & 5 deletions python/ray/dashboard/modules/job/job_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import json
import logging
import traceback
from concurrent.futures import ThreadPoolExecutor
from random import sample
from typing import AsyncIterator, List, Optional

Expand Down Expand Up @@ -163,9 +162,6 @@ def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._gcs_aio_client = dashboard_head.gcs_aio_client
self._job_info_client = None
self._upload_package_thread_pool_executor = ThreadPoolExecutor(
thread_name_prefix="job_head.upload_package"
)

# It contains all `JobAgentSubmissionClient` that
# `JobHead` has ever used, and will not be deleted
Expand Down Expand Up @@ -317,7 +313,7 @@ async def upload_package(self, req: Request):
try:
data = await req.read()
await get_or_create_event_loop().run_in_executor(
self._upload_package_thread_pool_executor,
self._dashboard_head._thread_pool_executor,
upload_package_to_gcs,
package_uri,
data,
Expand Down
56 changes: 33 additions & 23 deletions python/ray/dashboard/modules/node/node_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ray import NodeID
from ray._private import ray_constants
from ray._private.ray_constants import DEBUG_AUTOSCALING_ERROR, DEBUG_AUTOSCALING_STATUS
from ray._private.utils import get_or_create_event_loop
from ray.autoscaler._private.util import (
LoadMetricsSummary,
get_per_node_breakdown_as_dict,
Expand Down Expand Up @@ -395,31 +396,40 @@ async def _update_node_stats(self):
return_exceptions=True,
)

for node_info, reply in zip(nodes, replies):
node_id, _ = node_info
if isinstance(reply, asyncio.CancelledError):
pass
elif isinstance(reply, grpc.RpcError):
if reply.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
logger.exception(
f"Cannot reach the node, {node_id}, after timeout {TIMEOUT}. "
"This node may have been overloaded, terminated, or "
"the network is slow."
)
elif reply.code() == grpc.StatusCode.UNAVAILABLE:
logger.exception(
f"Cannot reach the node, {node_id}. "
"The node may have been terminated."
)
else:
def postprocess(nodes, replies):
"""Pure function reorganizing the data into {node_id: stats}."""
new_node_stats = {}
for node_info, reply in zip(nodes, replies):
node_id, _ = node_info
if isinstance(reply, asyncio.CancelledError):
pass
elif isinstance(reply, grpc.RpcError):
if reply.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
logger.exception(
f"Cannot reach the node, {node_id}, after timeout "
f" {TIMEOUT}. This node may have been overloaded, "
"terminated, or the network is slow."
)
elif reply.code() == grpc.StatusCode.UNAVAILABLE:
logger.exception(
f"Cannot reach the node, {node_id}. "
"The node may have been terminated."
)
else:
logger.exception(f"Error updating node stats of {node_id}.")
logger.exception(reply)
elif isinstance(reply, Exception):
logger.exception(f"Error updating node stats of {node_id}.")
logger.exception(reply)
elif isinstance(reply, Exception):
logger.exception(f"Error updating node stats of {node_id}.")
logger.exception(reply)
else:
reply_dict = node_stats_to_dict(reply)
DataSource.node_stats[node_id] = reply_dict
else:
new_node_stats[node_id] = node_stats_to_dict(reply)
return new_node_stats

new_node_stats = await get_or_create_event_loop().run_in_executor(
self._dashboard_head._thread_pool_executor, postprocess, nodes, replies
)
for node_id, new_stat in new_node_stats.items():
DataSource.node_stats[node_id] = new_stat

async def run(self, server):
self.get_all_node_info = GetAllNodeInfo(self._dashboard_head)
Expand Down
4 changes: 3 additions & 1 deletion python/ray/dashboard/modules/reporter/reporter_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,9 @@ async def _run_loop(self, publisher):
# NOTE: Stats collection is executed inside the thread-pool
# executor (TPE) to avoid blocking the Dashboard's event-loop
json_payload = await loop.run_in_executor(
None, self._compose_stats_payload, autoscaler_status_json_bytes
self._dashboard_agent.thread_pool_executor,
self._compose_stats_payload,
autoscaler_status_json_bytes,
)

await publisher.publish_resource_usage(self._key, json_payload)
Expand Down
10 changes: 8 additions & 2 deletions python/ray/dashboard/modules/reporter/reporter_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,11 @@ async def run(self, server):
gcs_channel, self._dashboard_head.gcs_aio_client
)
# Set up the state API in order to fetch task information.
self._state_api = StateAPIManager(self._state_api_data_source_client)
# TODO(ryw): unify the StateAPIManager in reporter_head and state_head.
self._state_api = StateAPIManager(
self._state_api_data_source_client,
self._dashboard_head._thread_pool_executor,
)

# Need daemon True to avoid dashboard hangs at exit.
self.service_discovery.daemon = True
Expand All @@ -635,7 +639,9 @@ async def run(self, server):

# NOTE: Every iteration is executed inside the thread-pool executor
# (TPE) to avoid blocking the Dashboard's event-loop
parsed_data = await loop.run_in_executor(None, json.loads, data)
parsed_data = await loop.run_in_executor(
self._dashboard_head._thread_pool_executor, json.loads, data
)

node_id = key.split(":")[-1]
DataSource.node_physical_stats[node_id] = parsed_data
Expand Down
5 changes: 4 additions & 1 deletion python/ray/dashboard/modules/state/state_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,10 @@ async def run(self, server):
self._state_api_data_source_client = StateDataSourceClient(
gcs_channel, self._dashboard_head.gcs_aio_client
)
self._state_api = StateAPIManager(self._state_api_data_source_client)
self._state_api = StateAPIManager(
self._state_api_data_source_client,
self._dashboard_head._thread_pool_executor,
)
self._log_api = LogsManager(self._state_api_data_source_client)

@staticmethod
Expand Down
9 changes: 6 additions & 3 deletions python/ray/dashboard/state_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,13 @@ class StateAPIManager:
the entries.
"""

def __init__(self, state_data_source_client: StateDataSourceClient):
def __init__(
self,
state_data_source_client: StateDataSourceClient,
thread_pool_executor: ThreadPoolExecutor,
):
self._client = state_data_source_client

self._thread_pool_executor = ThreadPoolExecutor(thread_name_prefix="state_head")
self._thread_pool_executor = thread_pool_executor

@property
def data_source_client(self):
Expand Down
5 changes: 4 additions & 1 deletion python/ray/tests/test_state_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import signal
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple
from unittest.mock import MagicMock, AsyncMock

Expand Down Expand Up @@ -130,7 +131,9 @@
@pytest.fixture
def state_api_manager():
data_source_client = AsyncMock(StateDataSourceClient)
manager = StateAPIManager(data_source_client)
manager = StateAPIManager(
data_source_client, thread_pool_executor=ThreadPoolExecutor()
)
yield manager


Expand Down
5 changes: 4 additions & 1 deletion python/ray/tests/test_state_api_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import sys
from dataclasses import asdict
from concurrent.futures import ThreadPoolExecutor

from ray.util.state import (
summarize_tasks,
Expand Down Expand Up @@ -42,7 +43,9 @@
@pytest.fixture
def state_api_manager():
data_source_client = AsyncMock(StateDataSourceClient)
manager = StateAPIManager(data_source_client)
manager = StateAPIManager(
data_source_client, thread_pool_executor=ThreadPoolExecutor()
)
yield manager


Expand Down

0 comments on commit 6c7da02

Please sign in to comment.