Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify ThreadPoolExecutor usage in all Dashboard(Agent)Head. #47160

Merged
merged 10 commits into from
Aug 27, 2024
8 changes: 7 additions & 1 deletion python/ray/_private/state_api_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from copy import deepcopy
from collections import defaultdict
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
import logging
import numpy as np
Expand Down Expand Up @@ -324,7 +325,12 @@ def get_state_api_manager(gcs_address: str) -> StateAPIManager:
state_api_data_source_client = StateDataSourceClient(
gcs_channel.channel(), gcs_aio_client
)
return StateAPIManager(state_api_data_source_client)
return StateAPIManager(
state_api_data_source_client,
thread_pool_executor=ThreadPoolExecutor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provided that we're planning on offloading CPU-bound tasks that however still hold on to GIL, we should limit # of threads in the TPE (by default TPE provisions at # of CPUs + 4 threads)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed in head.py, I mean

thread_name_prefix="state_api_test_utils"
),
)


def summarize_worker_startup_time():
Expand Down
5 changes: 5 additions & 0 deletions python/ray/dashboard/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pathlib
import signal
import sys
from concurrent.futures import ThreadPoolExecutor

import ray
import ray._private.ray_constants as ray_constants
Expand Down Expand Up @@ -47,6 +48,10 @@ def __init__(
# Public attributes are accessible for all agent modules.
self.ip = node_ip_address
self.minimal = minimal
self.thread_pool_executor = ThreadPoolExecutor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check my comment above.

We need to

  • Isolate these TPEs to non critical operations (like refreshing stats/data used in UI)
  • Limit concurrency of these TPEs (to 2-4)

max_workers=dashboard_consts.RAY_AGENT_THREAD_POOL_MAX_WORKERS,
thread_name_prefix="dashboard_agent_tpe",
)

assert gcs_address is not None
self.gcs_address = gcs_address
Expand Down
6 changes: 6 additions & 0 deletions python/ray/dashboard/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
# Example: "your.module.ray_cluster_activity_hook".
RAY_CLUSTER_ACTIVITY_HOOK = "RAY_CLUSTER_ACTIVITY_HOOK"

# Works in the thread pool should not starve the main thread loop, so we default to 1.
RAY_DASHBOARD_THREAD_POOL_MAX_WORKERS = env_integer(
"RAY_DASHBOARD_THREAD_POOL_MAX_WORKERS", 1
)
RAY_AGENT_THREAD_POOL_MAX_WORKERS = env_integer("RAY_AGENT_THREAD_POOL_MAX_WORKERS", 1)

# The number of candidate agents
CANDIDATE_AGENT_NUMBER = max(env_integer("CANDIDATE_AGENT_NUMBER", 1), 1)
# when head receive JobSubmitRequest, maybe not any agent is available,
Expand Down
29 changes: 23 additions & 6 deletions python/ray/dashboard/datacenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, List, Optional

import ray.dashboard.consts as dashboard_consts
from ray._private.utils import get_or_create_event_loop
from ray.dashboard.utils import (
Dict,
MutableNotificationDict,
Expand All @@ -13,6 +14,7 @@
logger = logging.getLogger(__name__)


# NOT thread safe. Every assignment must be on the main event loop thread.
class DataSource:
# {node id hex(str): node stats(dict of GetNodeStatsReply
# in node_manager.proto)}
Expand Down Expand Up @@ -64,12 +66,29 @@ async def purge():

@classmethod
@async_loop_forever(dashboard_consts.RAY_DASHBOARD_STATS_UPDATING_INTERVAL)
async def organize(cls):
async def organize(cls, thread_pool_executor):
"""
Organizes data: read from (node_physical_stats, node_stats) and updates
(node_workers, node_worker_stats).

This methods is not really async, but DataSource is not thread safe so we need
to make sure it's on the main event loop thread. To avoid blocking the main
event loop, we yield after each node processed.
"""
node_workers = {}
core_worker_stats = {}
# await inside for loop, so we create a copy of keys().
# nodes may change during process, so we create a copy of keys().
for node_id in list(DataSource.nodes.keys()):
workers = await cls.get_node_workers(node_id)
node_physical_stats = DataSource.node_physical_stats.get(node_id, {})
node_stats = DataSource.node_stats.get(node_id, {})
# Offloads the blocking operation to a thread pool executor. This also
# yields to the event loop.
workers = await get_or_create_event_loop().run_in_executor(
thread_pool_executor,
cls.merge_workers_for_node,
node_physical_stats,
node_stats,
)
for worker in workers:
stats = worker.get("coreWorkerStats", {})
if stats:
Expand All @@ -80,10 +99,8 @@ async def organize(cls):
DataSource.core_worker_stats.reset(core_worker_stats)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is DataSource thread safe given now it's accessed in another thread as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

These custom data-structures aren't thread-safe. If we're planning to run this on TPE we need to cover these with locks


@classmethod
async def get_node_workers(cls, node_id):
def merge_workers_for_node(cls, node_physical_stats, node_stats):
workers = []
node_physical_stats = DataSource.node_physical_stats.get(node_id, {})
node_stats = DataSource.node_stats.get(node_id, {})
# Merge coreWorkerStats (node stats) to workers (node physical stats)
pid_to_worker_stats = {}
pid_to_language = {}
Expand Down
10 changes: 9 additions & 1 deletion python/ray/dashboard/head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Optional, Set

Expand Down Expand Up @@ -96,6 +97,13 @@ def __init__(
self._modules_to_load = modules_to_load
self._modules_loaded = False

# A TPE holding background, compute-heavy, latency-tolerant jobs, typically
# state updates.
self._thread_pool_executor = ThreadPoolExecutor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check comment above

max_workers=dashboard_consts.RAY_DASHBOARD_THREAD_POOL_MAX_WORKERS,
thread_name_prefix="dashboard_head_tpe",
)

self.gcs_address = None
assert gcs_address is not None
self.gcs_address = gcs_address
Expand Down Expand Up @@ -312,7 +320,7 @@ async def _async_notify():
self._gcs_check_alive(),
_async_notify(),
DataOrganizer.purge(),
DataOrganizer.organize(),
DataOrganizer.organize(self._thread_pool_executor),
]
for m in modules:
concurrent_tasks.append(m.run(self.server))
Expand Down
6 changes: 1 addition & 5 deletions python/ray/dashboard/modules/event/event_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Union

import ray._private.ray_constants as ray_constants
Expand All @@ -26,9 +25,6 @@ def __init__(self, dashboard_agent):
self._stub: Union[event_pb2_grpc.ReportEventServiceStub, None] = None
self._cached_events = asyncio.Queue(event_consts.EVENT_AGENT_CACHE_SIZE)
self._gcs_aio_client = dashboard_agent.gcs_aio_client
self.monitor_thread_pool_executor = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="event_monitor"
)
# Total number of event created from this agent.
self.total_event_reported = 0
# Total number of event report request sent.
Expand Down Expand Up @@ -111,7 +107,7 @@ async def run(self, server):
self._monitor = monitor_events(
self._event_dir,
lambda data: create_task(self._cached_events.put(data)),
self.monitor_thread_pool_executor,
self._dashboard_agent.thread_pool_executor,
)

await asyncio.gather(
Expand Down
6 changes: 1 addition & 5 deletions python/ray/dashboard/modules/event/event_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import time
from collections import OrderedDict, defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Union

import aiohttp.web
Expand Down Expand Up @@ -31,9 +30,6 @@ def __init__(self, dashboard_head):
self._event_dir = os.path.join(self._dashboard_head.log_dir, "events")
os.makedirs(self._event_dir, exist_ok=True)
self._monitor: Union[asyncio.Task, None] = None
self.monitor_thread_pool_executor = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="event_monitor"
)
self.total_report_events_count = 0
self.total_events_received = 0
self.module_started = time.monotonic()
Expand Down Expand Up @@ -111,7 +107,7 @@ async def run(self, server):
self._monitor = monitor_events(
self._event_dir,
lambda data: self._update_events(parse_event_strings(data)),
self.monitor_thread_pool_executor,
self._dashboard_head._thread_pool_executor,
)

@staticmethod
Expand Down
6 changes: 1 addition & 5 deletions python/ray/dashboard/modules/job/job_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import json
import logging
import traceback
from concurrent.futures import ThreadPoolExecutor
from random import sample
from typing import AsyncIterator, List, Optional

Expand Down Expand Up @@ -163,9 +162,6 @@ def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._gcs_aio_client = dashboard_head.gcs_aio_client
self._job_info_client = None
self._upload_package_thread_pool_executor = ThreadPoolExecutor(
thread_name_prefix="job_head.upload_package"
)

# It contains all `JobAgentSubmissionClient` that
# `JobHead` has ever used, and will not be deleted
Expand Down Expand Up @@ -317,7 +313,7 @@ async def upload_package(self, req: Request):
try:
data = await req.read()
await get_or_create_event_loop().run_in_executor(
self._upload_package_thread_pool_executor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it critical enough to deserve its own thread pool?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so because it's not latency critical either.

self._dashboard_head._thread_pool_executor,
upload_package_to_gcs,
package_uri,
data,
Expand Down
56 changes: 33 additions & 23 deletions python/ray/dashboard/modules/node/node_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ray import NodeID
from ray._private import ray_constants
from ray._private.ray_constants import DEBUG_AUTOSCALING_ERROR, DEBUG_AUTOSCALING_STATUS
from ray._private.utils import get_or_create_event_loop
from ray.autoscaler._private.util import (
LoadMetricsSummary,
get_per_node_breakdown_as_dict,
Expand Down Expand Up @@ -399,31 +400,40 @@ async def _update_node_stats(self):
return_exceptions=True,
)

for node_info, reply in zip(nodes, replies):
node_id, _ = node_info
if isinstance(reply, asyncio.CancelledError):
pass
elif isinstance(reply, grpc.RpcError):
if reply.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
logger.exception(
f"Cannot reach the node, {node_id}, after timeout {TIMEOUT}. "
"This node may have been overloaded, terminated, or "
"the network is slow."
)
elif reply.code() == grpc.StatusCode.UNAVAILABLE:
logger.exception(
f"Cannot reach the node, {node_id}. "
"The node may have been terminated."
)
else:
def postprocess(nodes, replies):
"""Pure function reorganizing the data into {node_id: stats}."""
new_node_stats = {}
for node_info, reply in zip(nodes, replies):
node_id, _ = node_info
if isinstance(reply, asyncio.CancelledError):
pass
elif isinstance(reply, grpc.RpcError):
if reply.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
logger.exception(
f"Cannot reach the node, {node_id}, after timeout "
f" {TIMEOUT}. This node may have been overloaded, "
"terminated, or the network is slow."
)
elif reply.code() == grpc.StatusCode.UNAVAILABLE:
logger.exception(
f"Cannot reach the node, {node_id}. "
"The node may have been terminated."
)
else:
logger.exception(f"Error updating node stats of {node_id}.")
logger.exception(reply)
elif isinstance(reply, Exception):
logger.exception(f"Error updating node stats of {node_id}.")
logger.exception(reply)
elif isinstance(reply, Exception):
logger.exception(f"Error updating node stats of {node_id}.")
logger.exception(reply)
else:
reply_dict = node_stats_to_dict(reply)
DataSource.node_stats[node_id] = reply_dict
else:
new_node_stats[node_id] = node_stats_to_dict(reply)
return new_node_stats

new_node_stats = await get_or_create_event_loop().run_in_executor(
self._dashboard_head._thread_pool_executor, postprocess, nodes, replies
)
for node_id, new_stat in new_node_stats.items():
DataSource.node_stats[node_id] = new_stat

async def run(self, server):
self.get_all_node_info = GetAllNodeInfo(self._dashboard_head)
Expand Down
4 changes: 3 additions & 1 deletion python/ray/dashboard/modules/reporter/reporter_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,9 @@ async def _run_loop(self, publisher):
# NOTE: Stats collection is executed inside the thread-pool
# executor (TPE) to avoid blocking the Dashboard's event-loop
json_payload = await loop.run_in_executor(
None, self._compose_stats_payload, autoscaler_status_json_bytes
self._dashboard_agent.thread_pool_executor,
self._compose_stats_payload,
autoscaler_status_json_bytes,
)

await publisher.publish_resource_usage(self._key, json_payload)
Expand Down
10 changes: 8 additions & 2 deletions python/ray/dashboard/modules/reporter/reporter_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,11 @@ async def run(self, server):
gcs_channel, self._dashboard_head.gcs_aio_client
)
# Set up the state API in order to fetch task information.
self._state_api = StateAPIManager(self._state_api_data_source_client)
# TODO(ryw): unify the StateAPIManager in reporter_head and state_head.
self._state_api = StateAPIManager(
self._state_api_data_source_client,
self._dashboard_head._thread_pool_executor,
)

# Need daemon True to avoid dashboard hangs at exit.
self.service_discovery.daemon = True
Expand All @@ -635,7 +639,9 @@ async def run(self, server):

# NOTE: Every iteration is executed inside the thread-pool executor
# (TPE) to avoid blocking the Dashboard's event-loop
parsed_data = await loop.run_in_executor(None, json.loads, data)
parsed_data = await loop.run_in_executor(
self._dashboard_head._thread_pool_executor, json.loads, data
)

node_id = key.split(":")[-1]
DataSource.node_physical_stats[node_id] = parsed_data
Expand Down
5 changes: 4 additions & 1 deletion python/ray/dashboard/modules/state/state_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,10 @@ async def run(self, server):
self._state_api_data_source_client = StateDataSourceClient(
gcs_channel, self._dashboard_head.gcs_aio_client
)
self._state_api = StateAPIManager(self._state_api_data_source_client)
self._state_api = StateAPIManager(
self._state_api_data_source_client,
self._dashboard_head._thread_pool_executor,
)
self._log_api = LogsManager(self._state_api_data_source_client)

@staticmethod
Expand Down
9 changes: 6 additions & 3 deletions python/ray/dashboard/state_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,13 @@ class StateAPIManager:
the entries.
"""

def __init__(self, state_data_source_client: StateDataSourceClient):
def __init__(
self,
state_data_source_client: StateDataSourceClient,
thread_pool_executor: ThreadPoolExecutor,
):
self._client = state_data_source_client

self._thread_pool_executor = ThreadPoolExecutor(thread_name_prefix="state_head")
self._thread_pool_executor = thread_pool_executor

@property
def data_source_client(self):
Expand Down
5 changes: 4 additions & 1 deletion python/ray/tests/test_state_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import signal
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple
from unittest.mock import MagicMock, AsyncMock

Expand Down Expand Up @@ -130,7 +131,9 @@
@pytest.fixture
def state_api_manager():
data_source_client = AsyncMock(StateDataSourceClient)
manager = StateAPIManager(data_source_client)
manager = StateAPIManager(
data_source_client, thread_pool_executor=ThreadPoolExecutor()
)
yield manager


Expand Down
5 changes: 4 additions & 1 deletion python/ray/tests/test_state_api_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import sys
from dataclasses import asdict
from concurrent.futures import ThreadPoolExecutor

from ray.util.state import (
summarize_tasks,
Expand Down Expand Up @@ -42,7 +43,9 @@
@pytest.fixture
def state_api_manager():
data_source_client = AsyncMock(StateDataSourceClient)
manager = StateAPIManager(data_source_client)
manager = StateAPIManager(
data_source_client, thread_pool_executor=ThreadPoolExecutor()
)
yield manager


Expand Down
Loading