diff --git a/sky/serve/autoscalers.py b/sky/serve/autoscalers.py index a9075beedfb..ecc673c7c03 100644 --- a/sky/serve/autoscalers.py +++ b/sky/serve/autoscalers.py @@ -1,6 +1,8 @@ import logging import time +from typing import Optional + from sky.serve.infra_providers import InfraProvider from sky.serve.load_balancers import LoadBalancer @@ -78,10 +80,10 @@ class RequestRateAutoscaler(Autoscaler): def __init__(self, *args, - query_interval: int = 10, - upper_threshold: int = 10, - lower_threshold: int = 2, min_nodes: int = 1, + max_nodes: Optional[int] = None, + upper_threshold: Optional[int] = None, + lower_threshold: Optional[int] = None, cooldown: int = 60, **kwargs): """ @@ -95,10 +97,11 @@ def __init__(self, :param kwargs: """ super().__init__(*args, **kwargs) - self.query_interval = query_interval + self.min_nodes = min_nodes + self.max_nodes = max_nodes or min_nodes + self.query_interval = 60 # Therefore thresholds represent queries per minute. self.upper_threshold = upper_threshold self.lower_threshold = lower_threshold - self.min_nodes = min_nodes self.cooldown = cooldown self.last_scale_operation = 0 # Time of last scale operation. @@ -132,14 +135,15 @@ def evaluate_scaling(self): scaled = True # Bootstrap case logger.info(f'Number of nodes: {num_nodes}') - if num_nodes == 0 and requests_per_node > 0: + if num_nodes < self.min_nodes: logger.info(f'Bootstrapping autoscaler.') self.scale_up(1) self.last_scale_operation = current_time - elif requests_per_node > self.upper_threshold: - self.scale_up(1) - self.last_scale_operation = current_time - elif requests_per_node < self.lower_threshold: + elif self.upper_threshold is not None and requests_per_node > self.upper_threshold: + if self.infra_provider.total_servers() < self.max_nodes: + self.scale_up(1) + self.last_scale_operation = current_time + elif self.lower_threshold is not None and requests_per_node < self.lower_threshold: if self.infra_provider.total_servers() > self.min_nodes: self.scale_down(1) self.last_scale_operation = current_time diff --git a/sky/serve/common.py b/sky/serve/common.py index 5f211dc5a10..5c791daa363 100644 --- a/sky/serve/common.py +++ b/sky/serve/common.py @@ -1,32 +1,83 @@ -import yaml +from typing import Optional, Dict + +from sky.backends import backend_utils +from sky.utils import schemas +from sky.utils import ux_utils class SkyServiceSpec: - def __init__(self, yaml_path: str): - with open(yaml_path, 'r') as f: - self.task = yaml.safe_load(f) - if 'service' not in self.task: - raise ValueError('Task YAML must have a "service" section') - if 'port' not in self.task['service']: - raise ValueError('Task YAML must have a "port" section') - if 'readiness_probe' not in self.task['service']: - raise ValueError('Task YAML must have a "readiness_probe" section') - self._readiness_path = self.get_readiness_path() - self._app_port = self.get_app_port() - - def get_readiness_path(self): + def __init__( + self, + readiness_path: str, + readiness_timeout: int, + app_port: int, + min_replica: int, + max_replica: Optional[int] = None, + qpm_upper_threshold: Optional[int] = None, + qpm_lower_threshold: Optional[int] = None, + ): + if max_replica is not None and max_replica < min_replica: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + 'max_replica must be greater than or equal to min_replica') # TODO: check if the path is valid - return f':{self.task["service"]["port"]}{self.task["service"]["readiness_probe"]}' - - def get_app_port(self): + self._readiness_path = f':{app_port}{readiness_path}' + self._readiness_timeout = readiness_timeout # TODO: check if the port is valid - return f'{self.task["service"]["port"]}' + self._app_port = str(app_port) + self._min_replica = min_replica + self._max_replica = max_replica + self._qpm_upper_threshold = qpm_upper_threshold + self._qpm_lower_threshold = qpm_lower_threshold + + @classmethod + def from_yaml_config(cls, config: Optional[Dict[str, str]]): + if config is None: + return None + + backend_utils.validate_schema(config, schemas.get_service_schema(), + 'Invalid service YAML:') + + service_config = {} + service_config['readiness_path'] = config['readiness_probe']['path'] + service_config['readiness_timeout'] = config['readiness_probe'][ + 'readiness_timeout'] + service_config['app_port'] = config['port'] + service_config['min_replica'] = config['replica_policy']['min_replica'] + service_config['max_replica'] = config['replica_policy'].get( + 'max_replica', None) + service_config['qpm_upper_threshold'] = config['replica_policy'].get( + 'qpm_upper_threshold', None) + service_config['qpm_lower_threshold'] = config['replica_policy'].get( + 'qpm_lower_threshold', None) + + return SkyServiceSpec(**service_config) @property def readiness_path(self): return self._readiness_path + @property + def readiness_timeout(self): + return self._readiness_timeout + @property def app_port(self): return self._app_port + + @property + def min_replica(self): + return self._min_replica + + @property + def max_replica(self): + return self._max_replica + + @property + def qpm_upper_threshold(self): + return self._qpm_upper_threshold + + @property + def qpm_lower_threshold(self): + return self._qpm_lower_threshold diff --git a/sky/serve/controller.py b/sky/serve/controller.py index dfe5a953764..d10f89ab957 100644 --- a/sky/serve/controller.py +++ b/sky/serve/controller.py @@ -9,6 +9,7 @@ import time import threading +import yaml from typing import Optional @@ -83,10 +84,6 @@ def get_server_ips(): type=int, help='Port to run the controller', default=8082) - parser.add_argument('--min-nodes', - type=int, - default=1, - help='Minimum nodes to keep running') args = parser.parse_args() # ======= Infra Provider ========= @@ -94,11 +91,17 @@ def get_server_ips(): infra_provider = SkyPilotInfraProvider(args.task_yaml) # ======= Load Balancer ========= - service_spec = SkyServiceSpec(args.task_yaml) + with open(args.task_yaml, 'r') as f: + task = yaml.safe_load(f) + if 'service' not in task: + raise ValueError('Task YAML must have a "service" section') + service_config = task['service'] + service_spec = SkyServiceSpec.from_yaml_config(service_config) # Select the load balancing policy: RoundRobinLoadBalancer or LeastLoadedLoadBalancer load_balancer = RoundRobinLoadBalancer( infra_provider=infra_provider, - endpoint_path=service_spec.readiness_path) + endpoint_path=service_spec.readiness_path, + readiness_timeout=service_spec.readiness_timeout) # load_balancer = LeastLoadedLoadBalancer(n=5) # autoscaler = LatencyThresholdAutoscaler(load_balancer, # upper_threshold=0.5, # 500ms @@ -106,14 +109,15 @@ def get_server_ips(): # ======= Autoscaler ========= # Create an autoscaler with the RequestRateAutoscaler policy. Thresholds are defined as requests per node in the defined interval. - autoscaler = RequestRateAutoscaler(infra_provider, - load_balancer, - frequency=5, - query_interval=60, - lower_threshold=0, - upper_threshold=1, - min_nodes=args.min_nodes, - cooldown=60) + autoscaler = RequestRateAutoscaler( + infra_provider, + load_balancer, + frequency=5, + min_nodes=service_spec.min_replica, + max_nodes=service_spec.max_replica, + upper_threshold=service_spec.qpm_upper_threshold, + lower_threshold=service_spec.qpm_lower_threshold, + cooldown=60) # ======= Controller ========= # Create a controller object and run it. diff --git a/sky/serve/examples/http_server/server.py b/sky/serve/examples/http_server/server.py index 4ea616b148e..303b117d26d 100644 --- a/sky/serve/examples/http_server/server.py +++ b/sky/serve/examples/http_server/server.py @@ -3,7 +3,9 @@ PORT = 8081 + class MyHttpRequestHandler(http.server.SimpleHTTPRequestHandler): + def do_GET(self): # Return 200 for all paths # Therefore, readiness_probe will return 200 at path '/health' @@ -23,6 +25,7 @@ def do_GET(self): self.wfile.write(bytes(html, 'utf8')) return + Handler = MyHttpRequestHandler with socketserver.TCPServer(("", PORT), Handler) as httpd: diff --git a/sky/serve/examples/http_server/task.yaml b/sky/serve/examples/http_server/task.yaml index b82fbb29f75..d0fe866f259 100644 --- a/sky/serve/examples/http_server/task.yaml +++ b/sky/serve/examples/http_server/task.yaml @@ -9,4 +9,9 @@ run: python3 server.py service: port: 8081 - readiness_probe: /health + readiness_probe: + path: /health + readiness_timeout: 12000 + replica_policy: + min_replica: 1 + max_replica: 1 diff --git a/sky/serve/load_balancers.py b/sky/serve/load_balancers.py index 4c0c95c6ff7..90f94ceca41 100644 --- a/sky/serve/load_balancers.py +++ b/sky/serve/load_balancers.py @@ -12,12 +12,13 @@ class LoadBalancer: - def __init__(self, infra_provider, endpoint_path, post_data=None): + def __init__(self, infra_provider, endpoint_path, readiness_timeout, post_data=None): self.available_servers = [] self.request_count = 0 self.request_timestamps = deque() self.infra_provider = infra_provider self.endpoint_path = endpoint_path + self.readiness_timeout = readiness_timeout self.post_data = post_data def increment_request_count(self, count=1): @@ -37,7 +38,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.servers_queue = deque() self.first_unhealthy_time = {} - self.timeout = 18000 logger.info(f'Endpoint path: {self.endpoint_path}') def probe_endpoints(self, endpoint_ips): @@ -101,7 +101,7 @@ def probe_endpoint(endpoint_ip): if server not in self.first_unhealthy_time: self.first_unhealthy_time[server] = time.time() elif time.time() - self.first_unhealthy_time[ - server] > self.timeout: # cooldown before terminating a dead server to avoid hysterisis + server] > self.readiness_timeout: # cooldown before terminating a dead server to avoid hysterisis servers_to_terminate.append(server) self.infra_provider.terminate_servers(servers_to_terminate) diff --git a/sky/serve/redirector.py b/sky/serve/redirector.py index 5c3df12c62c..3e130058abb 100644 --- a/sky/serve/redirector.py +++ b/sky/serve/redirector.py @@ -1,5 +1,6 @@ import time import logging +import yaml from collections import deque from sky.serve.common import SkyServiceSpec @@ -113,7 +114,13 @@ def serve(self): help='Controller address (ip:port).') args = parser.parse_args() - service_spec = SkyServiceSpec(args.task_yaml) + with open(args.task_yaml, 'r') as f: + task = yaml.safe_load(f) + if 'service' not in task: + raise ValueError('Task YAML must have a "service" section') + service_config = task['service'] + service_spec = SkyServiceSpec.from_yaml_config(service_config) + redirector = SkyServeRedirector(controller_url=args.controller_addr, service_spec=service_spec, port=args.port) diff --git a/sky/task.py b/sky/task.py index 2105d0feeda..2fadbc7395f 100644 --- a/sky/task.py +++ b/sky/task.py @@ -15,6 +15,7 @@ from sky.backends import backend_utils from sky.data import storage as storage_lib from sky.data import data_utils +from sky.serve import common from sky.skylet import constants from sky.utils import schemas from sky.utils import ux_utils @@ -365,10 +366,12 @@ def from_yaml_config( resources = config.pop('resources', None) resources = sky.Resources.from_yaml_config(resources) - # FIXME: find a better way to exclude unused fields. - config.pop('service', None) - task.set_resources({resources}) + + service = config.pop('service', None) + service = common.SkyServiceSpec.from_yaml_config(service) + task.service = service + assert not config, f'Invalid task args: {config.keys()}' return task diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 8126a97da9d..5c95edcd336 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -143,6 +143,52 @@ def get_storage_schema(): } +def get_service_schema(): + return { + '$schema': 'http://json-schema.org/draft-07/schema#', + 'type': 'object', + 'required': ['port', 'readiness_probe', 'replica_policy'], + 'additionalProperties': False, + 'properties': { + 'port': { + 'type': 'integer', + }, + 'readiness_probe': { + 'type': 'object', + 'required': ['path', 'timeout'], + 'additionalProperties': False, + 'properties': { + 'path': { + 'type': 'string', + }, + 'readiness_timeout': { + 'type': 'number', + }, + } + }, + 'replica_policy': { + 'type': 'object', + 'required': ['min_replica'], + 'additionalProperties': False, + 'properties': { + 'min_replica': { + 'type': 'integer', + }, + 'max_replica': { + 'type': 'integer', + }, + 'qpm_upper_threshold': { + 'type': 'number', + }, + 'qpm_lower_threshold': { + 'type': 'number', + }, + } + } + } + } + + def get_task_schema(): return { '$schema': 'https://json-schema.org/draft/2020-12/schema', @@ -170,6 +216,10 @@ def get_task_schema(): 'file_mounts': { 'type': 'object', }, + # service config is validated separately using SERVICE_SCHEMA + 'service': { + 'type': 'object', + }, 'setup': { 'type': 'string', },