diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 9a334f98f6..f64b7d23dc 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, fields -from typing import Any, List, Optional, Union +from typing import List, Optional, Union from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements from mashumaro.mixins.json import DataClassJSONMixin @@ -103,11 +103,11 @@ def convert_resources_to_resource_model( def pod_spec_from_resources( - k8s_pod_name: str, + primary_container_name: Optional[str] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, k8s_gpu_resource_key: str = "nvidia.com/gpu", -) -> dict[str, Any]: +) -> V1PodSpec: def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str): if resources is None: return None @@ -133,10 +133,10 @@ def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resour requests = requests or limits limits = limits or requests - k8s_pod = V1PodSpec( + pod_spec = V1PodSpec( containers=[ V1Container( - name=k8s_pod_name, + name=primary_container_name, resources=V1ResourceRequirements( requests=requests, limits=limits, @@ -145,4 +145,4 @@ def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resour ] ) - return k8s_pod.to_dict() + return pod_spec diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 960555fd9b..9693390458 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -8,6 +8,7 @@ from flyteidl.core import tasks_pb2 as _core_task from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct +from kubernetes.client import ApiClient from flytekit.models import common as _common from flytekit.models import interface as _interface @@ -16,6 +17,9 @@ from flytekit.models.core import identifier as _identifier from flytekit.models.documentation import Documentation +if typing.TYPE_CHECKING: + from flytekit import PodTemplate + class Resources(_common.FlyteIdlEntity): class ResourceName(object): @@ -1042,6 +1046,22 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sPod): else None, ) + def to_pod_template(self) -> "PodTemplate": + from flytekit import PodTemplate + + return PodTemplate( + labels=self.metadata.labels, + annotations=self.metadata.annotations, + pod_spec=self.pod_spec, + ) + + @classmethod + def from_pod_template(cls, pod_template: "PodTemplate") -> "K8sPod": + return cls( + metadata=K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations), + pod_spec=ApiClient().sanitize_for_serialization(pod_template.pod_spec), + ) + class Sql(_common.FlyteIdlEntity): class Dialect(object): diff --git a/plugins/flytekit-ray/flytekitplugins/ray/task.py b/plugins/flytekit-ray/flytekitplugins/ray/task.py index 98a653a990..c87c86276d 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/task.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/task.py @@ -14,20 +14,30 @@ ) from google.protobuf.json_format import MessageToDict -from flytekit import lazy_module +from flytekit import PodTemplate, Resources, lazy_module from flytekit.configuration import SerializationSettings from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.resources import pod_spec_from_resources from flytekit.extend import TaskPlugins from flytekit.models.task import K8sPod ray = lazy_module("ray") +_RAY_HEAD_CONTAINER_NAME = "ray-head" +_RAY_WORKER_CONTAINER_NAME = "ray-worker" @dataclass class HeadNodeConfig: ray_start_params: typing.Optional[typing.Dict[str, str]] = None - k8s_pod: typing.Optional[K8sPod] = None + pod_template: typing.Optional[PodTemplate] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + + def __post_init__(self): + if self.pod_template: + if self.requests and self.limits: + raise ValueError("Cannot specify both pod_template and requests/limits") @dataclass @@ -37,7 +47,14 @@ class WorkerNodeConfig: min_replicas: typing.Optional[int] = None max_replicas: typing.Optional[int] = None ray_start_params: typing.Optional[typing.Dict[str, str]] = None - k8s_pod: typing.Optional[K8sPod] = None + pod_template: typing.Optional[PodTemplate] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + + def __post_init__(self): + if self.pod_template: + if self.requests and self.limits: + raise ValueError("Cannot specify both pod_template and requests/limits") @dataclass @@ -83,25 +100,49 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: cfg = self._task_config - # Deprecated: runtime_env is removed KubeRay >= 1.1.0. It is replaced by runtime_env_yaml runtime_env = base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode() if cfg.runtime_env else None - runtime_env_yaml = yaml.dump(cfg.runtime_env) if cfg.runtime_env else None + if cfg.head_node_config.requests or cfg.head_node_config.limits: + head_pod_template = PodTemplate( + pod_spec=pod_spec_from_resources( + primary_container_name=_RAY_HEAD_CONTAINER_NAME, + requests=cfg.head_node_config.requests, + limits=cfg.head_node_config.limits, + ) + ) + else: + head_pod_template = cfg.head_node_config.pod_template + + worker_group_spec: typing.List[WorkerGroupSpec] = [] + for c in cfg.worker_node_config: + if c.requests or c.limits: + worker_pod_template = PodTemplate( + pod_spec=pod_spec_from_resources( + primary_container_name=_RAY_WORKER_CONTAINER_NAME, + requests=c.requests, + limits=c.limits, + ) + ) + else: + worker_pod_template = c.pod_template + k8s_pod = K8sPod.from_pod_template(worker_pod_template) if worker_pod_template else None + worker_group_spec.append( + WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params, k8s_pod) + ) + ray_job = RayJob( ray_cluster=RayCluster( head_group_spec=( - HeadGroupSpec(cfg.head_node_config.ray_start_params, cfg.head_node_config.k8s_pod) + HeadGroupSpec( + cfg.head_node_config.ray_start_params, + K8sPod.from_pod_template(head_pod_template) if head_pod_template else None, + ) if cfg.head_node_config else None ), - worker_group_spec=[ - WorkerGroupSpec( - c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params, c.k8s_pod - ) - for c in cfg.worker_node_config - ], + worker_group_spec=worker_group_spec, enable_autoscaling=(cfg.enable_autoscaling if cfg.enable_autoscaling else False), ), runtime_env=runtime_env, diff --git a/plugins/flytekit-ray/setup.py b/plugins/flytekit-ray/setup.py index 18b95498ee..2237be2030 100644 --- a/plugins/flytekit-ray/setup.py +++ b/plugins/flytekit-ray/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["ray[default]", "flytekit>=1.3.0b2,<2.0.0", "flyteidl>=1.13.6"] +plugin_requires = ["ray[default]", "flytekit>1.14.5", "flyteidl>=1.13.6"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-ray/tests/test_ray.py b/plugins/flytekit-ray/tests/test_ray.py index c943067013..8fd8d432a9 100644 --- a/plugins/flytekit-ray/tests/test_ray.py +++ b/plugins/flytekit-ray/tests/test_ray.py @@ -3,6 +3,8 @@ import ray import yaml + +from flytekit.core.resources import pod_spec_from_resources from flytekitplugins.ray import HeadNodeConfig from flytekitplugins.ray.models import ( HeadGroupSpec, @@ -13,10 +15,17 @@ from flytekitplugins.ray.task import RayJobConfig, WorkerNodeConfig from google.protobuf.json_format import MessageToDict -from flytekit import PythonFunctionTask, task +from flytekit import PythonFunctionTask, task, PodTemplate, Resources from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.models.task import K8sPod + +pod_template=PodTemplate( + primary_container_name="primary", + labels={"lKeyA": "lValA"}, + annotations={"aKeyA": "aValA"}, + ) + config = RayJobConfig( worker_node_config=[ WorkerNodeConfig( @@ -24,10 +33,10 @@ replicas=3, min_replicas=0, max_replicas=10, - k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}), + pod_template=pod_template, ) ], - head_node_config=HeadNodeConfig(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})), + head_node_config=HeadNodeConfig(requests=Resources(cpu="1", mem="1Gi"), limits=Resources(cpu="2", mem="2Gi")), runtime_env={"pip": ["numpy"]}, enable_autoscaling=True, shutdown_after_job_finishes=True, @@ -55,6 +64,13 @@ def t1(a: int) -> str: image_config=ImageConfig(default_image=default_img, images=[default_img]), env={}, ) + head_pod_template = PodTemplate( + pod_spec=pod_spec_from_resources( + primary_container_name="ray-head", + requests=Resources(cpu="1", mem="1Gi"), + limits=Resources(cpu="2", mem="2Gi"), + ) + ) ray_job_pb = RayJob( ray_cluster=RayCluster( @@ -64,10 +80,10 @@ def t1(a: int) -> str: replicas=3, min_replicas=0, max_replicas=10, - k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}), + k8s_pod=K8sPod.from_pod_template(pod_template), ) ], - head_group_spec=HeadGroupSpec(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})), + head_group_spec=HeadGroupSpec(k8s_pod=K8sPod.from_pod_template(head_pod_template)), enable_autoscaling=True, ), runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(), diff --git a/tests/flytekit/unit/core/test_resources.py b/tests/flytekit/unit/core/test_resources.py index 1c09a111e3..115605b055 100644 --- a/tests/flytekit/unit/core/test_resources.py +++ b/tests/flytekit/unit/core/test_resources.py @@ -110,12 +110,12 @@ def test_resources_round_trip(): def test_pod_spec_from_resources_requests_limits_set(): requests = Resources(cpu="1", mem="1Gi", gpu="1", ephemeral_storage="1Gi") limits = Resources(cpu="4", mem="2Gi", gpu="1", ephemeral_storage="1Gi") - k8s_pod_name = "foo" + primary_container_name = "foo" expected_pod_spec = V1PodSpec( containers=[ V1Container( - name=k8s_pod_name, + name=primary_container_name, resources=V1ResourceRequirements( requests={ "cpu": "1", @@ -133,19 +133,19 @@ def test_pod_spec_from_resources_requests_limits_set(): ) ] ) - pod_spec = pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits) - assert expected_pod_spec == V1PodSpec(**pod_spec) + pod_spec = pod_spec_from_resources(primary_container_name=primary_container_name, requests=requests, limits=limits) + assert expected_pod_spec == pod_spec def test_pod_spec_from_resources_requests_set(): requests = Resources(cpu="1", mem="1Gi") limits = None - k8s_pod_name = "foo" + primary_container_name = "foo" expected_pod_spec = V1PodSpec( containers=[ V1Container( - name=k8s_pod_name, + name=primary_container_name, resources=V1ResourceRequirements( requests={"cpu": "1", "memory": "1Gi"}, limits={"cpu": "1", "memory": "1Gi"}, @@ -153,5 +153,5 @@ def test_pod_spec_from_resources_requests_set(): ) ] ) - pod_spec = pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits) - assert expected_pod_spec == V1PodSpec(**pod_spec) + pod_spec = pod_spec_from_resources(primary_container_name=primary_container_name, requests=requests, limits=limits) + assert expected_pod_spec == pod_spec