diff --git a/plugins/flytekit-ray/flytekitplugins/ray/task.py b/plugins/flytekit-ray/flytekitplugins/ray/task.py index f0ee542b32..eeb166592a 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/task.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/task.py @@ -4,7 +4,12 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, Optional -from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec +from flytekitplugins.ray.models import ( + HeadGroupSpec, + RayCluster, + RayJob, + WorkerGroupSpec, +) from google.protobuf.json_format import MessageToDict from flytekit import lazy_module @@ -46,11 +51,19 @@ class RayFunctionTask(PythonFunctionTask): _RAY_TASK_TYPE = "ray" def __init__(self, task_config: RayJobConfig, task_function: Callable, **kwargs): - super().__init__(task_config=task_config, task_type=self._RAY_TASK_TYPE, task_function=task_function, **kwargs) + super().__init__( + task_config=task_config, + task_type=self._RAY_TASK_TYPE, + task_function=task_function, + **kwargs, + ) self._task_config = task_config def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: - ray.init(address=self._task_config.address) + ray.init( + address=self._task_config.address, + runtime_env={"working_dir": "/root"}, + ) return user_params def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: @@ -62,9 +75,19 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] ray_job = RayJob( ray_cluster=RayCluster( - head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None, + head_group_spec=( + HeadGroupSpec(cfg.head_node_config.ray_start_params) + if cfg.head_node_config + else None + ), worker_group_spec=[ - WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params) + WorkerGroupSpec( + c.group_name, + c.replicas, + c.min_replicas, + c.max_replicas, + c.ray_start_params, + ) for c in cfg.worker_node_config ], ),