Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
fiedlerNr9 committed Jul 25, 2024
1 parent 8b77a18 commit f82bc2c
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional

from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec
from flytekitplugins.ray.models import (
HeadGroupSpec,
RayCluster,
RayJob,
WorkerGroupSpec,
)
from google.protobuf.json_format import MessageToDict

from flytekit import lazy_module
Expand Down Expand Up @@ -46,11 +51,19 @@ class RayFunctionTask(PythonFunctionTask):
_RAY_TASK_TYPE = "ray"

def __init__(self, task_config: RayJobConfig, task_function: Callable, **kwargs):
super().__init__(task_config=task_config, task_type=self._RAY_TASK_TYPE, task_function=task_function, **kwargs)
super().__init__(
task_config=task_config,
task_type=self._RAY_TASK_TYPE,
task_function=task_function,
**kwargs,
)
self._task_config = task_config

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
ray.init(address=self._task_config.address)
ray.init(
address=self._task_config.address,
runtime_env={"working_dir": "/root"},
)
return user_params

def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
Expand All @@ -62,9 +75,19 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]

ray_job = RayJob(
ray_cluster=RayCluster(
head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None,
head_group_spec=(
HeadGroupSpec(cfg.head_node_config.ray_start_params)
if cfg.head_node_config
else None
),
worker_group_spec=[
WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params)
WorkerGroupSpec(
c.group_name,
c.replicas,
c.min_replicas,
c.max_replicas,
c.ray_start_params,
)
for c in cfg.worker_node_config
],
),
Expand Down

0 comments on commit f82bc2c

Please sign in to comment.