Skip to content

Commit

Permalink
add deferrable mode to glue operator
Browse files Browse the repository at this point in the history
  • Loading branch information
vandonr-amz committed Apr 28, 2023
1 parent 783aa9c commit f03dc5a
Show file tree
Hide file tree
Showing 7 changed files with 304 additions and 23 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def get_client_type(
"""Get the underlying boto3 client using boto3 session"""
client_type = self.client_type
session = self.get_session(region_name=region_name, deferrable=deferrable)
if not isinstance(session, boto3.session.Session):
if isinstance(session, AioSession):
return session.create_client(
client_type,
endpoint_url=self.conn_config.endpoint_url,
Expand Down
87 changes: 66 additions & 21 deletions airflow/providers/amazon/aws/hooks/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import asyncio
import time

import boto3
Expand Down Expand Up @@ -194,6 +195,12 @@ def get_job_state(self, job_name: str, run_id: str) -> str:
job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True)
return job_run["JobRun"]["JobRunState"]

async def async_get_job_state(self, job_name: str, run_id: str) -> str:
"""The async version of get_job_state."""
async with self.async_conn as client:
job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
return job_run["JobRun"]["JobRunState"]

def print_job_logs(
self,
job_name: str,
Expand Down Expand Up @@ -264,33 +271,71 @@ def job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> d
:param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False)
:return: Dict of JobRunState and JobRunId
"""
failed_states = ["FAILED", "TIMEOUT"]
finished_states = ["SUCCEEDED", "STOPPED"]
next_log_tokens = self.LogContinuationTokens()
while True:
if verbose:
self.print_job_logs(
job_name=job_name,
run_id=run_id,
continuation_tokens=next_log_tokens,
)

job_run_state = self.get_job_state(job_name, run_id)
if job_run_state in finished_states:
self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state)
return {"JobRunState": job_run_state, "JobRunId": run_id}
if job_run_state in failed_states:
job_error_message = f"Exiting Job {run_id} Run State: {job_run_state}"
self.log.info(job_error_message)
raise AirflowException(job_error_message)
ret = self._handle_state(job_run_state, job_name, run_id, verbose, next_log_tokens)
if ret:
return ret
else:
self.log.info(
"Polling for AWS Glue Job %s current run state with status %s",
job_name,
job_run_state,
)
time.sleep(self.JOB_POLL_INTERVAL)

async def async_job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> dict[str, str]:
"""
Waits until Glue job with job_name completes or fails and return final state if finished.
Raises AirflowException when the job failed.
:param job_name: unique job name per AWS account
:param run_id: The job-run ID of the predecessor job run
:param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False)
:return: Dict of JobRunState and JobRunId
"""
next_log_tokens = self.LogContinuationTokens()
while True:
job_run_state = await self.async_get_job_state(job_name, run_id)
ret = self._handle_state(job_run_state, job_name, run_id, verbose, next_log_tokens)
if ret:
return ret
else:
await asyncio.sleep(self.JOB_POLL_INTERVAL)

def _handle_state(
self,
state: str,
job_name: str,
run_id: str,
verbose: bool,
next_log_tokens: GlueJobHook.LogContinuationTokens,
) -> dict | None:
"""
This method is here mostly to avoid duplicating code between the sync and async methods calling it.
It doesn't really have a business logic.
"""
failed_states = ["FAILED", "TIMEOUT"]
finished_states = ["SUCCEEDED", "STOPPED"]

if verbose:
self.print_job_logs(
job_name=job_name,
run_id=run_id,
continuation_tokens=next_log_tokens,
)

if state in finished_states:
self.log.info("Exiting Job %s Run State: %s", run_id, state)
return {"JobRunState": state, "JobRunId": run_id}
if state in failed_states:
job_error_message = f"Exiting Job {run_id} Run State: {state}"
self.log.info(job_error_message)
raise AirflowException(job_error_message)
else:
self.log.info(
"Polling for AWS Glue Job %s current run state with status %s",
job_name,
state,
)
return None

def has_job(self, job_name) -> bool:
"""
Checks if the job already exists.
Expand Down
22 changes: 21 additions & 1 deletion airflow/providers/amazon/aws/operators/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import urllib.parse
from typing import TYPE_CHECKING, Sequence

from airflow import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink
from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -93,6 +95,7 @@ def __init__(
wait_for_completion: bool = True,
verbose: bool = False,
update_config: bool = False,
deferrable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -114,6 +117,7 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.verbose = verbose
self.update_config = update_config
self.deferrable = deferrable

def execute(self, context: Context):
"""
Expand Down Expand Up @@ -167,7 +171,18 @@ def execute(self, context: Context):
job_run_id=glue_job_run["JobRunId"],
)
self.log.info("You can monitor this Glue Job run at: %s", glue_job_run_url)
if self.wait_for_completion:

if self.deferrable:
self.defer(
trigger=GlueJobCompleteTrigger(
job_name=self.job_name,
run_id=glue_job_run["JobRunId"],
verbose=self.verbose,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
)
elif self.wait_for_completion:
glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"], self.verbose)
self.log.info(
"AWS Glue Job: %s status: %s. Run Id: %s",
Expand All @@ -178,3 +193,8 @@ def execute(self, context: Context):
else:
self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run["JobRunId"])
return glue_job_run["JobRunId"]

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error in glue job: {event}")
return
62 changes: 62 additions & 0 deletions airflow/providers/amazon/aws/triggers/glue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Any

from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


class GlueJobCompleteTrigger(BaseTrigger):
"""
Watches for a glue job, triggers when it finishes
:param job_name: glue job name
:param run_id: the ID of the specific run to watch for that job
:param verbose: whether to print the job's logs in airflow logs or not
:param aws_conn_id: You know what this is
"""

def __init__(
self,
job_name: str,
run_id: str,
verbose: bool,
aws_conn_id: str,
):
self.job_name = job_name
self.run_id = run_id
self.verbose = verbose
self.aws_conn_id = aws_conn_id

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
self.__class__.__qualname__,
{
"job_name": self.job_name,
"run_id": self.run_id,
"verbose": str(self.verbose),
"aws_conn_id": self.aws_conn_id,
},
)

async def run(self):
hook = GlueJobHook(aws_conn_id=self.aws_conn_id)
await hook.async_job_completion(self.job_name, self.run_id, self.verbose)
return TriggerEvent({"status": "success", "message": "Job done"})
63 changes: 63 additions & 0 deletions tests/providers/amazon/aws/hooks/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from botocore.exceptions import ClientError
from moto import mock_glue, mock_iam

from airflow import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook

Expand Down Expand Up @@ -349,3 +350,65 @@ def test_print_job_logs_no_stream_yet(self, conn_mock: MagicMock, client_mock: M
assert tokens.output_stream_continuation is None
assert tokens.error_stream_continuation is None
assert client_mock().get_paginator().paginate.call_count == 2

@mock.patch.object(GlueJobHook, "get_job_state")
def test_job_completion_success(self, get_state_mock: MagicMock):
hook = GlueJobHook()
hook.JOB_POLL_INTERVAL = 0
get_state_mock.side_effect = [
"RUNNING",
"RUNNING",
"SUCCEEDED",
]

hook.job_completion("job_name", "run_id")

assert get_state_mock.call_count == 3
get_state_mock.assert_called_with("job_name", "run_id")

@mock.patch.object(GlueJobHook, "get_job_state")
def test_job_completion_failure(self, get_state_mock: MagicMock):
hook = GlueJobHook()
hook.JOB_POLL_INTERVAL = 0
get_state_mock.side_effect = [
"RUNNING",
"RUNNING",
"FAILED",
]

with pytest.raises(AirflowException):
hook.job_completion("job_name", "run_id")

assert get_state_mock.call_count == 3

@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "async_get_job_state")
async def test_async_job_completion_success(self, get_state_mock: MagicMock):
hook = GlueJobHook()
hook.JOB_POLL_INTERVAL = 0
get_state_mock.side_effect = [
"RUNNING",
"RUNNING",
"SUCCEEDED",
]

await hook.async_job_completion("job_name", "run_id")

assert get_state_mock.call_count == 3
get_state_mock.assert_called_with("job_name", "run_id")

@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "async_get_job_state")
async def test_async_job_completion_failure(self, get_state_mock: MagicMock):
hook = GlueJobHook()
hook.JOB_POLL_INTERVAL = 0
get_state_mock.side_effect = [
"RUNNING",
"RUNNING",
"FAILED",
]

with pytest.raises(AirflowException):
await hook.async_job_completion("job_name", "run_id")

assert get_state_mock.call_count == 3
22 changes: 22 additions & 0 deletions tests/providers/amazon/aws/operators/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

from airflow.configuration import conf
from airflow.exceptions import TaskDeferred
from airflow.models import TaskInstance
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand Down Expand Up @@ -98,6 +99,27 @@ def test_execute_without_failure(
mock_print_job_logs.assert_not_called()
assert glue.job_name == JOB_NAME

@mock.patch.object(GlueJobHook, "initialize_job")
@mock.patch.object(GlueJobHook, "get_conn")
def test_execute_deferrable(self, _, mock_initialize_job):
glue = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
script_location="s3://folder/file",
aws_conn_id="aws_default",
region_name="us-west-2",
s3_bucket="some_bucket",
iam_role_name="my_test_role",
deferrable=True,
)
mock_initialize_job.return_value = {"JobRunState": "RUNNING", "JobRunId": JOB_RUN_ID}

with pytest.raises(TaskDeferred) as defer:
glue.execute(mock.MagicMock())

assert defer.value.trigger.job_name == JOB_NAME
assert defer.value.trigger.run_id == JOB_RUN_ID

@mock.patch.object(GlueJobHook, "print_job_logs")
@mock.patch.object(GlueJobHook, "get_job_state")
@mock.patch.object(GlueJobHook, "initialize_job")
Expand Down
Loading

0 comments on commit f03dc5a

Please sign in to comment.