From f03dc5a5e7e90323099bf6eccab14452df2e7fc3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= <vandonr@amazon.com>
Date: Thu, 23 Feb 2023 10:05:14 -0800
Subject: [PATCH] add deferrable mode to glue operator

---
 .../providers/amazon/aws/hooks/base_aws.py    |  2 +-
 airflow/providers/amazon/aws/hooks/glue.py    | 87 ++++++++++++++-----
 .../providers/amazon/aws/operators/glue.py    | 22 ++++-
 airflow/providers/amazon/aws/triggers/glue.py | 62 +++++++++++++
 tests/providers/amazon/aws/hooks/test_glue.py | 63 ++++++++++++++
 .../amazon/aws/operators/test_glue.py         | 22 +++++
 .../amazon/aws/triggers/test_glue.py          | 69 +++++++++++++++
 7 files changed, 304 insertions(+), 23 deletions(-)
 create mode 100644 airflow/providers/amazon/aws/triggers/glue.py
 create mode 100644 tests/providers/amazon/aws/triggers/test_glue.py

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index 096ef038f3d8c..4d15945b3c5f8 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -603,7 +603,7 @@ def get_client_type(
         """Get the underlying boto3 client using boto3 session"""
         client_type = self.client_type
         session = self.get_session(region_name=region_name, deferrable=deferrable)
-        if not isinstance(session, boto3.session.Session):
+        if isinstance(session, AioSession):
             return session.create_client(
                 client_type,
                 endpoint_url=self.conn_config.endpoint_url,
diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py
index 6f4ed8342e072..f811980210afe 100644
--- a/airflow/providers/amazon/aws/hooks/glue.py
+++ b/airflow/providers/amazon/aws/hooks/glue.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import asyncio
 import time
 
 import boto3
@@ -194,6 +195,12 @@ def get_job_state(self, job_name: str, run_id: str) -> str:
         job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True)
         return job_run["JobRun"]["JobRunState"]
 
+    async def async_get_job_state(self, job_name: str, run_id: str) -> str:
+        """The async version of get_job_state."""
+        async with self.async_conn as client:
+            job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
+        return job_run["JobRun"]["JobRunState"]
+
     def print_job_logs(
         self,
         job_name: str,
@@ -264,33 +271,71 @@ def job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> d
         :param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs.  (default: False)
         :return: Dict of JobRunState and JobRunId
         """
-        failed_states = ["FAILED", "TIMEOUT"]
-        finished_states = ["SUCCEEDED", "STOPPED"]
         next_log_tokens = self.LogContinuationTokens()
         while True:
-            if verbose:
-                self.print_job_logs(
-                    job_name=job_name,
-                    run_id=run_id,
-                    continuation_tokens=next_log_tokens,
-                )
-
             job_run_state = self.get_job_state(job_name, run_id)
-            if job_run_state in finished_states:
-                self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state)
-                return {"JobRunState": job_run_state, "JobRunId": run_id}
-            if job_run_state in failed_states:
-                job_error_message = f"Exiting Job {run_id} Run State: {job_run_state}"
-                self.log.info(job_error_message)
-                raise AirflowException(job_error_message)
+            ret = self._handle_state(job_run_state, job_name, run_id, verbose, next_log_tokens)
+            if ret:
+                return ret
             else:
-                self.log.info(
-                    "Polling for AWS Glue Job %s current run state with status %s",
-                    job_name,
-                    job_run_state,
-                )
                 time.sleep(self.JOB_POLL_INTERVAL)
 
+    async def async_job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> dict[str, str]:
+        """
+        Waits until Glue job with job_name completes or fails and return final state if finished.
+        Raises AirflowException when the job failed.
+
+        :param job_name: unique job name per AWS account
+        :param run_id: The job-run ID of the predecessor job run
+        :param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs.  (default: False)
+        :return: Dict of JobRunState and JobRunId
+        """
+        next_log_tokens = self.LogContinuationTokens()
+        while True:
+            job_run_state = await self.async_get_job_state(job_name, run_id)
+            ret = self._handle_state(job_run_state, job_name, run_id, verbose, next_log_tokens)
+            if ret:
+                return ret
+            else:
+                await asyncio.sleep(self.JOB_POLL_INTERVAL)
+
+    def _handle_state(
+        self,
+        state: str,
+        job_name: str,
+        run_id: str,
+        verbose: bool,
+        next_log_tokens: GlueJobHook.LogContinuationTokens,
+    ) -> dict | None:
+        """
+        This method is here mostly to avoid duplicating code between the sync and async methods calling it.
+        It doesn't really have a business logic.
+        """
+        failed_states = ["FAILED", "TIMEOUT"]
+        finished_states = ["SUCCEEDED", "STOPPED"]
+
+        if verbose:
+            self.print_job_logs(
+                job_name=job_name,
+                run_id=run_id,
+                continuation_tokens=next_log_tokens,
+            )
+
+        if state in finished_states:
+            self.log.info("Exiting Job %s Run State: %s", run_id, state)
+            return {"JobRunState": state, "JobRunId": run_id}
+        if state in failed_states:
+            job_error_message = f"Exiting Job {run_id} Run State: {state}"
+            self.log.info(job_error_message)
+            raise AirflowException(job_error_message)
+        else:
+            self.log.info(
+                "Polling for AWS Glue Job %s current run state with status %s",
+                job_name,
+                state,
+            )
+            return None
+
     def has_job(self, job_name) -> bool:
         """
         Checks if the job already exists.
diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py
index 497df84d31418..794bfde9be598 100644
--- a/airflow/providers/amazon/aws/operators/glue.py
+++ b/airflow/providers/amazon/aws/operators/glue.py
@@ -21,10 +21,12 @@
 import urllib.parse
 from typing import TYPE_CHECKING, Sequence
 
+from airflow import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
 from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink
+from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -93,6 +95,7 @@ def __init__(
         wait_for_completion: bool = True,
         verbose: bool = False,
         update_config: bool = False,
+        deferrable: bool = False,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -114,6 +117,7 @@ def __init__(
         self.wait_for_completion = wait_for_completion
         self.verbose = verbose
         self.update_config = update_config
+        self.deferrable = deferrable
 
     def execute(self, context: Context):
         """
@@ -167,7 +171,18 @@ def execute(self, context: Context):
             job_run_id=glue_job_run["JobRunId"],
         )
         self.log.info("You can monitor this Glue Job run at: %s", glue_job_run_url)
-        if self.wait_for_completion:
+
+        if self.deferrable:
+            self.defer(
+                trigger=GlueJobCompleteTrigger(
+                    job_name=self.job_name,
+                    run_id=glue_job_run["JobRunId"],
+                    verbose=self.verbose,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="execute_complete",
+            )
+        elif self.wait_for_completion:
             glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"], self.verbose)
             self.log.info(
                 "AWS Glue Job: %s status: %s. Run Id: %s",
@@ -178,3 +193,8 @@ def execute(self, context: Context):
         else:
             self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run["JobRunId"])
         return glue_job_run["JobRunId"]
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error in glue job: {event}")
+        return
diff --git a/airflow/providers/amazon/aws/triggers/glue.py b/airflow/providers/amazon/aws/triggers/glue.py
new file mode 100644
index 0000000000000..bfcf3b4d370e6
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/glue.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Any
+
+from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class GlueJobCompleteTrigger(BaseTrigger):
+    """
+    Watches for a glue job, triggers when it finishes
+
+    :param job_name: glue job name
+    :param run_id: the ID of the specific run to watch for that job
+    :param verbose: whether to print the job's logs in airflow logs or not
+    :param aws_conn_id: You know what this is
+    """
+
+    def __init__(
+        self,
+        job_name: str,
+        run_id: str,
+        verbose: bool,
+        aws_conn_id: str,
+    ):
+        self.job_name = job_name
+        self.run_id = run_id
+        self.verbose = verbose
+        self.aws_conn_id = aws_conn_id
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            self.__class__.__qualname__,
+            {
+                "job_name": self.job_name,
+                "run_id": self.run_id,
+                "verbose": str(self.verbose),
+                "aws_conn_id": self.aws_conn_id,
+            },
+        )
+
+    async def run(self):
+        hook = GlueJobHook(aws_conn_id=self.aws_conn_id)
+        await hook.async_job_completion(self.job_name, self.run_id, self.verbose)
+        return TriggerEvent({"status": "success", "message": "Job done"})
diff --git a/tests/providers/amazon/aws/hooks/test_glue.py b/tests/providers/amazon/aws/hooks/test_glue.py
index 9a46abb34fdb9..1497affd8622a 100644
--- a/tests/providers/amazon/aws/hooks/test_glue.py
+++ b/tests/providers/amazon/aws/hooks/test_glue.py
@@ -26,6 +26,7 @@
 from botocore.exceptions import ClientError
 from moto import mock_glue, mock_iam
 
+from airflow import AirflowException
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
 
@@ -349,3 +350,65 @@ def test_print_job_logs_no_stream_yet(self, conn_mock: MagicMock, client_mock: M
         assert tokens.output_stream_continuation is None
         assert tokens.error_stream_continuation is None
         assert client_mock().get_paginator().paginate.call_count == 2
+
+    @mock.patch.object(GlueJobHook, "get_job_state")
+    def test_job_completion_success(self, get_state_mock: MagicMock):
+        hook = GlueJobHook()
+        hook.JOB_POLL_INTERVAL = 0
+        get_state_mock.side_effect = [
+            "RUNNING",
+            "RUNNING",
+            "SUCCEEDED",
+        ]
+
+        hook.job_completion("job_name", "run_id")
+
+        assert get_state_mock.call_count == 3
+        get_state_mock.assert_called_with("job_name", "run_id")
+
+    @mock.patch.object(GlueJobHook, "get_job_state")
+    def test_job_completion_failure(self, get_state_mock: MagicMock):
+        hook = GlueJobHook()
+        hook.JOB_POLL_INTERVAL = 0
+        get_state_mock.side_effect = [
+            "RUNNING",
+            "RUNNING",
+            "FAILED",
+        ]
+
+        with pytest.raises(AirflowException):
+            hook.job_completion("job_name", "run_id")
+
+        assert get_state_mock.call_count == 3
+
+    @pytest.mark.asyncio
+    @mock.patch.object(GlueJobHook, "async_get_job_state")
+    async def test_async_job_completion_success(self, get_state_mock: MagicMock):
+        hook = GlueJobHook()
+        hook.JOB_POLL_INTERVAL = 0
+        get_state_mock.side_effect = [
+            "RUNNING",
+            "RUNNING",
+            "SUCCEEDED",
+        ]
+
+        await hook.async_job_completion("job_name", "run_id")
+
+        assert get_state_mock.call_count == 3
+        get_state_mock.assert_called_with("job_name", "run_id")
+
+    @pytest.mark.asyncio
+    @mock.patch.object(GlueJobHook, "async_get_job_state")
+    async def test_async_job_completion_failure(self, get_state_mock: MagicMock):
+        hook = GlueJobHook()
+        hook.JOB_POLL_INTERVAL = 0
+        get_state_mock.side_effect = [
+            "RUNNING",
+            "RUNNING",
+            "FAILED",
+        ]
+
+        with pytest.raises(AirflowException):
+            await hook.async_job_completion("job_name", "run_id")
+
+        assert get_state_mock.call_count == 3
diff --git a/tests/providers/amazon/aws/operators/test_glue.py b/tests/providers/amazon/aws/operators/test_glue.py
index db5ff1e6c232f..03b5e154f47e4 100644
--- a/tests/providers/amazon/aws/operators/test_glue.py
+++ b/tests/providers/amazon/aws/operators/test_glue.py
@@ -21,6 +21,7 @@
 import pytest
 
 from airflow.configuration import conf
+from airflow.exceptions import TaskDeferred
 from airflow.models import TaskInstance
 from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -98,6 +99,27 @@ def test_execute_without_failure(
         mock_print_job_logs.assert_not_called()
         assert glue.job_name == JOB_NAME
 
+    @mock.patch.object(GlueJobHook, "initialize_job")
+    @mock.patch.object(GlueJobHook, "get_conn")
+    def test_execute_deferrable(self, _, mock_initialize_job):
+        glue = GlueJobOperator(
+            task_id=TASK_ID,
+            job_name=JOB_NAME,
+            script_location="s3://folder/file",
+            aws_conn_id="aws_default",
+            region_name="us-west-2",
+            s3_bucket="some_bucket",
+            iam_role_name="my_test_role",
+            deferrable=True,
+        )
+        mock_initialize_job.return_value = {"JobRunState": "RUNNING", "JobRunId": JOB_RUN_ID}
+
+        with pytest.raises(TaskDeferred) as defer:
+            glue.execute(mock.MagicMock())
+
+        assert defer.value.trigger.job_name == JOB_NAME
+        assert defer.value.trigger.run_id == JOB_RUN_ID
+
     @mock.patch.object(GlueJobHook, "print_job_logs")
     @mock.patch.object(GlueJobHook, "get_job_state")
     @mock.patch.object(GlueJobHook, "initialize_job")
diff --git a/tests/providers/amazon/aws/triggers/test_glue.py b/tests/providers/amazon/aws/triggers/test_glue.py
new file mode 100644
index 0000000000000..70cc7e38d2091
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_glue.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+from asynctest import MagicMock, mock
+
+from airflow import AirflowException
+from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
+from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger
+
+
+@pytest.mark.asyncio
+@mock.patch.object(GlueJobHook, "async_get_job_state")
+async def test_wait_job(get_state_mock: MagicMock):
+    GlueJobHook.JOB_POLL_INTERVAL = 0.1
+    trigger = GlueJobCompleteTrigger(
+        job_name="job_name",
+        run_id="JobRunId",
+        verbose=False,
+        aws_conn_id="aws_conn_id",
+    )
+    get_state_mock.side_effect = [
+        "RUNNING",
+        "RUNNING",
+        "SUCCEEDED",
+    ]
+
+    event = await trigger.run()
+
+    assert get_state_mock.call_count == 3
+    assert event.payload["status"] == "success"
+
+
+@pytest.mark.asyncio
+@mock.patch.object(GlueJobHook, "async_get_job_state")
+async def test_wait_job_failed(get_state_mock: MagicMock):
+    GlueJobHook.JOB_POLL_INTERVAL = 0.1
+    trigger = GlueJobCompleteTrigger(
+        job_name="job_name",
+        run_id="JobRunId",
+        verbose=False,
+        aws_conn_id="aws_conn_id",
+    )
+    get_state_mock.side_effect = [
+        "RUNNING",
+        "RUNNING",
+        "FAILED",
+    ]
+
+    with pytest.raises(AirflowException):
+        await trigger.run()
+
+    assert get_state_mock.call_count == 3