Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for min_success_ratio for local map_task execution #1884

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import hashlib
import logging
import math
import os # TODO: use flytekit logger
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Set, Union, cast
Expand Down Expand Up @@ -268,25 +269,38 @@
outputs_expected = False
outputs = []

mapped_input_value_len = 0
mapped_tasks_count = 0

Check warning on line 272 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L272

Added line #L272 was not covered by tests
if self._run_task.interface.inputs.items():
for k in self._run_task.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
mapped_input_value_len = len(v)
mapped_tasks_count = len(v)

Check warning on line 277 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L277

Added line #L277 was not covered by tests
break

for i in range(mapped_input_value_len):
failed_count = 0
min_successes = mapped_tasks_count

Check warning on line 281 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L280-L281

Added lines #L280 - L281 were not covered by tests
if self._min_successes:
min_successes = self._min_successes

Check warning on line 283 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L283

Added line #L283 was not covered by tests
elif self._min_success_ratio:
min_successes = math.ceil(min_successes * self._min_success_ratio)

Check warning on line 285 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L285

Added line #L285 was not covered by tests

for i in range(mapped_tasks_count):
single_instance_inputs = {}
for k in self.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self._bound_inputs:
single_instance_inputs[k] = kwargs[k][i]
else:
single_instance_inputs[k] = kwargs[k]
o = exception_scopes.user_entry_point(self.python_function_task.execute)(**single_instance_inputs)
if outputs_expected:
outputs.append(o)
try:
o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs)

Check warning on line 296 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L295-L296

Added lines #L295 - L296 were not covered by tests
if outputs_expected:
outputs.append(o)
except Exception as exc:
outputs.append(None)
failed_count += 1

Check warning on line 301 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L298-L301

Added lines #L298 - L301 were not covered by tests
if mapped_tasks_count - failed_count < min_successes:
raise exc

Check warning on line 303 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L303

Added line #L303 was not covered by tests

return outputs

Expand Down
24 changes: 18 additions & 6 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools
import hashlib
import logging
import math
import os
import typing
from contextlib import contextmanager
Expand Down Expand Up @@ -263,25 +264,36 @@
outputs_expected = False
outputs = []

mapped_input_value_len = 0
mapped_tasks_count = 0

Check warning on line 267 in flytekit/core/map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/map_task.py#L267

Added line #L267 was not covered by tests
if self._run_task.interface.inputs.items():
for k in self._run_task.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
mapped_input_value_len = len(v)
mapped_tasks_count = len(v)

Check warning on line 272 in flytekit/core/map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/map_task.py#L272

Added line #L272 was not covered by tests
break

for i in range(mapped_input_value_len):
failed_count = 0
min_successes = mapped_tasks_count

Check warning on line 276 in flytekit/core/map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/map_task.py#L275-L276

Added lines #L275 - L276 were not covered by tests
if self._min_success_ratio:
min_successes = math.ceil(min_successes * self._min_success_ratio)

Check warning on line 278 in flytekit/core/map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/map_task.py#L278

Added line #L278 was not covered by tests

for i in range(mapped_tasks_count):
single_instance_inputs = {}
for k in self.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
single_instance_inputs[k] = kwargs[k][i]
else:
single_instance_inputs[k] = kwargs[k]
o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs)
if outputs_expected:
outputs.append(o)
try:
o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs)

Check warning on line 289 in flytekit/core/map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/map_task.py#L288-L289

Added lines #L288 - L289 were not covered by tests
if outputs_expected:
outputs.append(o)
except Exception as exc:
outputs.append(None)
failed_count += 1

Check warning on line 294 in flytekit/core/map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/map_task.py#L291-L294

Added lines #L291 - L294 were not covered by tests
if mapped_tasks_count - failed_count < min_successes:
raise exc

Check warning on line 296 in flytekit/core/map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/map_task.py#L296

Added line #L296 was not covered by tests
pingsutw marked this conversation as resolved.
Show resolved Hide resolved

return outputs

Expand Down
28 changes: 28 additions & 0 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import typing
from collections import OrderedDict
from typing import List

Expand Down Expand Up @@ -254,3 +255,30 @@ def task3(c: str, a: int, b: float) -> str:
m3 = array_node_map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b)

assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"]


@pytest.mark.parametrize(
"min_success_ratio, should_raise_error",
[
(None, True),
(1, True),
(0.75, False),
(0.5, False),
],
)
def test_raw_execute_with_min_success_ratio(min_success_ratio, should_raise_error):
@task
def some_task1(inputs: int) -> int:
if inputs == 2:
raise ValueError("Unexpected inputs: 2")
return inputs

@workflow
def my_wf1() -> typing.List[typing.Optional[int]]:
return array_node_map_task(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4])

if should_raise_error:
with (pytest.raises(ValueError)):
my_wf1()
else:
assert my_wf1() == [1, None, 3, 4]
27 changes: 27 additions & 0 deletions tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,30 @@ def task3(c: str, a: int, b: float) -> str:
m3 = map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b)

assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"]


@pytest.mark.parametrize(
"min_success_ratio, should_raise_error",
[
(None, True),
(1, True),
(0.75, False),
(0.5, False),
],
)
def test_raw_execute_with_min_success_ratio(min_success_ratio, should_raise_error):
@task
def some_task1(inputs: int) -> int:
if inputs == 2:
raise ValueError("Unexpected inputs: 2")
return inputs

@workflow
def my_wf1() -> typing.List[typing.Optional[int]]:
return map_task(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4])

if should_raise_error:
with (pytest.raises(ValueError)):
my_wf1()
else:
assert my_wf1() == [1, None, 3, 4]
Loading