Skip to content

Commit

Permalink
Support PodTemplate in ArrayNode (flyteorg#2088)
Browse files Browse the repository at this point in the history
* added _cmd_prefix handling

Signed-off-by: Daniel Rammer <[email protected]>

* fixed typing imports

Signed-off-by: Daniel Rammer <[email protected]>

* add get_config

Signed-off-by: Kevin Su <[email protected]>

* updating get_config to use underlying functions config

Signed-off-by: Daniel Rammer <[email protected]>

---------

Signed-off-by: Daniel Rammer <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
hamersaw and pingsutw authored Jan 5, 2024
1 parent 5841a1e commit b2f3b77
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
).hexdigest()
self._name = f"{mod}.map_{f}_{h}-arraynode"

self._cmd_prefix: Optional[List[str]] = None
self._concurrency: Optional[int] = concurrency
self._min_successes: Optional[int] = min_successes
self._min_success_ratio: Optional[float] = min_success_ratio
Expand Down Expand Up @@ -149,6 +150,9 @@ def prepare_target(self):
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return ArrayJob(parallelism=self._concurrency, min_success_ratio=self._min_success_ratio).to_dict()

def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]:
return self.python_function_task.get_config(settings)

def get_container(self, settings: SerializationSettings) -> Container:
with self.prepare_target():
return self.python_function_task.get_container(settings)
Expand Down Expand Up @@ -185,11 +189,13 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
*mt.loader_args(settings, self),
]

# TODO: add support for ContainerTask
# if self._cmd_prefix:
# return self._cmd_prefix + container_args
if self._cmd_prefix:
return self._cmd_prefix + container_args
return container_args

def set_command_prefix(self, cmd: Optional[List[str]]):
self._cmd_prefix = cmd

def __call__(self, *args, **kwargs):
"""
This call method modifies the kwargs and adds kwargs from partial.
Expand Down

0 comments on commit b2f3b77

Please sign in to comment.