Skip to content

Commit

Permalink
Add AWS deferrable BatchOperator (#29300)
Browse files Browse the repository at this point in the history
This PR donates the following BatchOperator deferrable developed in [astronomer-providers](https://github.com/astronomer/astronomer-providers) repo to apache airflow.
  • Loading branch information
rajaths010494 authored Apr 5, 2023
1 parent dc4dd91 commit 77c272e
Show file tree
Hide file tree
Showing 8 changed files with 866 additions and 3 deletions.
239 changes: 238 additions & 1 deletion airflow/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@
"""
from __future__ import annotations

import asyncio
from random import uniform
from time import sleep
from typing import Any

import botocore.client
import botocore.exceptions
import botocore.waiter

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook
from airflow.typing_compat import Protocol, runtime_checkable


Expand Down Expand Up @@ -544,3 +546,238 @@ def exp(tries):
delay = 1 + pow(tries * 0.6, 2)
delay = min(max_interval, delay)
return uniform(delay / 3, delay)


class BatchClientAsyncHook(BatchClientHook, AwsBaseAsyncHook):
"""
Async client for AWS Batch services.
:param job_id: the job ID, usually unknown (None) until the
submit_job operation gets the jobId defined by AWS Batch
:param waiters: an :py:class:`.BatchWaiters` object (see note below);
if None, polling is used with max_retries and status_retries.
.. note::
Several methods use a default random delay to check or poll for job status, i.e.
``random.sample()``
Using a random interval helps to avoid AWS API throttle limits
when many concurrent tasks request job-descriptions.
To modify the global defaults for the range of jitter allowed when a
random delay is used to check Batch job status, modify these defaults, e.g.:
BatchClient.DEFAULT_DELAY_MIN = 0
BatchClient.DEFAULT_DELAY_MAX = 5
When explicit delay values are used, a 1 second random jitter is applied to the
delay . It is generally recommended that random jitter is added to API requests.
A convenience method is provided for this, e.g. to get a random delay of
10 sec +/- 5 sec: ``delay = BatchClient.add_jitter(10, width=5, minima=0)``
"""

def __init__(self, job_id: str | None, waiters: Any = None, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.job_id = job_id
self.waiters = waiters

async def monitor_job(self) -> dict[str, str] | None:
"""
Monitor an AWS Batch job
monitor_job can raise an exception or an AirflowTaskTimeout can be raised if execution_timeout
is given while creating the task. These exceptions should be handled in taskinstance.py
instead of here like it was previously done
:raises: AirflowException
"""
if not self.job_id:
raise AirflowException("AWS Batch job - job_id was not found")

if self.waiters:
self.waiters.wait_for_job(self.job_id)
return None
else:
await self.wait_for_job(self.job_id)
await self.check_job_success(self.job_id)
success_msg = f"AWS Batch job ({self.job_id}) succeeded"
self.log.info(success_msg)
return {"status": "success", "message": success_msg}

async def check_job_success(self, job_id: str) -> bool: # type: ignore[override]
"""
Check the final status of the Batch job; return True if the job
'SUCCEEDED', else raise an AirflowException
:param job_id: a Batch job ID
:raises: AirflowException
"""
job = await self.get_job_description(job_id)
job_status = job.get("status")
if job_status == self.SUCCESS_STATE:
self.log.info("AWS Batch job (%s) succeeded: %s", job_id, job)
return True

if job_status == self.FAILURE_STATE:
raise AirflowException(f"AWS Batch job ({job_id}) failed: {job}")

if job_status in self.INTERMEDIATE_STATES:
raise AirflowException(f"AWS Batch job ({job_id}) is not complete: {job}")

raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}")

@staticmethod
async def delay(delay: int | float | None = None) -> None: # type: ignore[override]
"""
Pause execution for ``delay`` seconds.
:param delay: a delay to pause execution using ``time.sleep(delay)``;
a small 1 second jitter is applied to the delay.
.. note::
This method uses a default random delay, i.e.
``random.sample()``;
using a random interval helps to avoid AWS API throttle limits
when many concurrent tasks request job-descriptions.
"""
if delay is None:
delay = uniform(BatchClientHook.DEFAULT_DELAY_MIN, BatchClientHook.DEFAULT_DELAY_MAX)
else:
delay = BatchClientAsyncHook.add_jitter(delay)
await asyncio.sleep(delay)

async def wait_for_job( # type: ignore[override]
self, job_id: str, delay: int | float | None = None
) -> None:
"""
Wait for Batch job to complete.
:param job_id: a Batch job ID
:param delay: a delay before polling for job status
:raises: AirflowException
"""
await self.delay(delay)
await self.poll_for_job_running(job_id, delay)
await self.poll_for_job_complete(job_id, delay)
self.log.info("AWS Batch job (%s) has completed", job_id)

async def poll_for_job_complete( # type: ignore[override]
self, job_id: str, delay: int | float | None = None
) -> None:
"""
Poll for job completion. The status that indicates job completion
are: 'SUCCEEDED'|'FAILED'.
So the status options that this will wait for are the transitions from:
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED'
:param job_id: a Batch job ID
:param delay: a delay before polling for job status
:raises: AirflowException
"""
await self.delay(delay)
complete_status = [self.SUCCESS_STATE, self.FAILURE_STATE]
await self.poll_job_status(job_id, complete_status)

async def poll_for_job_running( # type: ignore[override]
self, job_id: str, delay: int | float | None = None
) -> None:
"""
Poll for job running. The status that indicates a job is running or
already complete are: 'RUNNING'|'SUCCEEDED'|'FAILED'.
So the status options that this will wait for are the transitions from:
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'|'SUCCEEDED'|'FAILED'
The completed status options are included for cases where the status
changes too quickly for polling to detect a RUNNING status that moves
quickly from STARTING to RUNNING to completed (often a failure).
:param job_id: a Batch job ID
:param delay: a delay before polling for job status
:raises: AirflowException
"""
await self.delay(delay)
running_status = [self.RUNNING_STATE, self.SUCCESS_STATE, self.FAILURE_STATE]
await self.poll_job_status(job_id, running_status)

async def get_job_description(self, job_id: str) -> dict[str, str]: # type: ignore[override]
"""
Get job description (using status_retries).
:param job_id: a Batch job ID
:raises: AirflowException
"""
retries = 0
async with await self.get_client_async() as client:
while True:
try:
response = client.describe_jobs(jobs=[job_id])
return self.parse_job_description(job_id, response)

except botocore.exceptions.ClientError as err:
error = err.response.get("Error", {})
if error.get("Code") == "TooManyRequestsException":
pass # allow it to retry, if possible
else:
raise AirflowException(f"AWS Batch job ({job_id}) description error: {err}")

retries += 1
if retries >= self.status_retries:
raise AirflowException(
f"AWS Batch job ({job_id}) description error: exceeded status_retries "
f"({self.status_retries})"
)

pause = self.exponential_delay(retries)
self.log.info(
"AWS Batch job (%s) description retry (%d of %d) in the next %.2f seconds",
job_id,
retries,
self.status_retries,
pause,
)
await self.delay(pause)

async def poll_job_status(self, job_id: str, match_status: list[str]) -> bool: # type: ignore[override]
"""
Poll for job status using an exponential back-off strategy (with max_retries).
The Batch job status polled are:
'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
:param job_id: a Batch job ID
:param match_status: a list of job status to match
:raises: AirflowException
"""
retries = 0
while True:
job = await self.get_job_description(job_id)
job_status = job.get("status")
self.log.info(
"AWS Batch job (%s) check status (%s) in %s",
job_id,
job_status,
match_status,
)
if job_status in match_status:
return True

if retries >= self.max_retries:
raise AirflowException(f"AWS Batch job ({job_id}) status checks exceed max_retries")

retries += 1
pause = self.exponential_delay(retries)
self.log.info(
"AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds",
job_id,
retries,
self.max_retries,
pause,
)
await self.delay(pause)
37 changes: 37 additions & 0 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
BatchJobQueueLink,
)
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from airflow.providers.amazon.aws.utils import trim_none_values

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,6 +72,7 @@ class BatchOperator(BaseOperator):
Override the region_name in connection (if provided)
:param tags: collection of tags to apply to the AWS Batch job submission
if None, no tags are submitted
:param deferrable: Run operator in the deferrable mode.
.. note::
Any custom waiters must return a waiter for these calls:
Expand Down Expand Up @@ -125,6 +127,7 @@ def __init__(
region_name: str | None = None,
tags: dict | None = None,
wait_for_completion: bool = True,
deferrable: bool = False,
**kwargs,
):

Expand All @@ -139,6 +142,8 @@ def __init__(
self.waiters = waiters
self.tags = tags or {}
self.wait_for_completion = wait_for_completion
self.deferrable = deferrable

self.hook = BatchClientHook(
max_retries=max_retries,
status_retries=status_retries,
Expand All @@ -154,11 +159,43 @@ def execute(self, context: Context):
"""
self.submit_job(context)

if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=BatchOperatorTrigger(
job_id=self.job_id,
job_name=self.job_name,
job_definition=self.job_definition,
job_queue=self.job_queue,
overrides=self.overrides,
array_properties=self.array_properties,
parameters=self.parameters,
waiters=self.waiters,
tags=self.tags,
max_retries=self.hook.max_retries,
status_retries=self.hook.status_retries,
aws_conn_id=self.hook.aws_conn_id,
region_name=self.hook.region_name,
),
method_name="execute_complete",
)

if self.wait_for_completion:
self.monitor_job(context)

return self.job_id

def execute_complete(self, context: Context, event: dict[str, Any]):
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if "status" in event and event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(event["message"])
return self.job_id

def on_kill(self):
response = self.hook.client.terminate_job(jobId=self.job_id, reason="Task killed by the user")
self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, response)
Expand Down
Loading

0 comments on commit 77c272e

Please sign in to comment.