Skip to content

Commit

Permalink
Endpoint support for automatic task retries (#1342)
Browse files Browse the repository at this point in the history
This PR adds automatic retries to the GlobusComputeEngine via a new kwarg option `max_retries_on_system_failure`.
If set, the engine will automatically resubmit any infrastructure-failed tasks.
This is designed to *only* handle infrastructure level failures such as `ManagerLost` (often from walltime-truncated batch jobs)
and does not handle task failures (e.g `KeyError` during function execution). If the task fails to complete even after
the maximum allowed retries the full exception history will be reported.

The core functionality is implemented via `GlobusComputeEngineBase._retry_table` that tracks exception history
and the Engine specific error handler: `GlobusComputeEngine._handle_task_exception`.

By default, `GlobusComputeEngine.max_retries_on_system_failure` is set to 0 since retrying compute intensive
tasks could unintentionally waste the user's resource allocation. Here's a yaml config example that uses this
engine setting:

```
engine:
    type: GlobusComputeEngine
    max_retries_on_system_failure: 2
```
  • Loading branch information
yadudoc authored Dec 1, 2023
1 parent 420cc0b commit 24bcd99
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 23 deletions.
14 changes: 14 additions & 0 deletions changelog.d/20231031_230517_yadudoc1729_task_retries_1.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
New Functionality
^^^^^^^^^^^^^^^^^

- ``GlobusComputeEngine`` can now be configured to automatically retry task failures when
node failures (e.g nodes are lost due to batch job reaching walltime) occur. This option
is set to 0 by default to avoid unintentional resource wastage from retrying tasks.
Traceback history from all prior attempts is supplied if the last retry attempt fails.
Here's a snippet from config.yaml:

.. code-block:: yaml
engine:
type: GlobusComputeEngine
max_retries_on_system_failure: 2
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class EngineModel(BaseConfigModel):
worker_ports: t.Optional[t.Tuple[int, int]]
worker_port_range: t.Optional[t.Tuple[int, int]]
interchange_port_range: t.Optional[t.Tuple[int, int]]
max_retries_on_system_failure: t.Optional[int]

_validate_type = _validate_import("type", engines)
_validate_provider = _validate_params("provider")
Expand Down
87 changes: 67 additions & 20 deletions compute_endpoint/globus_compute_endpoint/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)

logger = logging.getLogger(__name__)
_EXC_HISTORY_TMPL = "+" * 68 + "\nTraceback from attempt: {ndx}\n{exc}\n" + "-" * 68


class ReportingThread:
Expand Down Expand Up @@ -76,11 +77,13 @@ def __init__(
self,
*args: object,
endpoint_id: t.Optional[uuid.UUID] = None,
max_retries_on_system_failure: int = 0,
**kwargs: object,
):
self._shutdown_event = threading.Event()
self.endpoint_id = endpoint_id

self.max_retries_on_system_failure = max_retries_on_system_failure
self._retry_table: t.Dict[str, t.Dict] = {}
# remove these unused vars that we are adding to just keep
# endpoint interchange happy
self.container_type: t.Optional[str] = None
Expand Down Expand Up @@ -115,6 +118,53 @@ def _status_report(self, shutdown_event: threading.Event, heartbeat_period: floa
packed = messagepack.pack(status_report)
self.results_passthrough.put({"message": packed})

def _handle_task_exception(
self,
task_id: str,
execution_begin: TaskTransition,
exception: BaseException,
) -> bytes:
"""Repackage task exception to messagepack'ed bytes
Parameters
----------
task_id: str
execution_begin: TaskTransition
exception: Exception object from the task failure
Returns
-------
bytes
"""
code, user_message = get_result_error_details(exception)
error_details = {"code": code, "user_message": user_message}
execution_end = TaskTransition(
timestamp=time.time_ns(),
actor=ActorName.INTERCHANGE,
state=TaskState.EXEC_END,
)
exception_string = ""
for index, prev_exc in enumerate(
self._retry_table[task_id]["exception_history"]
):
templated_history = _EXC_HISTORY_TMPL.format(
ndx=index + 1, exc=get_error_string(exc=prev_exc)
)
exception_string += templated_history

final = _EXC_HISTORY_TMPL.format(
ndx="final attempt", exc=get_error_string(exc=exception)
)
exception_string += final

result_message = dict(
task_id=task_id,
data=exception_string,
exception=exception_string,
error_details=error_details,
task_statuses=[execution_begin, execution_end], # only timings we have
)
return messagepack.pack(Result(**result_message))

def _setup_future_done_callback(self, task_id: str, future: Future) -> None:
"""
Set up the done() callback for the provided future.
Expand All @@ -133,27 +183,18 @@ def _setup_future_done_callback(self, task_id: str, future: Future) -> None:
)

def _done_cb(f: Future):
if f.exception():
exc = f.exception()
code, user_message = get_result_error_details(exc)
error_details = {"code": code, "user_message": user_message}
exec_end = TaskTransition(
timestamp=time.time_ns(),
actor=ActorName.INTERCHANGE,
state=TaskState.EXEC_END,
)
result_message = dict(
task_id=task_id,
data=get_error_string(exc=exc),
exception=get_error_string(exc=exc),
error_details=error_details,
task_statuses=[exec_beg, exec_end], # only transition info we have
)
packed = messagepack.pack(Result(**result_message))
else:
try:
packed = f.result()
except Exception as exception:
packed = self._handle_task_exception(
task_id=task_id, execution_begin=exec_beg, exception=exception
)

self.results_passthrough.put({"task_id": task_id, "message": packed})
if packed:
# _handle_task_exception can return empty bytestring
# when it retries task, indicating there's no task status update
self.results_passthrough.put({"task_id": task_id, "message": packed})
self._retry_table.pop(task_id, None)

future.add_done_callback(_done_cb)

Expand All @@ -178,6 +219,12 @@ def submit(self, task_id: str, packed_task: bytes) -> Future:
future
"""

if task_id not in self._retry_table:
self._retry_table[task_id] = {
"retry_count": 0,
"packed_task": packed_task,
"exception_history": [],
}
future = self._submit(execute_task, task_id, packed_task)
self._setup_future_done_callback(task_id, future)
return future
Expand Down
29 changes: 27 additions & 2 deletions compute_endpoint/globus_compute_endpoint/engines/globus_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,24 @@ def __init__(
self,
*args,
label: str = "GlobusComputeEngine",
max_retries_on_system_failure: int = 0,
strategy: t.Optional[SimpleStrategy] = SimpleStrategy(),
executor: t.Optional[HighThroughputExecutor] = None,
**kwargs,
):
self.run_dir = os.getcwd()
self.label = label
self._status_report_thread = ReportingThread(target=self.report_status, args=[])
super().__init__(*args, **kwargs)
super().__init__(
*args, max_retries_on_system_failure=max_retries_on_system_failure, **kwargs
)
self.strategy = strategy
self.max_workers_per_node = 1
if executor is None:
executor = HighThroughputExecutor( # type: ignore
*args, label=label, **kwargs
*args,
label=label,
**kwargs,
)
self.executor = executor

Expand Down Expand Up @@ -213,6 +218,26 @@ def scale_in(self, blocks: int):
logger.info(f"Scaling in {blocks} blocks")
return self.executor.scale_in(blocks=blocks)

def _handle_task_exception(
self,
task_id: str,
execution_begin: TaskTransition,
exception: BaseException,
) -> bytes:
result_bytes = b""
retry_info = self._retry_table[task_id]
if retry_info["retry_count"] < self.max_retries_on_system_failure:
retry_info["retry_count"] += 1
retry_info["exception_history"].append(exception)
self.submit(task_id, retry_info["packed_task"])
else:
# This is a terminal state
result_bytes = super()._handle_task_exception(
task_id=task_id, execution_begin=execution_begin, exception=exception
)

return result_bytes

@property
def scaling_enabled(self) -> bool:
"""Indicates whether scaling is possible"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import uuid
from queue import Queue

import pytest
from globus_compute_common import messagepack
from globus_compute_endpoint.engines import GlobusComputeEngine
from globus_compute_endpoint.strategies import SimpleStrategy
from globus_compute_sdk.serialize import ComputeSerializer
from parsl.executors.high_throughput.interchange import ManagerLost
from parsl.providers import LocalProvider
from tests.utils import ez_pack_function, kill_manager, succeed_after_n_runs


@pytest.fixture
def gc_engine_with_retries(tmp_path):
ep_id = uuid.uuid4()
engine = GlobusComputeEngine(
address="127.0.0.1",
max_workers=1,
heartbeat_period=1,
heartbeat_threshold=1,
max_retries_on_system_failure=0,
provider=LocalProvider(
init_blocks=0,
min_blocks=0,
max_blocks=1,
),
strategy=SimpleStrategy(interval=0.1, max_idletime=0),
)
engine._status_report_thread.reporting_period = 1
queue = Queue()
engine.start(endpoint_id=ep_id, run_dir=tmp_path, results_passthrough=queue)
yield engine
engine.shutdown()


def test_gce_kill_manager(gc_engine_with_retries):
engine = gc_engine_with_retries
engine.max_retries_on_system_failure = 0
queue = engine.results_passthrough
task_id = uuid.uuid1()
serializer = ComputeSerializer()

# Confirm error message for ManagerLost
task_body = ez_pack_function(serializer, kill_manager, (), {})
task_message = messagepack.pack(
messagepack.message_types.Task(task_id=task_id, task_buffer=task_body)
)

future = engine.submit(task_id, task_message)

with pytest.raises(ManagerLost):
future.result()

flag = False
for _i in range(4):
q_msg = queue.get(timeout=2)
assert isinstance(q_msg, dict)

packed_result_q = q_msg["message"]
result = messagepack.unpack(packed_result_q)
if isinstance(result, messagepack.message_types.Result):
assert result.task_id == task_id
if result.error_details and "ManagerLost" in result.data:
flag = True
break

assert flag, "Result message missing"


def test_success_after_1_fail(gc_engine_with_retries, tmp_path):
engine = gc_engine_with_retries
engine.max_retries_on_system_failure = 2
fail_count = 1
queue = engine.results_passthrough
task_id = uuid.uuid1()
serializer = ComputeSerializer()
task_body = ez_pack_function(
serializer, succeed_after_n_runs, (tmp_path,), {"fail_count": fail_count}
)
task_message = messagepack.pack(
messagepack.message_types.Task(task_id=task_id, task_buffer=task_body)
)
engine.submit(task_id, task_message)

flag = False
for _i in range(10):
q_msg = queue.get(timeout=5)
assert isinstance(q_msg, dict)

packed_result_q = q_msg["message"]
result = messagepack.unpack(packed_result_q)
if isinstance(result, messagepack.message_types.Result):
assert result.task_id == task_id
assert result.error_details is None
flag = True
break

assert flag, "Expected result packet, but none received"


def test_repeated_fail(gc_engine_with_retries, tmp_path):
engine = gc_engine_with_retries
engine.max_retries_on_system_failure = 2
fail_count = 3
queue = engine.results_passthrough
task_id = uuid.uuid1()
serializer = ComputeSerializer()
task_body = ez_pack_function(
serializer, succeed_after_n_runs, (tmp_path,), {"fail_count": fail_count}
)
task_message = messagepack.pack(
messagepack.message_types.Task(task_id=task_id, task_buffer=task_body)
)
engine.submit(task_id, task_message)

flag = False
for _i in range(10):
q_msg = queue.get(timeout=5)
assert isinstance(q_msg, dict)

packed_result_q = q_msg["message"]
result = messagepack.unpack(packed_result_q)
if isinstance(result, messagepack.message_types.Result):
assert result.task_id == task_id
assert result.error_details
assert "ManagerLost" in result.data
count = result.data.count("Traceback from attempt")
assert count == fail_count, "Got incorrect # of failure reports"
assert "final attempt" in result.data
flag = True
break

assert flag, "Expected ManagerLost in failed result.data, but none received"


def test_default_retries_is_0():
engine = GlobusComputeEngine(address="127.0.0.1")
assert engine.max_retries_on_system_failure == 0, "Users must knowingly opt-in"
21 changes: 20 additions & 1 deletion compute_endpoint/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import pathlib
import sys
import time
import types
Expand Down Expand Up @@ -101,8 +102,26 @@ def kill_manager():
import signal

manager_pid = os.getppid()
os.kill(manager_pid, signal.SIGKILL)
manager_pgid = os.getpgid(manager_pid)
os.killpg(manager_pgid, signal.SIGKILL)


def div_zero(x: int):
return x / 0


def succeed_after_n_runs(dirpath: pathlib.Path, fail_count: int = 1):
import os
import signal
from glob import glob

prior_run_count = len(glob(os.path.join(dirpath, "foo.*.txt")))
with open(os.path.join(dirpath, f"foo.{prior_run_count+1}.txt"), "w+") as f:
f.write(f"Hello at {time} counter={prior_run_count+1}")

if prior_run_count < fail_count:
manager_pid = os.getppid()
manager_pgid = os.getpgid(manager_pid)
os.killpg(manager_pgid, signal.SIGKILL)

return f"Success on attempt: {prior_run_count+1}"

0 comments on commit 24bcd99

Please sign in to comment.