Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SkyServe] Add service schema and change to new service YAML #2267

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions sky/serve/autoscalers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import time

from typing import Optional

from sky.serve.infra_providers import InfraProvider
from sky.serve.load_balancers import LoadBalancer

Expand Down Expand Up @@ -78,10 +80,10 @@ class RequestRateAutoscaler(Autoscaler):

def __init__(self,
*args,
query_interval: int = 10,
upper_threshold: int = 10,
lower_threshold: int = 2,
min_nodes: int = 1,
max_nodes: Optional[int] = None,
upper_threshold: Optional[int] = None,
lower_threshold: Optional[int] = None,
cooldown: int = 60,
**kwargs):
"""
Expand All @@ -95,10 +97,11 @@ def __init__(self,
:param kwargs:
"""
super().__init__(*args, **kwargs)
self.query_interval = query_interval
self.min_nodes = min_nodes
self.max_nodes = max_nodes or min_nodes
cblmemo marked this conversation as resolved.
Show resolved Hide resolved
self.query_interval = 60 # Therefore thresholds represent queries per minute.
self.upper_threshold = upper_threshold
self.lower_threshold = lower_threshold
self.min_nodes = min_nodes
self.cooldown = cooldown
self.last_scale_operation = 0 # Time of last scale operation.

Expand Down Expand Up @@ -132,14 +135,15 @@ def evaluate_scaling(self):
scaled = True
# Bootstrap case
logger.info(f'Number of nodes: {num_nodes}')
if num_nodes == 0 and requests_per_node > 0:
if num_nodes < self.min_nodes:
logger.info(f'Bootstrapping autoscaler.')
self.scale_up(1)
self.last_scale_operation = current_time
elif requests_per_node > self.upper_threshold:
self.scale_up(1)
self.last_scale_operation = current_time
elif requests_per_node < self.lower_threshold:
elif self.upper_threshold is not None and requests_per_node > self.upper_threshold:
if self.infra_provider.total_servers() < self.max_nodes:
self.scale_up(1)
self.last_scale_operation = current_time
elif self.lower_threshold is not None and requests_per_node < self.lower_threshold:
if self.infra_provider.total_servers() > self.min_nodes:
self.scale_down(1)
self.last_scale_operation = current_time
Expand Down
87 changes: 69 additions & 18 deletions sky/serve/common.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,83 @@
import yaml
from typing import Optional, Dict

from sky.backends import backend_utils
from sky.utils import schemas
from sky.utils import ux_utils


class SkyServiceSpec:

def __init__(self, yaml_path: str):
with open(yaml_path, 'r') as f:
self.task = yaml.safe_load(f)
if 'service' not in self.task:
raise ValueError('Task YAML must have a "service" section')
if 'port' not in self.task['service']:
raise ValueError('Task YAML must have a "port" section')
if 'readiness_probe' not in self.task['service']:
raise ValueError('Task YAML must have a "readiness_probe" section')
self._readiness_path = self.get_readiness_path()
self._app_port = self.get_app_port()

def get_readiness_path(self):
def __init__(
self,
readiness_path: str,
readiness_timeout: int,
app_port: int,
min_replica: int,
max_replica: Optional[int] = None,
qpm_upper_threshold: Optional[int] = None,
qpm_lower_threshold: Optional[int] = None,
):
if max_replica is not None and max_replica < min_replica:
with ux_utils.print_exception_no_traceback():
raise ValueError(
'max_replica must be greater than or equal to min_replica')
# TODO: check if the path is valid
return f':{self.task["service"]["port"]}{self.task["service"]["readiness_probe"]}'

def get_app_port(self):
self._readiness_path = f':{app_port}{readiness_path}'
self._readiness_timeout = readiness_timeout
# TODO: check if the port is valid
return f'{self.task["service"]["port"]}'
self._app_port = str(app_port)
self._min_replica = min_replica
self._max_replica = max_replica
self._qpm_upper_threshold = qpm_upper_threshold
self._qpm_lower_threshold = qpm_lower_threshold

@classmethod
def from_yaml_config(cls, config: Optional[Dict[str, str]]):
if config is None:
return None

backend_utils.validate_schema(config, schemas.get_service_schema(),
'Invalid service YAML:')

service_config = {}
service_config['readiness_path'] = config['readiness_probe']['path']
service_config['readiness_timeout'] = config['readiness_probe'][
'readiness_timeout']
service_config['app_port'] = config['port']
service_config['min_replica'] = config['replica_policy']['min_replica']
service_config['max_replica'] = config['replica_policy'].get(
'max_replica', None)
service_config['qpm_upper_threshold'] = config['replica_policy'].get(
'qpm_upper_threshold', None)
service_config['qpm_lower_threshold'] = config['replica_policy'].get(
'qpm_lower_threshold', None)

return SkyServiceSpec(**service_config)

@property
def readiness_path(self):
return self._readiness_path

@property
def readiness_timeout(self):
return self._readiness_timeout

@property
def app_port(self):
return self._app_port

@property
def min_replica(self):
return self._min_replica

@property
def max_replica(self):
return self._max_replica

@property
def qpm_upper_threshold(self):
return self._qpm_upper_threshold

@property
def qpm_lower_threshold(self):
return self._qpm_lower_threshold
32 changes: 18 additions & 14 deletions sky/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import time
import threading
import yaml

from typing import Optional

Expand Down Expand Up @@ -83,37 +84,40 @@ def get_server_ips():
type=int,
help='Port to run the controller',
default=8082)
parser.add_argument('--min-nodes',
type=int,
default=1,
help='Minimum nodes to keep running')
args = parser.parse_args()

# ======= Infra Provider =========
# infra_provider = DummyInfraProvider()
infra_provider = SkyPilotInfraProvider(args.task_yaml)

# ======= Load Balancer =========
service_spec = SkyServiceSpec(args.task_yaml)
with open(args.task_yaml, 'r') as f:
task = yaml.safe_load(f)
if 'service' not in task:
raise ValueError('Task YAML must have a "service" section')
service_config = task['service']
service_spec = SkyServiceSpec.from_yaml_config(service_config)
# Select the load balancing policy: RoundRobinLoadBalancer or LeastLoadedLoadBalancer
load_balancer = RoundRobinLoadBalancer(
infra_provider=infra_provider,
endpoint_path=service_spec.readiness_path)
endpoint_path=service_spec.readiness_path,
readiness_timeout=service_spec.readiness_timeout)
# load_balancer = LeastLoadedLoadBalancer(n=5)
# autoscaler = LatencyThresholdAutoscaler(load_balancer,
# upper_threshold=0.5, # 500ms
# lower_threshold=0.1) # 100ms

# ======= Autoscaler =========
# Create an autoscaler with the RequestRateAutoscaler policy. Thresholds are defined as requests per node in the defined interval.
autoscaler = RequestRateAutoscaler(infra_provider,
load_balancer,
frequency=5,
query_interval=60,
lower_threshold=0,
upper_threshold=1,
min_nodes=args.min_nodes,
cooldown=60)
autoscaler = RequestRateAutoscaler(
infra_provider,
load_balancer,
frequency=5,
min_nodes=service_spec.min_replica,
max_nodes=service_spec.max_replica,
upper_threshold=service_spec.qpm_upper_threshold,
lower_threshold=service_spec.qpm_lower_threshold,
cooldown=60)

# ======= Controller =========
# Create a controller object and run it.
Expand Down
3 changes: 3 additions & 0 deletions sky/serve/examples/http_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

PORT = 8081


class MyHttpRequestHandler(http.server.SimpleHTTPRequestHandler):

def do_GET(self):
# Return 200 for all paths
# Therefore, readiness_probe will return 200 at path '/health'
Expand All @@ -23,6 +25,7 @@ def do_GET(self):
self.wfile.write(bytes(html, 'utf8'))
return


Handler = MyHttpRequestHandler

with socketserver.TCPServer(("", PORT), Handler) as httpd:
Expand Down
7 changes: 6 additions & 1 deletion sky/serve/examples/http_server/task.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,9 @@ run: python3 server.py

service:
port: 8081
readiness_probe: /health
readiness_probe:
path: /health
readiness_timeout: 12000
replica_policy:
min_replica: 1
max_replica: 1
6 changes: 3 additions & 3 deletions sky/serve/load_balancers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@

class LoadBalancer:

def __init__(self, infra_provider, endpoint_path, post_data=None):
def __init__(self, infra_provider, endpoint_path, readiness_timeout, post_data=None):
self.available_servers = []
self.request_count = 0
self.request_timestamps = deque()
self.infra_provider = infra_provider
self.endpoint_path = endpoint_path
self.readiness_timeout = readiness_timeout
self.post_data = post_data

def increment_request_count(self, count=1):
Expand All @@ -37,7 +38,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.servers_queue = deque()
self.first_unhealthy_time = {}
self.timeout = 18000
logger.info(f'Endpoint path: {self.endpoint_path}')

def probe_endpoints(self, endpoint_ips):
Expand Down Expand Up @@ -101,7 +101,7 @@ def probe_endpoint(endpoint_ip):
if server not in self.first_unhealthy_time:
self.first_unhealthy_time[server] = time.time()
elif time.time() - self.first_unhealthy_time[
server] > self.timeout: # cooldown before terminating a dead server to avoid hysterisis
server] > self.readiness_timeout: # cooldown before terminating a dead server to avoid hysterisis
servers_to_terminate.append(server)
self.infra_provider.terminate_servers(servers_to_terminate)

Expand Down
9 changes: 8 additions & 1 deletion sky/serve/redirector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
import logging
import yaml
from collections import deque

from sky.serve.common import SkyServiceSpec
Expand Down Expand Up @@ -113,7 +114,13 @@ def serve(self):
help='Controller address (ip:port).')
args = parser.parse_args()

service_spec = SkyServiceSpec(args.task_yaml)
with open(args.task_yaml, 'r') as f:
task = yaml.safe_load(f)
if 'service' not in task:
raise ValueError('Task YAML must have a "service" section')
service_config = task['service']
service_spec = SkyServiceSpec.from_yaml_config(service_config)

redirector = SkyServeRedirector(controller_url=args.controller_addr,
service_spec=service_spec,
port=args.port)
Expand Down
9 changes: 6 additions & 3 deletions sky/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sky.backends import backend_utils
from sky.data import storage as storage_lib
from sky.data import data_utils
from sky.serve import common
from sky.skylet import constants
from sky.utils import schemas
from sky.utils import ux_utils
Expand Down Expand Up @@ -365,10 +366,12 @@ def from_yaml_config(
resources = config.pop('resources', None)
resources = sky.Resources.from_yaml_config(resources)

# FIXME: find a better way to exclude unused fields.
config.pop('service', None)

task.set_resources({resources})

service = config.pop('service', None)
service = common.SkyServiceSpec.from_yaml_config(service)
task.service = service

assert not config, f'Invalid task args: {config.keys()}'
return task

Expand Down
50 changes: 50 additions & 0 deletions sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,52 @@ def get_storage_schema():
}


def get_service_schema():
return {
'$schema': 'http://json-schema.org/draft-07/schema#',
'type': 'object',
'required': ['port', 'readiness_probe', 'replica_policy'],
'additionalProperties': False,
'properties': {
'port': {
'type': 'integer',
},
'readiness_probe': {
'type': 'object',
'required': ['path', 'timeout'],
'additionalProperties': False,
'properties': {
'path': {
'type': 'string',
},
'readiness_timeout': {
'type': 'number',
},
}
},
'replica_policy': {
'type': 'object',
'required': ['min_replica'],
'additionalProperties': False,
'properties': {
'min_replica': {
'type': 'integer',
},
'max_replica': {
'type': 'integer',
},
'qpm_upper_threshold': {
'type': 'number',
},
'qpm_lower_threshold': {
'type': 'number',
},
}
}
}
}


def get_task_schema():
return {
'$schema': 'https://json-schema.org/draft/2020-12/schema',
Expand Down Expand Up @@ -170,6 +216,10 @@ def get_task_schema():
'file_mounts': {
'type': 'object',
},
# service config is validated separately using SERVICE_SCHEMA
'service': {
'type': 'object',
},
'setup': {
'type': 'string',
},
Expand Down