From 4208da21d919b8e0ab0df5d6fac715804dd28ad3 Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll <51814063+Yicheng-Lu-llll@users.noreply.github.com> Date: Fri, 15 Mar 2024 01:21:13 -0500 Subject: [PATCH] Add Ray Autoscaler to the Flyte-Ray plugin (#1937) Signed-off-by: Yicheng-Lu-llll --- .../flytekitplugins/ray/models.py | 40 +++++++++++++++++-- .../flytekit-ray/flytekitplugins/ray/task.py | 6 +++ plugins/flytekit-ray/tests/test_ray.py | 9 ++++- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/plugins/flytekit-ray/flytekitplugins/ray/models.py b/plugins/flytekit-ray/flytekitplugins/ray/models.py index 080f1239b4..06e36af186 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/models.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/models.py @@ -10,14 +10,14 @@ def __init__( self, group_name: str, replicas: int, - min_replicas: typing.Optional[int] = 0, + min_replicas: typing.Optional[int] = None, max_replicas: typing.Optional[int] = None, ray_start_params: typing.Optional[typing.Dict[str, str]] = None, ): self._group_name = group_name self._replicas = replicas - self._min_replicas = min_replicas - self._max_replicas = max_replicas if max_replicas else replicas + self._max_replicas = max(replicas, max_replicas) if max_replicas is not None else replicas + self._min_replicas = min(replicas, min_replicas) if min_replicas is not None else replicas self._ray_start_params = ray_start_params @property @@ -127,10 +127,14 @@ class RayCluster(_common.FlyteIdlEntity): """ def __init__( - self, worker_group_spec: typing.List[WorkerGroupSpec], head_group_spec: typing.Optional[HeadGroupSpec] = None + self, + worker_group_spec: typing.List[WorkerGroupSpec], + head_group_spec: typing.Optional[HeadGroupSpec] = None, + enable_autoscaling: bool = False, ): self._head_group_spec = head_group_spec self._worker_group_spec = worker_group_spec + self._enable_autoscaling = enable_autoscaling @property def head_group_spec(self) -> HeadGroupSpec: @@ -148,6 +152,14 @@ def worker_group_spec(self) -> typing.List[WorkerGroupSpec]: """ return self._worker_group_spec + @property + def enable_autoscaling(self) -> bool: + """ + Whether to enable autoscaling. + :rtype: bool + """ + return self._enable_autoscaling + def to_flyte_idl(self) -> _ray_pb2.RayCluster: """ :rtype: flyteidl.plugins._ray_pb2.RayCluster @@ -155,6 +167,7 @@ def to_flyte_idl(self) -> _ray_pb2.RayCluster: return _ray_pb2.RayCluster( head_group_spec=self.head_group_spec.to_flyte_idl() if self.head_group_spec else None, worker_group_spec=[wg.to_flyte_idl() for wg in self.worker_group_spec], + enable_autoscaling=self.enable_autoscaling, ) @classmethod @@ -166,6 +179,7 @@ def from_flyte_idl(cls, proto): return cls( head_group_spec=HeadGroupSpec.from_flyte_idl(proto.head_group_spec) if proto.head_group_spec else None, worker_group_spec=[WorkerGroupSpec.from_flyte_idl(wg) for wg in proto.worker_group_spec], + enable_autoscaling=proto.enable_autoscaling, ) @@ -178,9 +192,13 @@ def __init__( self, ray_cluster: RayCluster, runtime_env: typing.Optional[str], + ttl_seconds_after_finished: typing.Optional[int] = None, + shutdown_after_job_finishes: bool = False, ): self._ray_cluster = ray_cluster self._runtime_env = runtime_env + self._ttl_seconds_after_finished = ttl_seconds_after_finished + self._shutdown_after_job_finishes = shutdown_after_job_finishes @property def ray_cluster(self) -> RayCluster: @@ -190,10 +208,22 @@ def ray_cluster(self) -> RayCluster: def runtime_env(self) -> typing.Optional[str]: return self._runtime_env + @property + def ttl_seconds_after_finished(self) -> typing.Optional[int]: + # ttl_seconds_after_finished specifies the number of seconds after which the RayCluster will be deleted after the RayJob finishes. + return self._ttl_seconds_after_finished + + @property + def shutdown_after_job_finishes(self) -> bool: + # shutdown_after_job_finishes specifies whether the RayCluster should be deleted after the RayJob finishes. + return self._shutdown_after_job_finishes + def to_flyte_idl(self) -> _ray_pb2.RayJob: return _ray_pb2.RayJob( ray_cluster=self.ray_cluster.to_flyte_idl(), runtime_env=self.runtime_env, + ttl_seconds_after_finished=self.ttl_seconds_after_finished, + shutdown_after_job_finishes=self.shutdown_after_job_finishes, ) @classmethod @@ -201,4 +231,6 @@ def from_flyte_idl(cls, proto: _ray_pb2.RayJob): return cls( ray_cluster=RayCluster.from_flyte_idl(proto.ray_cluster) if proto.ray_cluster else None, runtime_env=proto.runtime_env, + ttl_seconds_after_finished=proto.ttl_seconds_after_finished, + shutdown_after_job_finishes=proto.shutdown_after_job_finishes, ) diff --git a/plugins/flytekit-ray/flytekitplugins/ray/task.py b/plugins/flytekit-ray/flytekitplugins/ray/task.py index f0ee542b32..76688d74cd 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/task.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/task.py @@ -34,8 +34,11 @@ class WorkerNodeConfig: class RayJobConfig: worker_node_config: typing.List[WorkerNodeConfig] head_node_config: typing.Optional[HeadNodeConfig] = None + enable_autoscaling: bool = False runtime_env: typing.Optional[dict] = None address: typing.Optional[str] = None + shutdown_after_job_finishes: bool = False + ttl_seconds_after_finished: typing.Optional[int] = None class RayFunctionTask(PythonFunctionTask): @@ -67,9 +70,12 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params) for c in cfg.worker_node_config ], + enable_autoscaling=cfg.enable_autoscaling if cfg.enable_autoscaling else False, ), # Use base64 to encode runtime_env dict and convert it to byte string runtime_env=base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode(), + ttl_seconds_after_finished=cfg.ttl_seconds_after_finished, + shutdown_after_job_finishes=cfg.shutdown_after_job_finishes, ) return MessageToDict(ray_job.to_flyte_idl()) diff --git a/plugins/flytekit-ray/tests/test_ray.py b/plugins/flytekit-ray/tests/test_ray.py index 8bcebf7937..0c0ada1944 100644 --- a/plugins/flytekit-ray/tests/test_ray.py +++ b/plugins/flytekit-ray/tests/test_ray.py @@ -10,8 +10,11 @@ from flytekit.configuration import Image, ImageConfig, SerializationSettings config = RayJobConfig( - worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3)], + worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10)], runtime_env={"pip": ["numpy"]}, + enable_autoscaling=True, + shutdown_after_job_finishes=True, + ttl_seconds_after_finished=20, ) @@ -37,8 +40,10 @@ def t1(a: int) -> str: ) ray_job_pb = RayJob( - ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec("test_group", 3)]), + ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec("test_group", 3, 0, 10)], enable_autoscaling=True), runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(), + shutdown_after_job_finishes=True, + ttl_seconds_after_finished=20, ).to_flyte_idl() assert t1.get_custom(settings) == MessageToDict(ray_job_pb)