Skip to content

Commit

Permalink
Add support for min_success_ratio for local map_task execution (flyte…
Browse files Browse the repository at this point in the history
…org#1884)

* Add min_success_ratio logic in map_task._raw_execute. Add test.

Signed-off-by: Chao-Heng Lee <[email protected]>

* also update array_node_map_task.

Signed-off-by: Chao-Heng Lee <[email protected]>

* add log with error.

Signed-off-by: Chao-Heng Lee <[email protected]>

---------

Signed-off-by: Chao-Heng Lee <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
2 people authored and ringohoffman committed Nov 24, 2023
1 parent f591e2d commit d6b771d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 12 deletions.
28 changes: 22 additions & 6 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import hashlib
import logging
import math
import os # TODO: use flytekit logger
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Set, Union, cast
Expand All @@ -14,6 +15,7 @@
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.utils import timeit
from flytekit.exceptions import scopes as exception_scopes
from flytekit.loggers import logger
from flytekit.models.array_job import ArrayJob
from flytekit.models.core.workflow import NodeMetadata
from flytekit.models.interface import Variable
Expand Down Expand Up @@ -267,25 +269,39 @@ def _raw_execute(self, **kwargs) -> Any:
outputs_expected = False
outputs = []

mapped_input_value_len = 0
mapped_tasks_count = 0
if self._run_task.interface.inputs.items():
for k in self._run_task.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
mapped_input_value_len = len(v)
mapped_tasks_count = len(v)
break

for i in range(mapped_input_value_len):
failed_count = 0
min_successes = mapped_tasks_count
if self._min_successes:
min_successes = self._min_successes
elif self._min_success_ratio:
min_successes = math.ceil(min_successes * self._min_success_ratio)

for i in range(mapped_tasks_count):
single_instance_inputs = {}
for k in self.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self._bound_inputs:
single_instance_inputs[k] = kwargs[k][i]
else:
single_instance_inputs[k] = kwargs[k]
o = exception_scopes.user_entry_point(self.python_function_task.execute)(**single_instance_inputs)
if outputs_expected:
outputs.append(o)
try:
o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs)
if outputs_expected:
outputs.append(o)
except Exception as exc:
outputs.append(None)
failed_count += 1
if mapped_tasks_count - failed_count < min_successes:
logger.error("The number of successful tasks is lower than the minimum ratio")
raise exc

return outputs

Expand Down
26 changes: 20 additions & 6 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools
import hashlib
import logging
import math
import os
import typing
from contextlib import contextmanager
Expand All @@ -20,6 +21,7 @@
from flytekit.core.tracker import TrackedInstance
from flytekit.core.utils import timeit
from flytekit.exceptions import scopes as exception_scopes
from flytekit.loggers import logger
from flytekit.models.array_job import ArrayJob
from flytekit.models.interface import Variable
from flytekit.models.task import Container, K8sPod, Sql
Expand Down Expand Up @@ -263,25 +265,37 @@ def _raw_execute(self, **kwargs) -> Any:
outputs_expected = False
outputs = []

mapped_input_value_len = 0
mapped_tasks_count = 0
if self._run_task.interface.inputs.items():
for k in self._run_task.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
mapped_input_value_len = len(v)
mapped_tasks_count = len(v)
break

for i in range(mapped_input_value_len):
failed_count = 0
min_successes = mapped_tasks_count
if self._min_success_ratio:
min_successes = math.ceil(min_successes * self._min_success_ratio)

for i in range(mapped_tasks_count):
single_instance_inputs = {}
for k in self.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
single_instance_inputs[k] = kwargs[k][i]
else:
single_instance_inputs[k] = kwargs[k]
o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs)
if outputs_expected:
outputs.append(o)
try:
o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs)
if outputs_expected:
outputs.append(o)
except Exception as exc:
outputs.append(None)
failed_count += 1
if mapped_tasks_count - failed_count < min_successes:
logger.error("The number of successful tasks is lower than the minimum ratio")
raise exc

return outputs

Expand Down
28 changes: 28 additions & 0 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import typing
from collections import OrderedDict
from typing import List

Expand Down Expand Up @@ -254,3 +255,30 @@ def task3(c: str, a: int, b: float) -> str:
m3 = array_node_map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b)

assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"]


@pytest.mark.parametrize(
"min_success_ratio, should_raise_error",
[
(None, True),
(1, True),
(0.75, False),
(0.5, False),
],
)
def test_raw_execute_with_min_success_ratio(min_success_ratio, should_raise_error):
@task
def some_task1(inputs: int) -> int:
if inputs == 2:
raise ValueError("Unexpected inputs: 2")
return inputs

@workflow
def my_wf1() -> typing.List[typing.Optional[int]]:
return array_node_map_task(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4])

if should_raise_error:
with (pytest.raises(ValueError)):
my_wf1()
else:
assert my_wf1() == [1, None, 3, 4]
27 changes: 27 additions & 0 deletions tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,30 @@ def task3(c: str, a: int, b: float) -> str:
m3 = map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b)

assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"]


@pytest.mark.parametrize(
"min_success_ratio, should_raise_error",
[
(None, True),
(1, True),
(0.75, False),
(0.5, False),
],
)
def test_raw_execute_with_min_success_ratio(min_success_ratio, should_raise_error):
@task
def some_task1(inputs: int) -> int:
if inputs == 2:
raise ValueError("Unexpected inputs: 2")
return inputs

@workflow
def my_wf1() -> typing.List[typing.Optional[int]]:
return map_task(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4])

if should_raise_error:
with (pytest.raises(ValueError)):
my_wf1()
else:
assert my_wf1() == [1, None, 3, 4]

0 comments on commit d6b771d

Please sign in to comment.