Skip to content

Commit

Permalink
Add Ray Autoscaler to the Flyte-Ray plugin (#1937)
Browse files Browse the repository at this point in the history
Signed-off-by: Yicheng-Lu-llll <[email protected]>
  • Loading branch information
Yicheng-Lu-llll authored Mar 15, 2024
1 parent 4767fd8 commit 4208da2
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 6 deletions.
40 changes: 36 additions & 4 deletions plugins/flytekit-ray/flytekitplugins/ray/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ def __init__(
self,
group_name: str,
replicas: int,
min_replicas: typing.Optional[int] = 0,
min_replicas: typing.Optional[int] = None,
max_replicas: typing.Optional[int] = None,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
):
self._group_name = group_name
self._replicas = replicas
self._min_replicas = min_replicas
self._max_replicas = max_replicas if max_replicas else replicas
self._max_replicas = max(replicas, max_replicas) if max_replicas is not None else replicas
self._min_replicas = min(replicas, min_replicas) if min_replicas is not None else replicas
self._ray_start_params = ray_start_params

@property
Expand Down Expand Up @@ -127,10 +127,14 @@ class RayCluster(_common.FlyteIdlEntity):
"""

def __init__(
self, worker_group_spec: typing.List[WorkerGroupSpec], head_group_spec: typing.Optional[HeadGroupSpec] = None
self,
worker_group_spec: typing.List[WorkerGroupSpec],
head_group_spec: typing.Optional[HeadGroupSpec] = None,
enable_autoscaling: bool = False,
):
self._head_group_spec = head_group_spec
self._worker_group_spec = worker_group_spec
self._enable_autoscaling = enable_autoscaling

@property
def head_group_spec(self) -> HeadGroupSpec:
Expand All @@ -148,13 +152,22 @@ def worker_group_spec(self) -> typing.List[WorkerGroupSpec]:
"""
return self._worker_group_spec

@property
def enable_autoscaling(self) -> bool:
"""
Whether to enable autoscaling.
:rtype: bool
"""
return self._enable_autoscaling

def to_flyte_idl(self) -> _ray_pb2.RayCluster:
"""
:rtype: flyteidl.plugins._ray_pb2.RayCluster
"""
return _ray_pb2.RayCluster(
head_group_spec=self.head_group_spec.to_flyte_idl() if self.head_group_spec else None,
worker_group_spec=[wg.to_flyte_idl() for wg in self.worker_group_spec],
enable_autoscaling=self.enable_autoscaling,
)

@classmethod
Expand All @@ -166,6 +179,7 @@ def from_flyte_idl(cls, proto):
return cls(
head_group_spec=HeadGroupSpec.from_flyte_idl(proto.head_group_spec) if proto.head_group_spec else None,
worker_group_spec=[WorkerGroupSpec.from_flyte_idl(wg) for wg in proto.worker_group_spec],
enable_autoscaling=proto.enable_autoscaling,
)


Expand All @@ -178,9 +192,13 @@ def __init__(
self,
ray_cluster: RayCluster,
runtime_env: typing.Optional[str],
ttl_seconds_after_finished: typing.Optional[int] = None,
shutdown_after_job_finishes: bool = False,
):
self._ray_cluster = ray_cluster
self._runtime_env = runtime_env
self._ttl_seconds_after_finished = ttl_seconds_after_finished
self._shutdown_after_job_finishes = shutdown_after_job_finishes

@property
def ray_cluster(self) -> RayCluster:
Expand All @@ -190,15 +208,29 @@ def ray_cluster(self) -> RayCluster:
def runtime_env(self) -> typing.Optional[str]:
return self._runtime_env

@property
def ttl_seconds_after_finished(self) -> typing.Optional[int]:
# ttl_seconds_after_finished specifies the number of seconds after which the RayCluster will be deleted after the RayJob finishes.
return self._ttl_seconds_after_finished

@property
def shutdown_after_job_finishes(self) -> bool:
# shutdown_after_job_finishes specifies whether the RayCluster should be deleted after the RayJob finishes.
return self._shutdown_after_job_finishes

def to_flyte_idl(self) -> _ray_pb2.RayJob:
return _ray_pb2.RayJob(
ray_cluster=self.ray_cluster.to_flyte_idl(),
runtime_env=self.runtime_env,
ttl_seconds_after_finished=self.ttl_seconds_after_finished,
shutdown_after_job_finishes=self.shutdown_after_job_finishes,
)

@classmethod
def from_flyte_idl(cls, proto: _ray_pb2.RayJob):
return cls(
ray_cluster=RayCluster.from_flyte_idl(proto.ray_cluster) if proto.ray_cluster else None,
runtime_env=proto.runtime_env,
ttl_seconds_after_finished=proto.ttl_seconds_after_finished,
shutdown_after_job_finishes=proto.shutdown_after_job_finishes,
)
6 changes: 6 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ class WorkerNodeConfig:
class RayJobConfig:
worker_node_config: typing.List[WorkerNodeConfig]
head_node_config: typing.Optional[HeadNodeConfig] = None
enable_autoscaling: bool = False
runtime_env: typing.Optional[dict] = None
address: typing.Optional[str] = None
shutdown_after_job_finishes: bool = False
ttl_seconds_after_finished: typing.Optional[int] = None


class RayFunctionTask(PythonFunctionTask):
Expand Down Expand Up @@ -67,9 +70,12 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params)
for c in cfg.worker_node_config
],
enable_autoscaling=cfg.enable_autoscaling if cfg.enable_autoscaling else False,
),
# Use base64 to encode runtime_env dict and convert it to byte string
runtime_env=base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode(),
ttl_seconds_after_finished=cfg.ttl_seconds_after_finished,
shutdown_after_job_finishes=cfg.shutdown_after_job_finishes,
)
return MessageToDict(ray_job.to_flyte_idl())

Expand Down
9 changes: 7 additions & 2 deletions plugins/flytekit-ray/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from flytekit.configuration import Image, ImageConfig, SerializationSettings

config = RayJobConfig(
worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3)],
worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10)],
runtime_env={"pip": ["numpy"]},
enable_autoscaling=True,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=20,
)


Expand All @@ -37,8 +40,10 @@ def t1(a: int) -> str:
)

ray_job_pb = RayJob(
ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec("test_group", 3)]),
ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec("test_group", 3, 0, 10)], enable_autoscaling=True),
runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(),
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=20,
).to_flyte_idl()

assert t1.get_custom(settings) == MessageToDict(ray_job_pb)
Expand Down

0 comments on commit 4208da2

Please sign in to comment.