Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoscaler] Update autoscaler to use heartbeat batches. #3409

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"ray.core.generated.ClientTableData",
"ray.core.generated.GcsTableEntry",
"ray.core.generated.HeartbeatTableData",
"ray.core.generated.HeartbeatBatchTableData",
"ray.core.generated.DriverTableData",
"ray.core.generated.ErrorTableData",
"ray.core.generated.ProfileTableData",
Expand Down
7 changes: 5 additions & 2 deletions python/ray/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.ProfileTableData import ProfileTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData
from ray.core.generated.DriverTableData import DriverTableData
from ray.core.generated.ObjectTableData import ObjectTableData
from ray.core.generated.ray.protocol.Task import Task
Expand All @@ -20,14 +21,16 @@

__all__ = [
"GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData",
"DriverTableData", "ProfileTableData", "ObjectTableData", "Task",
"TablePrefix", "TablePubsub", "construct_error_message"
"HeartbeatBatchTableData", "DriverTableData", "ProfileTableData",
"ObjectTableData", "Task", "TablePrefix", "TablePubsub",
"construct_error_message"
]

FUNCTION_PREFIX = "RemoteFunction:"

# xray heartbeats
XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii")
XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.HEARTBEAT_BATCH).encode("ascii")

# xray driver updates
XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii")
Expand Down
77 changes: 36 additions & 41 deletions python/ray/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ def __init__(self,
# Setup subscriptions to the primary Redis server and the Redis shards.
self.primary_subscribe_client = self.redis.pubsub(
ignore_subscribe_messages=True)
self.shard_subscribe_clients = []
for redis_client in self.state.redis_clients:
subscribe_client = redis_client.pubsub(
ignore_subscribe_messages=True)
self.shard_subscribe_clients.append(subscribe_client)
# Keep a mapping from local scheduler client ID to IP address to use
# for updating the load metrics.
self.local_scheduler_id_to_ip_map = {}
Expand Down Expand Up @@ -90,49 +85,50 @@ def __init__(self,
str(e)))
self.issue_gcs_flushes = False

def subscribe(self, channel, primary=True):
"""Subscribe to the given channel.
def subscribe(self, channel):
"""Subscribe to the given channel on the primary Redis shard.

Args:
channel (str): The channel to subscribe to.
primary: If True, then we only subscribe to the primary Redis
shard. Otherwise we subscribe to all of the other shards but
not the primary.

Raises:
Exception: An exception is raised if the subscription fails.
"""
if primary:
self.primary_subscribe_client.subscribe(channel)
else:
for subscribe_client in self.shard_subscribe_clients:
subscribe_client.subscribe(channel)
self.primary_subscribe_client.subscribe(channel)

def xray_heartbeat_handler(self, unused_channel, data):
"""Handle an xray heartbeat message from Redis."""
def xray_heartbeat_batch_handler(self, unused_channel, data):
"""Handle an xray heartbeat batch message from Redis."""

gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0)
heartbeat_data = gcs_entries.Entries(0)
message = ray.gcs_utils.HeartbeatTableData.GetRootAsHeartbeatTableData(
heartbeat_data, 0)
num_resources = message.ResourcesAvailableLabelLength()
static_resources = {}
dynamic_resources = {}
for i in range(num_resources):
dyn = message.ResourcesAvailableLabel(i)
static = message.ResourcesTotalLabel(i)
dynamic_resources[dyn] = message.ResourcesAvailableCapacity(i)
static_resources[static] = message.ResourcesTotalCapacity(i)

# Update the load metrics for this local scheduler.
client_id = ray.utils.binary_to_hex(message.ClientId())
ip = self.local_scheduler_id_to_ip_map.get(client_id)
if ip:
self.load_metrics.update(ip, static_resources, dynamic_resources)
else:
print("Warning: could not find ip for client {} in {}.".format(
client_id, self.local_scheduler_id_to_ip_map))

message = (ray.gcs_utils.HeartbeatBatchTableData.
GetRootAsHeartbeatBatchTableData(heartbeat_data, 0))

for j in range(message.BatchLength()):
heartbeat_message = message.Batch(j)

num_resources = heartbeat_message.ResourcesAvailableLabelLength()
static_resources = {}
dynamic_resources = {}
for i in range(num_resources):
dyn = heartbeat_message.ResourcesAvailableLabel(i)
static = heartbeat_message.ResourcesTotalLabel(i)
dynamic_resources[dyn] = (
heartbeat_message.ResourcesAvailableCapacity(i))
static_resources[static] = (
heartbeat_message.ResourcesTotalCapacity(i))

# Update the load metrics for this local scheduler.
client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId())
ip = self.local_scheduler_id_to_ip_map.get(client_id)
if ip:
self.load_metrics.update(ip, static_resources,
dynamic_resources)
else:
print("Warning: could not find ip for client {} in {}.".format(
client_id, self.local_scheduler_id_to_ip_map))

def _xray_clean_up_entries_for_driver(self, driver_id):
"""Remove this driver's object/task entries from redis.
Expand Down Expand Up @@ -222,8 +218,7 @@ def process_messages(self, max_messages=10000):
max_messages: The maximum number of messages to process before
returning.
"""
subscribe_clients = (
[self.primary_subscribe_client] + self.shard_subscribe_clients)
subscribe_clients = [self.primary_subscribe_client]
for subscribe_client in subscribe_clients:
for _ in range(max_messages):
message = subscribe_client.get_message()
Expand All @@ -237,9 +232,9 @@ def process_messages(self, max_messages=10000):

# Determine the appropriate message handler.
message_handler = None
if channel == ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL:
if channel == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL:
# Similar functionality as local scheduler info channel
message_handler = self.xray_heartbeat_handler
message_handler = self.xray_heartbeat_batch_handler
elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL:
# Handles driver death.
message_handler = self.xray_driver_removed_handler
Expand Down Expand Up @@ -299,7 +294,7 @@ def run(self):
clients and cleaning up state accordingly.
"""
# Initialize the subscription channel.
self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL, primary=False)
self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL)
self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL)

# TODO(rkn): If there were any dead clients at startup, we should clean
Expand Down
6 changes: 3 additions & 3 deletions test/failure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,10 +500,10 @@ def test_warning_monitor_died(shutdown_only):
# addition to the monitor.
fake_id = 20 * b"\x00"
malformed_message = "asdf"
redis_client = ray.worker.global_state.redis_clients[0]
redis_client = ray.worker.global_worker.redis_client
redis_client.execute_command(
"RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT,
ray.gcs_utils.TablePubsub.HEARTBEAT, fake_id, malformed_message)
"RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT_BATCH,
ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH, fake_id, malformed_message)

wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1)

Expand Down