From f03dc5a5e7e90323099bf6eccab14452df2e7fc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= <vandonr@amazon.com> Date: Thu, 23 Feb 2023 10:05:14 -0800 Subject: [PATCH] add deferrable mode to glue operator --- .../providers/amazon/aws/hooks/base_aws.py | 2 +- airflow/providers/amazon/aws/hooks/glue.py | 87 ++++++++++++++----- .../providers/amazon/aws/operators/glue.py | 22 ++++- airflow/providers/amazon/aws/triggers/glue.py | 62 +++++++++++++ tests/providers/amazon/aws/hooks/test_glue.py | 63 ++++++++++++++ .../amazon/aws/operators/test_glue.py | 22 +++++ .../amazon/aws/triggers/test_glue.py | 69 +++++++++++++++ 7 files changed, 304 insertions(+), 23 deletions(-) create mode 100644 airflow/providers/amazon/aws/triggers/glue.py create mode 100644 tests/providers/amazon/aws/triggers/test_glue.py diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 096ef038f3d8c..4d15945b3c5f8 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -603,7 +603,7 @@ def get_client_type( """Get the underlying boto3 client using boto3 session""" client_type = self.client_type session = self.get_session(region_name=region_name, deferrable=deferrable) - if not isinstance(session, boto3.session.Session): + if isinstance(session, AioSession): return session.create_client( client_type, endpoint_url=self.conn_config.endpoint_url, diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py index 6f4ed8342e072..f811980210afe 100644 --- a/airflow/providers/amazon/aws/hooks/glue.py +++ b/airflow/providers/amazon/aws/hooks/glue.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import asyncio import time import boto3 @@ -194,6 +195,12 @@ def get_job_state(self, job_name: str, run_id: str) -> str: job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True) return job_run["JobRun"]["JobRunState"] + async def async_get_job_state(self, job_name: str, run_id: str) -> str: + """The async version of get_job_state.""" + async with self.async_conn as client: + job_run = await client.get_job_run(JobName=job_name, RunId=run_id) + return job_run["JobRun"]["JobRunState"] + def print_job_logs( self, job_name: str, @@ -264,33 +271,71 @@ def job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> d :param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False) :return: Dict of JobRunState and JobRunId """ - failed_states = ["FAILED", "TIMEOUT"] - finished_states = ["SUCCEEDED", "STOPPED"] next_log_tokens = self.LogContinuationTokens() while True: - if verbose: - self.print_job_logs( - job_name=job_name, - run_id=run_id, - continuation_tokens=next_log_tokens, - ) - job_run_state = self.get_job_state(job_name, run_id) - if job_run_state in finished_states: - self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state) - return {"JobRunState": job_run_state, "JobRunId": run_id} - if job_run_state in failed_states: - job_error_message = f"Exiting Job {run_id} Run State: {job_run_state}" - self.log.info(job_error_message) - raise AirflowException(job_error_message) + ret = self._handle_state(job_run_state, job_name, run_id, verbose, next_log_tokens) + if ret: + return ret else: - self.log.info( - "Polling for AWS Glue Job %s current run state with status %s", - job_name, - job_run_state, - ) time.sleep(self.JOB_POLL_INTERVAL) + async def async_job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> dict[str, str]: + """ + Waits until Glue job with job_name completes or fails and return final state if finished. + Raises AirflowException when the job failed. + + :param job_name: unique job name per AWS account + :param run_id: The job-run ID of the predecessor job run + :param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False) + :return: Dict of JobRunState and JobRunId + """ + next_log_tokens = self.LogContinuationTokens() + while True: + job_run_state = await self.async_get_job_state(job_name, run_id) + ret = self._handle_state(job_run_state, job_name, run_id, verbose, next_log_tokens) + if ret: + return ret + else: + await asyncio.sleep(self.JOB_POLL_INTERVAL) + + def _handle_state( + self, + state: str, + job_name: str, + run_id: str, + verbose: bool, + next_log_tokens: GlueJobHook.LogContinuationTokens, + ) -> dict | None: + """ + This method is here mostly to avoid duplicating code between the sync and async methods calling it. + It doesn't really have a business logic. + """ + failed_states = ["FAILED", "TIMEOUT"] + finished_states = ["SUCCEEDED", "STOPPED"] + + if verbose: + self.print_job_logs( + job_name=job_name, + run_id=run_id, + continuation_tokens=next_log_tokens, + ) + + if state in finished_states: + self.log.info("Exiting Job %s Run State: %s", run_id, state) + return {"JobRunState": state, "JobRunId": run_id} + if state in failed_states: + job_error_message = f"Exiting Job {run_id} Run State: {state}" + self.log.info(job_error_message) + raise AirflowException(job_error_message) + else: + self.log.info( + "Polling for AWS Glue Job %s current run state with status %s", + job_name, + state, + ) + return None + def has_job(self, job_name) -> bool: """ Checks if the job already exists. diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index 497df84d31418..794bfde9be598 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -21,10 +21,12 @@ import urllib.parse from typing import TYPE_CHECKING, Sequence +from airflow import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink +from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -93,6 +95,7 @@ def __init__( wait_for_completion: bool = True, verbose: bool = False, update_config: bool = False, + deferrable: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -114,6 +117,7 @@ def __init__( self.wait_for_completion = wait_for_completion self.verbose = verbose self.update_config = update_config + self.deferrable = deferrable def execute(self, context: Context): """ @@ -167,7 +171,18 @@ def execute(self, context: Context): job_run_id=glue_job_run["JobRunId"], ) self.log.info("You can monitor this Glue Job run at: %s", glue_job_run_url) - if self.wait_for_completion: + + if self.deferrable: + self.defer( + trigger=GlueJobCompleteTrigger( + job_name=self.job_name, + run_id=glue_job_run["JobRunId"], + verbose=self.verbose, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + elif self.wait_for_completion: glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"], self.verbose) self.log.info( "AWS Glue Job: %s status: %s. Run Id: %s", @@ -178,3 +193,8 @@ def execute(self, context: Context): else: self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run["JobRunId"]) return glue_job_run["JobRunId"] + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error in glue job: {event}") + return diff --git a/airflow/providers/amazon/aws/triggers/glue.py b/airflow/providers/amazon/aws/triggers/glue.py new file mode 100644 index 0000000000000..bfcf3b4d370e6 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/glue.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any + +from airflow.providers.amazon.aws.hooks.glue import GlueJobHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class GlueJobCompleteTrigger(BaseTrigger): + """ + Watches for a glue job, triggers when it finishes + + :param job_name: glue job name + :param run_id: the ID of the specific run to watch for that job + :param verbose: whether to print the job's logs in airflow logs or not + :param aws_conn_id: You know what this is + """ + + def __init__( + self, + job_name: str, + run_id: str, + verbose: bool, + aws_conn_id: str, + ): + self.job_name = job_name + self.run_id = run_id + self.verbose = verbose + self.aws_conn_id = aws_conn_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__qualname__, + { + "job_name": self.job_name, + "run_id": self.run_id, + "verbose": str(self.verbose), + "aws_conn_id": self.aws_conn_id, + }, + ) + + async def run(self): + hook = GlueJobHook(aws_conn_id=self.aws_conn_id) + await hook.async_job_completion(self.job_name, self.run_id, self.verbose) + return TriggerEvent({"status": "success", "message": "Job done"}) diff --git a/tests/providers/amazon/aws/hooks/test_glue.py b/tests/providers/amazon/aws/hooks/test_glue.py index 9a46abb34fdb9..1497affd8622a 100644 --- a/tests/providers/amazon/aws/hooks/test_glue.py +++ b/tests/providers/amazon/aws/hooks/test_glue.py @@ -26,6 +26,7 @@ from botocore.exceptions import ClientError from moto import mock_glue, mock_iam +from airflow import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.glue import GlueJobHook @@ -349,3 +350,65 @@ def test_print_job_logs_no_stream_yet(self, conn_mock: MagicMock, client_mock: M assert tokens.output_stream_continuation is None assert tokens.error_stream_continuation is None assert client_mock().get_paginator().paginate.call_count == 2 + + @mock.patch.object(GlueJobHook, "get_job_state") + def test_job_completion_success(self, get_state_mock: MagicMock): + hook = GlueJobHook() + hook.JOB_POLL_INTERVAL = 0 + get_state_mock.side_effect = [ + "RUNNING", + "RUNNING", + "SUCCEEDED", + ] + + hook.job_completion("job_name", "run_id") + + assert get_state_mock.call_count == 3 + get_state_mock.assert_called_with("job_name", "run_id") + + @mock.patch.object(GlueJobHook, "get_job_state") + def test_job_completion_failure(self, get_state_mock: MagicMock): + hook = GlueJobHook() + hook.JOB_POLL_INTERVAL = 0 + get_state_mock.side_effect = [ + "RUNNING", + "RUNNING", + "FAILED", + ] + + with pytest.raises(AirflowException): + hook.job_completion("job_name", "run_id") + + assert get_state_mock.call_count == 3 + + @pytest.mark.asyncio + @mock.patch.object(GlueJobHook, "async_get_job_state") + async def test_async_job_completion_success(self, get_state_mock: MagicMock): + hook = GlueJobHook() + hook.JOB_POLL_INTERVAL = 0 + get_state_mock.side_effect = [ + "RUNNING", + "RUNNING", + "SUCCEEDED", + ] + + await hook.async_job_completion("job_name", "run_id") + + assert get_state_mock.call_count == 3 + get_state_mock.assert_called_with("job_name", "run_id") + + @pytest.mark.asyncio + @mock.patch.object(GlueJobHook, "async_get_job_state") + async def test_async_job_completion_failure(self, get_state_mock: MagicMock): + hook = GlueJobHook() + hook.JOB_POLL_INTERVAL = 0 + get_state_mock.side_effect = [ + "RUNNING", + "RUNNING", + "FAILED", + ] + + with pytest.raises(AirflowException): + await hook.async_job_completion("job_name", "run_id") + + assert get_state_mock.call_count == 3 diff --git a/tests/providers/amazon/aws/operators/test_glue.py b/tests/providers/amazon/aws/operators/test_glue.py index db5ff1e6c232f..03b5e154f47e4 100644 --- a/tests/providers/amazon/aws/operators/test_glue.py +++ b/tests/providers/amazon/aws/operators/test_glue.py @@ -21,6 +21,7 @@ import pytest from airflow.configuration import conf +from airflow.exceptions import TaskDeferred from airflow.models import TaskInstance from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -98,6 +99,27 @@ def test_execute_without_failure( mock_print_job_logs.assert_not_called() assert glue.job_name == JOB_NAME + @mock.patch.object(GlueJobHook, "initialize_job") + @mock.patch.object(GlueJobHook, "get_conn") + def test_execute_deferrable(self, _, mock_initialize_job): + glue = GlueJobOperator( + task_id=TASK_ID, + job_name=JOB_NAME, + script_location="s3://folder/file", + aws_conn_id="aws_default", + region_name="us-west-2", + s3_bucket="some_bucket", + iam_role_name="my_test_role", + deferrable=True, + ) + mock_initialize_job.return_value = {"JobRunState": "RUNNING", "JobRunId": JOB_RUN_ID} + + with pytest.raises(TaskDeferred) as defer: + glue.execute(mock.MagicMock()) + + assert defer.value.trigger.job_name == JOB_NAME + assert defer.value.trigger.run_id == JOB_RUN_ID + @mock.patch.object(GlueJobHook, "print_job_logs") @mock.patch.object(GlueJobHook, "get_job_state") @mock.patch.object(GlueJobHook, "initialize_job") diff --git a/tests/providers/amazon/aws/triggers/test_glue.py b/tests/providers/amazon/aws/triggers/test_glue.py new file mode 100644 index 0000000000000..70cc7e38d2091 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_glue.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest +from asynctest import MagicMock, mock + +from airflow import AirflowException +from airflow.providers.amazon.aws.hooks.glue import GlueJobHook +from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger + + +@pytest.mark.asyncio +@mock.patch.object(GlueJobHook, "async_get_job_state") +async def test_wait_job(get_state_mock: MagicMock): + GlueJobHook.JOB_POLL_INTERVAL = 0.1 + trigger = GlueJobCompleteTrigger( + job_name="job_name", + run_id="JobRunId", + verbose=False, + aws_conn_id="aws_conn_id", + ) + get_state_mock.side_effect = [ + "RUNNING", + "RUNNING", + "SUCCEEDED", + ] + + event = await trigger.run() + + assert get_state_mock.call_count == 3 + assert event.payload["status"] == "success" + + +@pytest.mark.asyncio +@mock.patch.object(GlueJobHook, "async_get_job_state") +async def test_wait_job_failed(get_state_mock: MagicMock): + GlueJobHook.JOB_POLL_INTERVAL = 0.1 + trigger = GlueJobCompleteTrigger( + job_name="job_name", + run_id="JobRunId", + verbose=False, + aws_conn_id="aws_conn_id", + ) + get_state_mock.side_effect = [ + "RUNNING", + "RUNNING", + "FAILED", + ] + + with pytest.raises(AirflowException): + await trigger.run() + + assert get_state_mock.call_count == 3