Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Add and adopt least load policy as default poicy. #4439

Merged
merged 6 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/serving/sky-serve.rst
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ Under the hood, :code:`sky serve up`:
#. Meanwhile, the controller provisions replica VMs which later run the services;
#. Once any replica is ready, the requests sent to the Service Endpoint will be distributed to one of the endpoint replicas.

.. note::
SkyServe uses least load load balancing to distribute the traffic to the replicas. It keeps track of the number of requests each replica has handled and routes the next request to the replica with the least load.

After the controller is provisioned, you'll see the following in :code:`sky serve status` output:

.. image:: ../images/sky-serve-status-output-provisioning.png
Expand Down
11 changes: 11 additions & 0 deletions examples/serve/minimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# An minimal example of a serve application.

service:
readiness_probe: /
replicas: 1

resources:
ports: 8080
cpus: 2+

run: python3 -m http.server 8080
14 changes: 13 additions & 1 deletion sky/serve/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,17 @@ def update(
with ux_utils.print_exception_no_traceback():
raise RuntimeError(prompt)

original_lb_policy = service_record['load_balancing_policy']
assert task.service is not None, 'Service section not found.'
if original_lb_policy != task.service.load_balancing_policy:
logger.warning(
f'{colorama.Fore.YELLOW}Current load balancing policy '
f'{original_lb_policy!r} is different from the new policy '
f'{task.service.load_balancing_policy!r}. Updating the load '
'balancing policy is not supported yet and it will be ignored. '
'The service will continue to use the current load balancing '
f'policy.{colorama.Style.RESET_ALL}')

with rich_utils.safe_status(
ux_utils.spinner_message('Initializing service')):
controller_utils.maybe_translate_local_file_mounts_and_sync_up(
Expand Down Expand Up @@ -581,9 +592,10 @@ def status(
'status': (sky.ServiceStatus) service status,
'controller_port': (Optional[int]) controller port,
'load_balancer_port': (Optional[int]) load balancer port,
'policy': (Optional[str]) load balancer policy description,
'policy': (Optional[str]) autoscaling policy description,
'requested_resources_str': (str) str representation of
requested resources,
'load_balancing_policy': (str) load balancing policy name,
'replica_info': (List[Dict[str, Any]]) replica information,
}

Expand Down
12 changes: 10 additions & 2 deletions sky/serve/load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def __init__(self,
# Use the registry to create the load balancing policy
self._load_balancing_policy = lb_policies.LoadBalancingPolicy.make(
load_balancing_policy_name)
logger.info('Starting load balancer with policy '
f'{load_balancing_policy_name}.')
self._request_aggregator: serve_utils.RequestsAggregator = (
serve_utils.RequestTimestamp())
# TODO(tian): httpx.Client has a resource limit of 100 max connections
Expand Down Expand Up @@ -128,6 +130,7 @@ async def _proxy_request_to(
encountered if anything goes wrong.
"""
logger.info(f'Proxy request to {url}')
self._load_balancing_policy.pre_execute_hook(url, request)
try:
# We defer the get of the client here on purpose, for case when the
# replica is ready in `_proxy_with_retries` but refreshed before
Expand All @@ -147,11 +150,16 @@ async def _proxy_request_to(
content=await request.body(),
timeout=constants.LB_STREAM_TIMEOUT)
proxy_response = await client.send(proxy_request, stream=True)

async def background_func():
await proxy_response.aclose()
self._load_balancing_policy.post_execute_hook(url, request)

return fastapi.responses.StreamingResponse(
content=proxy_response.aiter_raw(),
status_code=proxy_response.status_code,
headers=proxy_response.headers,
background=background.BackgroundTask(proxy_response.aclose))
background=background.BackgroundTask(background_func))
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.error(f'Error when proxy request to {url}: '
f'{common_utils.format_exception(e)}')
Expand Down Expand Up @@ -263,7 +271,7 @@ def run_load_balancer(controller_addr: str,
parser.add_argument(
'--load-balancing-policy',
choices=available_policies,
default='round_robin',
default=lb_policies.DEFAULT_LB_POLICY,
help=f'The load balancing policy to use. Available policies: '
f'{", ".join(available_policies)}.')
args = parser.parse_args()
Expand Down
70 changes: 65 additions & 5 deletions sky/serve/load_balancing_policies.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""LoadBalancingPolicy: Policy to select endpoint."""
import collections
import random
import threading
import typing
from typing import List, Optional
from typing import Dict, List, Optional

from sky import sky_logging

Expand All @@ -13,6 +15,10 @@
# Define a registry for load balancing policies
LB_POLICIES = {}
DEFAULT_LB_POLICY = None
# Prior to #4439, the default policy was round_robin. We store the legacy
# default policy here to maintain backwards compatibility. Remove this after
# 2 minor release, i.e., 0.9.0.
LEGACY_DEFAULT_POLICY = 'round_robin'


def _request_repr(request: 'fastapi.Request') -> str:
Expand All @@ -38,11 +44,17 @@ def __init_subclass__(cls, name: str, default: bool = False):
DEFAULT_LB_POLICY = name

@classmethod
def make(cls, policy_name: Optional[str] = None) -> 'LoadBalancingPolicy':
"""Create a load balancing policy from a name."""
def make_policy_name(cls, policy_name: Optional[str]) -> str:
"""Return the policy name."""
assert DEFAULT_LB_POLICY is not None, 'No default policy set.'
if policy_name is None:
policy_name = DEFAULT_LB_POLICY
return DEFAULT_LB_POLICY
return policy_name

@classmethod
def make(cls, policy_name: Optional[str] = None) -> 'LoadBalancingPolicy':
"""Create a load balancing policy from a name."""
policy_name = cls.make_policy_name(policy_name)
if policy_name not in LB_POLICIES:
raise ValueError(f'Unknown load balancing policy: {policy_name}')
return LB_POLICIES[policy_name]()
Expand All @@ -65,8 +77,16 @@ def select_replica(self, request: 'fastapi.Request') -> Optional[str]:
def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
raise NotImplementedError

def pre_execute_hook(self, replica_url: str,
request: 'fastapi.Request') -> None:
pass

def post_execute_hook(self, replica_url: str,
request: 'fastapi.Request') -> None:
pass


class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin', default=True):
class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin'):
"""Round-robin load balancing policy."""

def __init__(self) -> None:
Expand All @@ -90,3 +110,43 @@ def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
ready_replica_url = self.ready_replicas[self.index]
self.index = (self.index + 1) % len(self.ready_replicas)
return ready_replica_url


class LeastLoadPolicy(LoadBalancingPolicy, name='least_load', default=True):
"""Least load load balancing policy."""

def __init__(self) -> None:
super().__init__()
self.load_map: Dict[str, int] = collections.defaultdict(int)
self.lock = threading.Lock()

def set_ready_replicas(self, ready_replicas: List[str]) -> None:
if set(self.ready_replicas) == set(ready_replicas):
return
with self.lock:
self.ready_replicas = ready_replicas
for r in self.ready_replicas:
if r not in ready_replicas:
del self.load_map[r]
for replica in ready_replicas:
self.load_map[replica] = self.load_map.get(replica, 0)

def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
del request # Unused.
if not self.ready_replicas:
return None
with self.lock:
return min(self.ready_replicas,
key=lambda replica: self.load_map.get(replica, 0))

def pre_execute_hook(self, replica_url: str,
request: 'fastapi.Request') -> None:
del request # Unused.
with self.lock:
self.load_map[replica_url] += 1

def post_execute_hook(self, replica_url: str,
request: 'fastapi.Request') -> None:
del request # Unused.
with self.lock:
self.load_map[replica_url] -= 1
20 changes: 15 additions & 5 deletions sky/serve/serve_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import colorama

from sky.serve import constants
from sky.serve import load_balancing_policies as lb_policies
from sky.utils import db_utils

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -76,6 +77,8 @@ def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None:
db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services',
'active_versions',
f'TEXT DEFAULT {json.dumps([])!r}')
db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services',
'load_balancing_policy', 'TEXT DEFAULT NULL')
_UNIQUE_CONSTRAINT_FAILED_ERROR_MSG = 'UNIQUE constraint failed: services.name'


Expand Down Expand Up @@ -241,7 +244,8 @@ def from_replica_statuses(


def add_service(name: str, controller_job_id: int, policy: str,
requested_resources_str: str, status: ServiceStatus) -> bool:
requested_resources_str: str, load_balancing_policy: str,
status: ServiceStatus) -> bool:
"""Add a service in the database.

Returns:
Expand All @@ -254,10 +258,10 @@ def add_service(name: str, controller_job_id: int, policy: str,
"""\
INSERT INTO services
(name, controller_job_id, status, policy,
requested_resources_str)
VALUES (?, ?, ?, ?, ?)""",
requested_resources_str, load_balancing_policy)
VALUES (?, ?, ?, ?, ?, ?)""",
(name, controller_job_id, status.value, policy,
requested_resources_str))
requested_resources_str, load_balancing_policy))

except sqlite3.IntegrityError as e:
if str(e) != _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG:
Expand Down Expand Up @@ -324,7 +328,12 @@ def set_service_load_balancer_port(service_name: str,
def _get_service_from_row(row) -> Dict[str, Any]:
(current_version, name, controller_job_id, controller_port,
load_balancer_port, status, uptime, policy, _, _, requested_resources_str,
_, active_versions) = row[:13]
_, active_versions, load_balancing_policy) = row[:14]
if load_balancing_policy is None:
# This entry in database was added in #4439, and it will always be set
# to a str value. If it is None, it means it is an legacy entry and is
# using the legacy default policy.
load_balancing_policy = lb_policies.LEGACY_DEFAULT_POLICY
return {
'name': name,
'controller_job_id': controller_job_id,
Expand All @@ -341,6 +350,7 @@ def _get_service_from_row(row) -> Dict[str, Any]:
# integers in json format. This is mainly for display purpose.
'active_versions': json.loads(active_versions),
'requested_resources_str': requested_resources_str,
'load_balancing_policy': load_balancing_policy,
}


Expand Down
8 changes: 6 additions & 2 deletions sky/serve/serve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,9 @@ def format_service_table(service_records: List[Dict[str, Any]],
'NAME', 'VERSION', 'UPTIME', 'STATUS', 'REPLICAS', 'ENDPOINT'
]
if show_all:
service_columns.extend(['POLICY', 'REQUESTED_RESOURCES'])
service_columns.extend([
'AUTOSCALING_POLICY', 'LOAD_BALANCING_POLICY', 'REQUESTED_RESOURCES'
])
service_table = log_utils.create_table(service_columns)

replica_infos = []
Expand All @@ -832,6 +834,7 @@ def format_service_table(service_records: List[Dict[str, Any]],
endpoint = get_endpoint(record)
policy = record['policy']
requested_resources_str = record['requested_resources_str']
load_balancing_policy = record['load_balancing_policy']

service_values = [
service_name,
Expand All @@ -842,7 +845,8 @@ def format_service_table(service_records: List[Dict[str, Any]],
endpoint,
]
if show_all:
service_values.extend([policy, requested_resources_str])
service_values.extend(
[policy, load_balancing_policy, requested_resources_str])
service_table.add_row(service_values)

replica_table = _format_replica_table(replica_infos, show_all)
Expand Down
1 change: 1 addition & 0 deletions sky/serve/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int):
controller_job_id=job_id,
policy=service_spec.autoscaling_policy_str(),
requested_resources_str=backend_utils.get_task_resources_str(task),
load_balancing_policy=service_spec.load_balancing_policy,
status=serve_state.ServiceStatus.CONTROLLER_INIT)
# Directly throw an error here. See sky/serve/api.py::up
# for more details.
Expand Down
6 changes: 4 additions & 2 deletions sky/serve/service_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from sky import serve
from sky.serve import constants
from sky.serve import load_balancing_policies as lb_policies
from sky.utils import common_utils
from sky.utils import schemas
from sky.utils import ux_utils
Expand Down Expand Up @@ -327,5 +328,6 @@ def use_ondemand_fallback(self) -> bool:
return self._use_ondemand_fallback

@property
def load_balancing_policy(self) -> Optional[str]:
return self._load_balancing_policy
def load_balancing_policy(self) -> str:
return lb_policies.LoadBalancingPolicy.make_policy_name(
self._load_balancing_policy)
1 change: 1 addition & 0 deletions tests/skyserve/load_balancer/service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ service:
initial_delay_seconds: 180
replica_policy:
min_replicas: 3
load_balancing_policy: round_robin

resources:
ports: 8080
Expand Down
1 change: 1 addition & 0 deletions tests/skyserve/update/new.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ service:
path: /health
initial_delay_seconds: 100
replicas: 2
load_balancing_policy: round_robin

resources:
ports: 8081
Expand Down
1 change: 1 addition & 0 deletions tests/skyserve/update/old.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ service:
path: /health
initial_delay_seconds: 20
replicas: 2
load_balancing_policy: round_robin

resources:
ports: 8080
Expand Down
Loading