diff --git a/airflow/providers/amazon/aws/operators/glue_crawler.py b/airflow/providers/amazon/aws/operators/glue_crawler.py index cd760a3fba4de..1ce73eea07163 100644 --- a/airflow/providers/amazon/aws/operators/glue_crawler.py +++ b/airflow/providers/amazon/aws/operators/glue_crawler.py @@ -17,22 +17,22 @@ # under the License. from __future__ import annotations -from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger from airflow.providers.amazon.aws.utils import validate_execute_complete_event +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook -class GlueCrawlerOperator(BaseOperator): +class GlueCrawlerOperator(AwsBaseOperator[GlueCrawlerHook]): """ Creates, updates and triggers an AWS Glue Crawler. @@ -45,45 +45,45 @@ class GlueCrawlerOperator(BaseOperator): :ref:`howto/operator:GlueCrawlerOperator` :param config: Configurations for the AWS Glue crawler - :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If - running Airflow in a distributed manner and aws_conn_id is None or - empty, then default boto3 configuration would be used (and must be - maintained on each worker node). :param poll_interval: Time (in seconds) to wait between two consecutive calls to check crawler status :param wait_for_completion: Whether to wait for crawl execution completion. (default: True) :param deferrable: If True, the operator will wait asynchronously for the crawl to complete. This implies waiting for completion. This mode requires aiobotocore module to be installed. (default: False) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("config",) + aws_hook_class = GlueCrawlerHook + + template_fields: Sequence[str] = aws_template_fields( + "config", + ) ui_color = "#ededed" def __init__( self, config, - aws_conn_id="aws_default", - region_name: str | None = None, poll_interval: int = 5, wait_for_completion: bool = True, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) - self.aws_conn_id = aws_conn_id self.poll_interval = poll_interval self.wait_for_completion = wait_for_completion self.deferrable = deferrable - self.region_name = region_name self.config = config - @cached_property - def hook(self) -> GlueCrawlerHook: - """Create and return a GlueCrawlerHook.""" - return GlueCrawlerHook(self.aws_conn_id, region_name=self.region_name) - - def execute(self, context: Context): + def execute(self, context: Context) -> str: """ Execute AWS Glue Crawler from Airflow. @@ -103,6 +103,9 @@ def execute(self, context: Context): crawler_name=crawler_name, waiter_delay=self.poll_interval, aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + botocore_config=self.botocore_config, ), method_name="execute_complete", ) diff --git a/airflow/providers/amazon/aws/sensors/glue_crawler.py b/airflow/providers/amazon/aws/sensors/glue_crawler.py index 47e87e082d910..ce35aef2cfb2d 100644 --- a/airflow/providers/amazon/aws/sensors/glue_crawler.py +++ b/airflow/providers/amazon/aws/sensors/glue_crawler.py @@ -17,20 +17,20 @@ # under the License. from __future__ import annotations -from functools import cached_property from typing import TYPE_CHECKING, Sequence from deprecated import deprecated from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook -from airflow.sensors.base import BaseSensorOperator +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -class GlueCrawlerSensor(BaseSensorOperator): +class GlueCrawlerSensor(AwsBaseSensor[GlueCrawlerHook]): """ Waits for an AWS Glue crawler to reach any of the statuses below. @@ -41,19 +41,27 @@ class GlueCrawlerSensor(BaseSensorOperator): :ref:`howto/sensor:GlueCrawlerSensor` :param crawler_name: The AWS Glue crawler unique name - :param aws_conn_id: aws connection to use, defaults to 'aws_default' - If this is None or empty then the default boto3 behaviour is used. If + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("crawler_name",) + aws_hook_class = GlueCrawlerHook - def __init__(self, *, crawler_name: str, aws_conn_id: str | None = "aws_default", **kwargs) -> None: + template_fields: Sequence[str] = aws_template_fields( + "crawler_name", + ) + + def __init__(self, *, crawler_name: str, **kwargs) -> None: super().__init__(**kwargs) self.crawler_name = crawler_name - self.aws_conn_id = aws_conn_id self.success_statuses = "SUCCEEDED" self.errored_statuses = ("FAILED", "CANCELLED") @@ -79,7 +87,3 @@ def poke(self, context: Context): def get_hook(self) -> GlueCrawlerHook: """Return a new or pre-existing GlueCrawlerHook.""" return self.hook - - @cached_property - def hook(self) -> GlueCrawlerHook: - return GlueCrawlerHook(aws_conn_id=self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/triggers/glue_crawler.py b/airflow/providers/amazon/aws/triggers/glue_crawler.py index f3d975f0a5a7c..6bd52bef6f554 100644 --- a/airflow/providers/amazon/aws/triggers/glue_crawler.py +++ b/airflow/providers/amazon/aws/triggers/glue_crawler.py @@ -43,6 +43,7 @@ def __init__( aws_conn_id: str | None = "aws_default", waiter_delay: int = 5, waiter_max_attempts: int = 1500, + **kwargs, ): if poll_interval is not None: warnings.warn( @@ -62,7 +63,13 @@ def __init__( waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, + **kwargs, ) def hook(self) -> AwsGenericHook: - return GlueCrawlerHook(aws_conn_id=self.aws_conn_id) + return GlueCrawlerHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) diff --git a/docs/apache-airflow-providers-amazon/operators/glue.rst b/docs/apache-airflow-providers-amazon/operators/glue.rst index 024d6708b6462..626fa49990811 100644 --- a/docs/apache-airflow-providers-amazon/operators/glue.rst +++ b/docs/apache-airflow-providers-amazon/operators/glue.rst @@ -29,6 +29,11 @@ Prerequisite Tasks .. include:: ../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + Operators --------- diff --git a/tests/providers/amazon/aws/operators/test_glue_crawler.py b/tests/providers/amazon/aws/operators/test_glue_crawler.py index 931f9bc7d1a1a..a7a38e78bb4a7 100644 --- a/tests/providers/amazon/aws/operators/test_glue_crawler.py +++ b/tests/providers/amazon/aws/operators/test_glue_crawler.py @@ -16,10 +16,19 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING, Generator from unittest import mock +import pytest +from moto import mock_aws + +from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook +from airflow.providers.amazon.aws.hooks.sts import StsHook from airflow.providers.amazon.aws.operators.glue_crawler import GlueCrawlerOperator +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection + mock_crawler_name = "test-crawler" mock_role_name = "test-role" mock_config = { @@ -81,20 +90,86 @@ class TestGlueCrawlerOperator: + @pytest.fixture + def mock_conn(self) -> Generator[BaseAwsConnection, None, None]: + with mock.patch.object(GlueCrawlerHook, "get_conn") as _conn: + _conn.create_crawler.return_value = mock_crawler_name + yield _conn + + @pytest.fixture + def crawler_hook(self) -> Generator[GlueCrawlerHook, None, None]: + with mock_aws(): + hook = GlueCrawlerHook(aws_conn_id="aws_default") + yield hook + def setup_method(self): - self.glue = GlueCrawlerOperator(task_id="test_glue_crawler_operator", config=mock_config) - - @mock.patch("airflow.providers.amazon.aws.operators.glue_crawler.GlueCrawlerHook") - def test_execute_without_failure(self, mock_hook): - mock_hook.return_value.has_crawler.return_value = True - self.glue.execute({}) - - mock_hook.assert_has_calls( - [ - mock.call("aws_default", region_name=None), - mock.call().has_crawler("test-crawler"), - mock.call().update_crawler(**mock_config), - mock.call().start_crawler(mock_crawler_name), - mock.call().wait_for_crawler_completion(crawler_name=mock_crawler_name, poll_interval=5), - ] + self.op = GlueCrawlerOperator(task_id="test_glue_crawler_operator", config=mock_config) + + def test_init(self): + op = GlueCrawlerOperator( + task_id="test_glue_crawler_operator", + aws_conn_id="fake-conn-id", + region_name="eu-west-2", + verify=True, + botocore_config={"read_timeout": 42}, + config=mock_config, ) + + assert op.hook.client_type == "glue" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "eu-west-2" + assert op.hook._verify is True + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = GlueCrawlerOperator(task_id="test_glue_crawler_operator", config=mock_config) + + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + @mock.patch.object(GlueCrawlerHook, "glue_client") + @mock.patch.object(StsHook, "get_account_number") + def test_execute_update_and_start_crawler(self, sts_mock, mock_glue_client): + sts_mock.get_account_number.return_value = 123456789012 + mock_glue_client.get_crawler.return_value = {"Crawler": {}} + self.op.wait_for_completion = False + crawler_name = self.op.execute({}) + + mock_glue_client.get_crawler.call_count = 2 + mock_glue_client.update_crawler.call_count = 1 + mock_glue_client.start_crawler.call_count = 1 + assert crawler_name == mock_crawler_name + + @mock.patch.object(GlueCrawlerHook, "has_crawler") + @mock.patch.object(GlueCrawlerHook, "glue_client") + def test_execute_create_and_start_crawler(self, mock_glue_client, mock_has_crawler): + mock_has_crawler.return_value = False + mock_glue_client.create_crawler.return_value = {} + self.op.wait_for_completion = False + crawler_name = self.op.execute({}) + + assert crawler_name == mock_crawler_name + mock_glue_client.create_crawler.assert_called_once() + + @pytest.mark.parametrize( + "wait_for_completion, deferrable", + [ + pytest.param(False, False, id="no_wait"), + pytest.param(True, False, id="wait"), + pytest.param(False, True, id="defer"), + ], + ) + @mock.patch.object(GlueCrawlerHook, "get_waiter") + def test_crawler_wait_combinations(self, _, wait_for_completion, deferrable, mock_conn, crawler_hook): + self.op.defer = mock.MagicMock() + self.op.wait_for_completion = wait_for_completion + self.op.deferrable = deferrable + + response = self.op.execute({}) + + assert response == mock_crawler_name + assert crawler_hook.get_waiter.call_count == wait_for_completion + assert self.op.defer.call_count == deferrable diff --git a/tests/providers/amazon/aws/sensors/test_glue_crawler.py b/tests/providers/amazon/aws/sensors/test_glue_crawler.py index 9eab5e35e3d80..83b3795a1ae13 100644 --- a/tests/providers/amazon/aws/sensors/test_glue_crawler.py +++ b/tests/providers/amazon/aws/sensors/test_glue_crawler.py @@ -64,3 +64,28 @@ def test_fail_poke(self, get_crawler, soft_fail, expected_exception): message = f"Status: {crawler_status}" with pytest.raises(expected_exception, match=message): self.sensor.poke(context={}) + + def test_base_aws_op_attributes(self): + op = GlueCrawlerSensor( + task_id="test_glue_crawler_sensor", + crawler_name="aws_test_glue_crawler", + ) + assert op.hook.client_type == "glue" + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + op = GlueCrawlerSensor( + task_id="test_glue_crawler_sensor", + crawler_name="aws_test_glue_crawler", + aws_conn_id="aws-test-custom-conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.aws_conn_id == "aws-test-custom-conn" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 diff --git a/tests/providers/amazon/aws/triggers/test_glue_crawler.py b/tests/providers/amazon/aws/triggers/test_glue_crawler.py index 1ba2d610877db..dd46b1882a79d 100644 --- a/tests/providers/amazon/aws/triggers/test_glue_crawler.py +++ b/tests/providers/amazon/aws/triggers/test_glue_crawler.py @@ -16,9 +16,15 @@ # under the License. from __future__ import annotations -from unittest.mock import patch +from unittest import mock +from unittest.mock import AsyncMock, patch +import pytest + +from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger +from airflow.triggers.base import TriggerEvent +from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type class TestGlueCrawlerCompleteTrigger: @@ -47,3 +53,19 @@ def test_serialization(self, mock_warn): "waiter_max_attempts": 1500, "aws_conn_id": "aws_default", } + + @pytest.mark.asyncio + @mock.patch.object(GlueCrawlerHook, "get_waiter") + @mock.patch.object(GlueCrawlerHook, "async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() + crawler_name = "test_crawler" + trigger = GlueCrawlerCompleteTrigger(crawler_name=crawler_name) + + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "value": None}) + assert_expected_waiter_type(mock_get_waiter, "crawler_ready") + mock_get_waiter().wait.assert_called_once()