Skip to content

Commit

Permalink
Merge pull request #570 from opsani/OPTSERV-1269-servox-async-logic-h…
Browse files Browse the repository at this point in the history
…ardening

First pass update gather methods to prevent task leaks on error
  • Loading branch information
linkous8 authored Apr 25, 2024
2 parents 3c1a3f1 + 04b9811 commit f3ade34
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 165 deletions.
6 changes: 3 additions & 3 deletions servo/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ async def assemble(
)

# Attach all connectors to the servo
await asyncio.gather(
*list(map(lambda s: s.dispatch_event(servo.servo.Events.attach, s), servos))
)
async with asyncio.TaskGroup() as tg:
for s in servos:
_ = tg.create_task(s.dispatch_event(servo.servo.Events.attach, s))

return assembly

Expand Down
18 changes: 8 additions & 10 deletions servo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,24 +1123,22 @@ def print_callback(input: str) -> None:
# gather() expects a loop to exist at invocation time which is not compatible with the run_async
# execution model. wrap the gather in a standard async function to work around this
async def gather_checks():
return await asyncio.gather(
*list(
map(
lambda s: s.check_servo(print_callback),
context.assembly.servos,
)
),
)
async with asyncio.TaskGroup() as tg:
tasks = [
tg.create_task(s.check_servo(print_callback))
for s in context.assembly.servos
]
return (t.result() for t in tasks)

results = run_async(gather_checks())
ready = functools.reduce(lambda x, y: x and y, results)

except servo.ConnectorNotFoundError as e:
except* servo.ConnectorNotFoundError as e:
typer.echo(
"A connector named within the checks config was not found in the current Assembly"
)
raise typer.Exit(1) from e
except servo.EventHandlersNotFoundError as e:
except* servo.EventHandlersNotFoundError as e:
typer.echo(
"At least one configured connector must respond to the Check event (Note the servo "
"responds to checks so this error should never raise unless something is well and truly wrong"
Expand Down
115 changes: 50 additions & 65 deletions servo/connectors/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,27 +1141,15 @@ async def create_tuning_pod(self) -> V1Pod:
)
)
progress.start()

task = asyncio.create_task(PodHelper.wait_until_ready(tuning_pod))
task.add_done_callback(lambda _: progress.complete())
gather_task = asyncio.gather(
task,
progress.watch(progress_logger),
)

try:
await asyncio.wait_for(gather_task, timeout=self.timeout.total_seconds())
async with asyncio.timeout(delay=self.timeout.total_seconds()):
async with asyncio.TaskGroup() as tg:
task = tg.create_task(PodHelper.wait_until_ready(tuning_pod))
task.add_done_callback(lambda _: progress.complete())
_ = tg.create_task(progress.watch(progress_logger))

except asyncio.TimeoutError:
servo.logger.error(f"Timed out waiting for Tuning Pod to become ready...")
servo.logger.debug(f"Cancelling Task: {task}, progress: {progress}")
for t in {task, gather_task}:
t.cancel()
with contextlib.suppress(asyncio.CancelledError):
await t
servo.logger.debug(f"Cancelled Task: {t}, progress: {progress}")

# get latest status of tuning pod for raise_for_status
await self.raise_for_status()

# Hydrate local state
Expand Down Expand Up @@ -1631,49 +1619,46 @@ async def apply(self, adjustments: List[servo.Adjustment]) -> None:
# TODO: Run sanity checks to look for out of band changes

async def raise_for_status(self) -> None:
handle_error_tasks = []

def _raise_for_task(task: asyncio.Task, optimization: BaseOptimization) -> None:
if task.done() and not task.cancelled():
if exception := task.exception():
handle_error_tasks.append(
asyncio.create_task(optimization.handle_error(exception))
)

tasks = []
for optimization in self.optimizations:
task = asyncio.create_task(optimization.raise_for_status())
task.add_done_callback(
functools.partial(_raise_for_task, optimization=optimization)
)
tasks.append(task)

for future in asyncio.as_completed(
tasks, timeout=self.config.timeout.total_seconds()
):
try:
await future
except Exception as error:
servo.logger.exception(f"Optimization failed with error: {error}")
# TODO: first handle_error_task to raise will likely interrupt other tasks.
# Gather with return_exceptions=True and aggregate resulting exceptions into group before raising
async with asyncio.TaskGroup() as tg:

def _raise_for_task(
task: asyncio.Task, optimization: BaseOptimization
) -> None:
if task.done() and not task.cancelled():
if exception := task.exception():
_ = tg.create_task(optimization.handle_error(exception))

tasks = []
for optimization in self.optimizations:
task = asyncio.create_task(optimization.raise_for_status())
task.add_done_callback(
functools.partial(_raise_for_task, optimization=optimization)
)
tasks.append(task)

# TODO: first handler to raise will likely interrupt other tasks.
# Gather with return_exceptions=True and aggregate resulting exceptions before raising
await asyncio.gather(*handle_error_tasks)
for future in asyncio.as_completed(
tasks, timeout=self.config.timeout.total_seconds()
):
try:
await future
except Exception as error:
servo.logger.exception(f"Optimization failed with error: {error}")

async def is_ready(self):
if self.optimizations:
self.logger.debug(
f"Checking for readiness of {len(self.optimizations)} optimizations"
)
try:
results = await asyncio.wait_for(
asyncio.gather(
*list(map(lambda a: a.is_ready(), self.optimizations)),
),
timeout=self.config.timeout.total_seconds(),
)
async with asyncio.timeout(delay=self.config.timeout.total_seconds()):
async with asyncio.TaskGroup() as tg:
results = [
tg.create_task(o.is_ready()) for o in self.optimizations
]

return all(results)
return all((r.result() for r in results))

except asyncio.TimeoutError:
return False
Expand Down Expand Up @@ -2297,15 +2282,13 @@ async def adjust(
progress=p.progress,
)
progress = servo.EventProgress(timeout=self.config.timeout)
future = asyncio.create_task(state.apply(adjustments))
future.add_done_callback(lambda _: progress.trigger())

# Catch-all for spaghettified non-EventError usage
try:
await asyncio.gather(
future,
progress.watch(progress_logger),
)
async with asyncio.TaskGroup() as tg:
future = tg.create_task(state.apply(adjustments))
future.add_done_callback(lambda _: progress.trigger())
_ = tg.create_task(progress.watch(progress_logger))

# Handle settlement
settlement = control.settlement or self.config.settlement
Expand All @@ -2325,7 +2308,7 @@ async def readiness_monitor() -> None:
# Raise a specific exception if the optimization defines one
try:
await state.raise_for_status()
except servo.AdjustmentRejectedError as e:
except* servo.AdjustmentRejectedError as e:
# Update rejections with start-failed to indicate the initial rollout was successful
if e.reason == "start-failed":
e.reason = "unstable"
Expand All @@ -2350,6 +2333,11 @@ async def readiness_monitor() -> None:
)

description = state.to_description()
except ExceptionGroup as eg:
if any(isinstance(sub_e, servo.EventError) for sub_e in eg.exceptions):
raise
else:
raise servo.AdjustmentFailedError(str(eg.message)) from eg
except servo.EventError: # this is recognized by the runner
raise
except Exception as e:
Expand Down Expand Up @@ -2383,13 +2371,10 @@ async def _create_optimizations(self) -> KubernetesOptimizations:
)
progress = servo.EventProgress(timeout=self.config.timeout)
try:
future = asyncio.create_task(KubernetesOptimizations.create(self.config))
future.add_done_callback(lambda _: progress.trigger())

await asyncio.gather(
future,
progress.watch(progress_logger),
)
async with asyncio.TaskGroup() as tg:
future = tg.create_task(KubernetesOptimizations.create(self.config))
future.add_done_callback(lambda _: progress.trigger())
_ = tg.create_task(progress.watch(progress_logger))

return future.result()
except Exception as e:
Expand Down
35 changes: 17 additions & 18 deletions servo/connectors/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,28 +994,25 @@ async def measure(
),
)
fast_fail_progress = servo.EventProgress(timeout=measurement_duration)
gather_tasks = [
asyncio.create_task(progress.watch(self.observe)),
asyncio.create_task(
async with asyncio.TaskGroup() as tg:
_ = tg.create_task(progress.watch(self.observe))
_ = tg.create_task(
fast_fail_progress.watch(
fast_fail_observer.observe, every=self.config.fast_fail.period
fast_fail_observer.observe,
every=self.config.fast_fail.period,
)
),
]
try:
await asyncio.gather(*gather_tasks)
except:
[task.cancel() for task in gather_tasks]
await asyncio.gather(*gather_tasks, return_exceptions=True)
raise
)

else:
await progress.watch(self.observe)

# Capture the measurements
self.logger.info(f"Querying Prometheus for {len(metrics__)} metrics...")
readings = await asyncio.gather(
*list(map(lambda m: self._query_prometheus(m, start, end), metrics__))
)
async with asyncio.TaskGroup() as tg:
q_tasks = [
tg.create_task(self._query_prometheus(m, start, end)) for m in metrics__
]
readings = (qt.result() for qt in q_tasks)
all_readings = (
functools.reduce(lambda x, y: x + y, readings) if readings else []
)
Expand Down Expand Up @@ -1077,9 +1074,11 @@ async def _query_slo_metrics(
self, start: datetime, end: datetime, metrics: List[PrometheusMetric]
) -> Dict[str, List[servo.TimeSeries]]:
"""Query prometheus for the provided metrics and return mapping of metric names to their corresponding readings"""
readings = await asyncio.gather(
*list(map(lambda m: self._query_prometheus(m, start, end), metrics))
)
async with asyncio.TaskGroup() as tg:
q_tasks = [
tg.create_task(self._query_prometheus(m, start, end)) for m in metrics
]
readings = (qt.result() for qt in q_tasks)
return dict(map(lambda tup: (tup[0].name, tup[1]), zip(metrics, readings)))


Expand Down
79 changes: 45 additions & 34 deletions servo/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def decorator(fn: EventCallable) -> EventCallable:
if preposition == Preposition.before:
# 'before' event takes same args as 'on' event, but returns None
ref_signature = ref_signature.replace(return_annotation="None")

servo.utilities.inspect.assert_equal_callable_descriptors(
servo.utilities.inspect.CallableDescriptor(
signature=ref_signature,
Expand Down Expand Up @@ -891,6 +892,7 @@ async def run_event_handlers(
value=error,
)

# TODO refactor to use ExceptionGroups with events retrievable from exceptions
if return_exceptions:
results.append(error.__event_result__)
else:
Expand Down Expand Up @@ -1006,22 +1008,34 @@ async def run(self) -> List[EventResult]:

# Invoke the before event handlers
if self._prepositions & Preposition.before:
for connector in self._connectors:
try:
results = await connector.run_event_handlers(
self.event,
Preposition.before,
*self._args,
return_exceptions=False,
**self._kwargs,
)
try:
async with asyncio.TaskGroup() as tg:
for connector in self._connectors:
_ = tg.create_task(
connector.run_event_handlers(
self.event,
Preposition.before,
*self._args,
return_exceptions=False,
**self._kwargs,
)
)

except servo.errors.EventCancelledError as error:
except ExceptionGroup as eg:
cancelled_errs = [
sub_e
for sub_e in eg.exceptions
if isinstance(sub_e, servo.errors.EventCancelledError)
]
if cancelled_errs:
# Return an empty result set
canceller_names = (ce.connector.name for ce in cancelled_errs)
servo.logger.warning(
f'event cancelled by before event handler on connector "{connector.name}": {error}'
f'event cancelled by before event handler on connector "{", ".join(canceller_names)}": {eg.exceptions}'
)
return []
else:
raise

# Invoke the on event handlers and gather results
if self._prepositions & Preposition.on:
Expand All @@ -1038,36 +1052,33 @@ async def run(self) -> List[EventResult]:
if results:
break
else:
group = asyncio.gather(
*list(
map(
lambda c: c.run_event_handlers(
self.event,
Preposition.on,
return_exceptions=self._return_exceptions,
*self._args,
**self._kwargs,
),
self._connectors,
)
),
tasks = (
c.run_event_handlers(
self.event,
Preposition.on,
return_exceptions=self._return_exceptions,
*self._args,
**self._kwargs,
)
for c in self._connectors
)
results = await group
if self._return_exceptions:
results = await asyncio.gather(*tasks)
else:
async with asyncio.TaskGroup() as tg:
tg_tasks = [tg.create_task(t) for t in tasks]
results = (tt.result() for tt in tg_tasks)

results = list(filter(lambda r: r is not None, results))
results = functools.reduce(lambda x, y: x + y, results, [])

# Invoke the after event handlers
if self._prepositions & Preposition.after:
await asyncio.gather(
*list(
map(
lambda c: c.run_event_handlers(
self.event, Preposition.after, results
),
self._connectors,
async with asyncio.TaskGroup() as tg:
for c in self._connectors:
_ = tg.create_task(
c.run_event_handlers(self.event, Preposition.after, results)
)
)
)

if self.channel:
await self.channel.close()
Expand Down
Loading

0 comments on commit f3ade34

Please sign in to comment.