Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use base aws classes in AWS Glue Crawlers Operators/Sensors/Triggers #40504

Merged
merged 2 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions airflow/providers/amazon/aws/operators/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook


class GlueCrawlerOperator(BaseOperator):
class GlueCrawlerOperator(AwsBaseOperator[GlueCrawlerHook]):
"""
Creates, updates and triggers an AWS Glue Crawler.

Expand All @@ -45,45 +45,45 @@ class GlueCrawlerOperator(BaseOperator):
:ref:`howto/operator:GlueCrawlerOperator`

:param config: Configurations for the AWS Glue crawler
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check crawler status
:param wait_for_completion: Whether to wait for crawl execution completion. (default: True)
:param deferrable: If True, the operator will wait asynchronously for the crawl to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("config",)
aws_hook_class = GlueCrawlerHook

template_fields: Sequence[str] = aws_template_fields(
"config",
)
ui_color = "#ededed"

def __init__(
self,
config,
aws_conn_id="aws_default",
region_name: str | None = None,
poll_interval: int = 5,
wait_for_completion: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.poll_interval = poll_interval
self.wait_for_completion = wait_for_completion
self.deferrable = deferrable
self.region_name = region_name
self.config = config

@cached_property
def hook(self) -> GlueCrawlerHook:
"""Create and return a GlueCrawlerHook."""
return GlueCrawlerHook(self.aws_conn_id, region_name=self.region_name)

def execute(self, context: Context):
def execute(self, context: Context) -> str:
"""
Execute AWS Glue Crawler from Airflow.

Expand All @@ -103,6 +103,9 @@ def execute(self, context: Context):
crawler_name=crawler_name,
waiter_delay=self.poll_interval,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
)
Expand Down
28 changes: 16 additions & 12 deletions airflow/providers/amazon/aws/sensors/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from deprecated import deprecated

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
from airflow.sensors.base import BaseSensorOperator
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class GlueCrawlerSensor(BaseSensorOperator):
class GlueCrawlerSensor(AwsBaseSensor[GlueCrawlerHook]):
"""
Waits for an AWS Glue crawler to reach any of the statuses below.

Expand All @@ -41,19 +41,27 @@ class GlueCrawlerSensor(BaseSensorOperator):
:ref:`howto/sensor:GlueCrawlerSensor`

:param crawler_name: The AWS Glue crawler unique name
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
If this is None or empty then the default boto3 behaviour is used. If
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("crawler_name",)
aws_hook_class = GlueCrawlerHook

def __init__(self, *, crawler_name: str, aws_conn_id: str | None = "aws_default", **kwargs) -> None:
template_fields: Sequence[str] = aws_template_fields(
"crawler_name",
)

def __init__(self, *, crawler_name: str, **kwargs) -> None:
super().__init__(**kwargs)
self.crawler_name = crawler_name
self.aws_conn_id = aws_conn_id
self.success_statuses = "SUCCEEDED"
self.errored_statuses = ("FAILED", "CANCELLED")

Expand All @@ -79,7 +87,3 @@ def poke(self, context: Context):
def get_hook(self) -> GlueCrawlerHook:
"""Return a new or pre-existing GlueCrawlerHook."""
return self.hook

@cached_property
def hook(self) -> GlueCrawlerHook:
return GlueCrawlerHook(aws_conn_id=self.aws_conn_id)
9 changes: 8 additions & 1 deletion airflow/providers/amazon/aws/triggers/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
aws_conn_id: str | None = "aws_default",
waiter_delay: int = 5,
waiter_max_attempts: int = 1500,
**kwargs,
):
if poll_interval is not None:
warnings.warn(
Expand All @@ -62,7 +63,13 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return GlueCrawlerHook(aws_conn_id=self.aws_conn_id)
return GlueCrawlerHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)
5 changes: 5 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/glue.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
105 changes: 90 additions & 15 deletions tests/providers/amazon/aws/operators/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Generator
from unittest import mock

import pytest
from moto import mock_aws

from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
from airflow.providers.amazon.aws.hooks.sts import StsHook
from airflow.providers.amazon.aws.operators.glue_crawler import GlueCrawlerOperator

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection

mock_crawler_name = "test-crawler"
mock_role_name = "test-role"
mock_config = {
Expand Down Expand Up @@ -81,20 +90,86 @@


class TestGlueCrawlerOperator:
@pytest.fixture
def mock_conn(self) -> Generator[BaseAwsConnection, None, None]:
with mock.patch.object(GlueCrawlerHook, "get_conn") as _conn:
_conn.create_crawler.return_value = mock_crawler_name
yield _conn

@pytest.fixture
def crawler_hook(self) -> Generator[GlueCrawlerHook, None, None]:
with mock_aws():
hook = GlueCrawlerHook(aws_conn_id="aws_default")
yield hook

def setup_method(self):
self.glue = GlueCrawlerOperator(task_id="test_glue_crawler_operator", config=mock_config)

@mock.patch("airflow.providers.amazon.aws.operators.glue_crawler.GlueCrawlerHook")
def test_execute_without_failure(self, mock_hook):
mock_hook.return_value.has_crawler.return_value = True
self.glue.execute({})

mock_hook.assert_has_calls(
[
mock.call("aws_default", region_name=None),
mock.call().has_crawler("test-crawler"),
mock.call().update_crawler(**mock_config),
mock.call().start_crawler(mock_crawler_name),
mock.call().wait_for_crawler_completion(crawler_name=mock_crawler_name, poll_interval=5),
]
self.op = GlueCrawlerOperator(task_id="test_glue_crawler_operator", config=mock_config)

def test_init(self):
op = GlueCrawlerOperator(
task_id="test_glue_crawler_operator",
aws_conn_id="fake-conn-id",
region_name="eu-west-2",
verify=True,
botocore_config={"read_timeout": 42},
config=mock_config,
)

assert op.hook.client_type == "glue"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "eu-west-2"
assert op.hook._verify is True
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = GlueCrawlerOperator(task_id="test_glue_crawler_operator", config=mock_config)

assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

@mock.patch.object(GlueCrawlerHook, "glue_client")
@mock.patch.object(StsHook, "get_account_number")
def test_execute_update_and_start_crawler(self, sts_mock, mock_glue_client):
sts_mock.get_account_number.return_value = 123456789012
mock_glue_client.get_crawler.return_value = {"Crawler": {}}
self.op.wait_for_completion = False
crawler_name = self.op.execute({})

mock_glue_client.get_crawler.call_count = 2
mock_glue_client.update_crawler.call_count = 1
mock_glue_client.start_crawler.call_count = 1
assert crawler_name == mock_crawler_name

@mock.patch.object(GlueCrawlerHook, "has_crawler")
@mock.patch.object(GlueCrawlerHook, "glue_client")
def test_execute_create_and_start_crawler(self, mock_glue_client, mock_has_crawler):
mock_has_crawler.return_value = False
mock_glue_client.create_crawler.return_value = {}
self.op.wait_for_completion = False
crawler_name = self.op.execute({})

assert crawler_name == mock_crawler_name
mock_glue_client.create_crawler.assert_called_once()

@pytest.mark.parametrize(
"wait_for_completion, deferrable",
[
pytest.param(False, False, id="no_wait"),
pytest.param(True, False, id="wait"),
pytest.param(False, True, id="defer"),
],
)
@mock.patch.object(GlueCrawlerHook, "get_waiter")
def test_crawler_wait_combinations(self, _, wait_for_completion, deferrable, mock_conn, crawler_hook):
self.op.defer = mock.MagicMock()
self.op.wait_for_completion = wait_for_completion
self.op.deferrable = deferrable

response = self.op.execute({})

assert response == mock_crawler_name
assert crawler_hook.get_waiter.call_count == wait_for_completion
assert self.op.defer.call_count == deferrable
25 changes: 25 additions & 0 deletions tests/providers/amazon/aws/sensors/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,28 @@ def test_fail_poke(self, get_crawler, soft_fail, expected_exception):
message = f"Status: {crawler_status}"
with pytest.raises(expected_exception, match=message):
self.sensor.poke(context={})

def test_base_aws_op_attributes(self):
op = GlueCrawlerSensor(
task_id="test_glue_crawler_sensor",
crawler_name="aws_test_glue_crawler",
)
assert op.hook.client_type == "glue"
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

op = GlueCrawlerSensor(
task_id="test_glue_crawler_sensor",
crawler_name="aws_test_glue_crawler",
aws_conn_id="aws-test-custom-conn",
region_name="eu-west-1",
verify=False,
botocore_config={"read_timeout": 42},
)
assert op.hook.aws_conn_id == "aws-test-custom-conn"
assert op.hook._region_name == "eu-west-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42
24 changes: 23 additions & 1 deletion tests/providers/amazon/aws/triggers/test_glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
# under the License.
from __future__ import annotations

from unittest.mock import patch
from unittest import mock
from unittest.mock import AsyncMock, patch

import pytest

from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger
from airflow.triggers.base import TriggerEvent
from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type


class TestGlueCrawlerCompleteTrigger:
Expand Down Expand Up @@ -47,3 +53,19 @@ def test_serialization(self, mock_warn):
"waiter_max_attempts": 1500,
"aws_conn_id": "aws_default",
}

@pytest.mark.asyncio
@mock.patch.object(GlueCrawlerHook, "get_waiter")
@mock.patch.object(GlueCrawlerHook, "async_conn")
async def test_run_success(self, mock_async_conn, mock_get_waiter):
mock_async_conn.__aenter__.return_value = mock.MagicMock()
mock_get_waiter().wait = AsyncMock()
crawler_name = "test_crawler"
trigger = GlueCrawlerCompleteTrigger(crawler_name=crawler_name)

generator = trigger.run()
response = await generator.asend(None)

assert response == TriggerEvent({"status": "success", "value": None})
assert_expected_waiter_type(mock_get_waiter, "crawler_ready")
mock_get_waiter().wait.assert_called_once()