Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use base aws classes in AWS Glue DataBrew Operators/Triggers #41848

Merged
merged 4 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/providers/amazon/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Changelog
`Apache Airflow providers support policy <https://github.com/apache/airflow/blob/main/PROVIDERS.rst#minimum-supported-version-of-airflow-for-community-managed-providers>`_.

.. warning:: When deferrable mode was introduced for ``RedshiftDataOperator``, in version 8.17.0, tasks configured with
``deferrable=True`` and ``wait_for_completion=True`` wouldn't enter the deferred state. Instead, the task would occupy
``deferrable=True`` and ``wait_for_completion=True`` would not enter the deferred state. Instead, the task would occupy
an executor slot until the statement was completed. A workaround may have been to set ``wait_for_completion=False``.
In this version, tasks set up with ``wait_for_completion=False`` will not wait anymore, regardless of the value of
``deferrable``.
Expand Down
70 changes: 53 additions & 17 deletions airflow/providers/amazon/aws/operators/glue_databrew.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@
# under the License.
from __future__ import annotations

from functools import cached_property
import warnings
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.models import BaseOperator
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.glue_databrew import GlueDataBrewJobCompleteTrigger
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class GlueDataBrewStartJobOperator(BaseOperator):
class GlueDataBrewStartJobOperator(AwsBaseOperator[GlueDataBrewHook]):
"""
Start an AWS Glue DataBrew job.

Expand All @@ -47,36 +49,55 @@ class GlueDataBrewStartJobOperator(BaseOperator):
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
:param delay: Time in seconds to wait between status checks. Default is 30.
:param delay: Time in seconds to wait between status checks. (Deprecated).
:param waiter_delay: Time in seconds to wait between status checks. Default is 30.
:param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 60)
:return: dictionary with key run_id and value of the resulting job's run_id.

:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = (
aws_hook_class = GlueDataBrewHook

template_fields: Sequence[str] = aws_template_fields(
"job_name",
"wait_for_completion",
"delay",
"deferrable",
"waiter_delay",
"waiter_max_attemptsdeferrable",
)

def __init__(
self,
job_name: str,
wait_for_completion: bool = True,
delay: int = 30,
delay: int | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
aws_conn_id: str | None = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.job_name = job_name
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
self.delay = delay
self.aws_conn_id = aws_conn_id

@cached_property
def hook(self) -> GlueDataBrewHook:
return GlueDataBrewHook(aws_conn_id=self.aws_conn_id)
if delay is not None:
warnings.warn(
"please use `waiter_delay` instead of delay.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
self.waiter_delay = delay

def execute(self, context: Context):
job = self.hook.conn.start_job_run(Name=self.job_name)
Expand All @@ -88,7 +109,14 @@ def execute(self, context: Context):
self.log.info("Deferring job %s with run_id %s", self.job_name, run_id)
self.defer(
trigger=GlueDataBrewJobCompleteTrigger(
aws_conn_id=self.aws_conn_id, job_name=self.job_name, run_id=run_id, delay=self.delay
job_name=self.job_name,
run_id=run_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
)
Expand All @@ -97,14 +125,22 @@ def execute(self, context: Context):
self.log.info(
"Waiting for AWS Glue DataBrew Job: %s. Run Id: %s to complete.", self.job_name, run_id
)
status = self.hook.job_completion(job_name=self.job_name, delay=self.delay, run_id=run_id)
status = self.hook.job_completion(
job_name=self.job_name,
delay=self.waiter_delay,
run_id=run_id,
max_attempts=self.waiter_max_attempts,
)
self.log.info("Glue DataBrew Job: %s status: %s", self.job_name, status)

return {"run_id": run_id}

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]:
event = validate_execute_complete_event(event)

if event["status"] != "success":
raise AirflowException("Error while running AWS Glue DataBrew job: %s", event)

run_id = event.get("run_id", "")
status = event.get("status", "")

Expand Down
43 changes: 35 additions & 8 deletions airflow/providers/amazon/aws/triggers/glue_databrew.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

from __future__ import annotations

import warnings

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger

Expand All @@ -27,20 +30,38 @@ class GlueDataBrewJobCompleteTrigger(AwsBaseWaiterTrigger):

:param job_name: Glue DataBrew job name
:param run_id: the ID of the specific run to watch for that job
:param delay: Number of seconds to wait between two checks. Default is 10 seconds.
:param max_attempts: Maximum number of attempts to wait for the job to complete. Default is 60 attempts.
:param delay: Number of seconds to wait between two checks.(Deprecated).
:param waiter_delay: Number of seconds to wait between two checks. Default is 30 seconds.
:param max_attempts: Maximum number of attempts to wait for the job to complete.(Deprecated).
:param waiter_max_attempts: Maximum number of attempts to wait for the job to complete. Default is 60 attempts.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""

def __init__(
self,
job_name: str,
run_id: str,
aws_conn_id: str | None,
delay: int = 10,
max_attempts: int = 60,
delay: int | None = None,
max_attempts: int | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 60,
aws_conn_id: str | None = "aws_default",
**kwargs,
):
if delay is not None:
warnings.warn(
"please use `waiter_delay` instead of delay.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
waiter_delay = delay or waiter_delay
if max_attempts is not None:
warnings.warn(
"please use `waiter_max_attempts` instead of max_attempts.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
waiter_max_attempts = max_attempts or waiter_max_attempts
super().__init__(
serialized_fields={"job_name": job_name, "run_id": run_id},
waiter_name="job_complete",
Expand All @@ -50,10 +71,16 @@ def __init__(
status_queries=["State"],
return_value=run_id,
return_key="run_id",
waiter_delay=delay,
waiter_max_attempts=max_attempts,
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
**kwargs,
)

def hook(self) -> GlueDataBrewHook:
return GlueDataBrewHook(aws_conn_id=self.aws_conn_id)
return GlueDataBrewHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
44 changes: 44 additions & 0 deletions tests/providers/amazon/aws/operators/test_glue_databrew.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest
from moto import mock_aws

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook
from airflow.providers.amazon.aws.operators.glue_databrew import GlueDataBrewStartJobOperator

Expand All @@ -36,6 +37,30 @@ def hook() -> Generator[GlueDataBrewHook, None, None]:


class TestGlueDataBrewOperator:
def test_init(self):
op = GlueDataBrewStartJobOperator(
task_id="task_test",
job_name=JOB_NAME,
aws_conn_id="fake-conn-id",
region_name="eu-central-1",
verify="/spam/egg.pem",
botocore_config={"read_timeout": 42},
)

assert op.hook.client_type == "databrew"
assert op.hook.resource_type is None
assert op.hook.aws_conn_id == "fake-conn-id"
assert op.hook._region_name == "eu-central-1"
assert op.hook._verify == "/spam/egg.pem"
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

op = GlueDataBrewStartJobOperator(task_id="fake_task_id", job_name=JOB_NAME)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None

@mock.patch.object(GlueDataBrewHook, "conn")
@mock.patch.object(GlueDataBrewHook, "get_waiter")
def test_start_job_wait_for_completion(self, mock_hook_get_waiter, mock_conn):
Expand All @@ -57,3 +82,22 @@ def test_start_job_no_wait(self, mock_hook_get_waiter, mock_conn):
mock_conn.start_job_run(mock.MagicMock(), return_value=TEST_RUN_ID)
operator.execute(None)
mock_hook_get_waiter.assert_not_called()

@mock.patch.object(GlueDataBrewHook, "conn")
@mock.patch.object(GlueDataBrewHook, "get_waiter")
def test_start_job_with_deprecation_parameters(self, mock_hook_get_waiter, mock_conn):
TEST_RUN_ID = "12345"

with pytest.warns(AirflowProviderDeprecationWarning):
operator = GlueDataBrewStartJobOperator(
task_id="task_test",
job_name=JOB_NAME,
wait_for_completion=False,
aws_conn_id="aws_default",
delay=15,
)

mock_conn.start_job_run(mock.MagicMock(), return_value=TEST_RUN_ID)
assert operator.waiter_delay == 15
operator.execute(None)
mock_hook_get_waiter.assert_not_called()
22 changes: 22 additions & 0 deletions tests/providers/amazon/aws/triggers/test_glue_databrew.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pytest

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.triggers.glue_databrew import GlueDataBrewJobCompleteTrigger

TEST_JOB_NAME = "test_job_name"
Expand All @@ -44,3 +45,24 @@ def test_serialize(self, trigger):

assert class_path == class_path2
assert args == args2

def test_serialize_with_deprecated_parameters(self, trigger):
with pytest.warns(AirflowProviderDeprecationWarning):
class_path, args = GlueDataBrewJobCompleteTrigger(
aws_conn_id="aws_default",
job_name=TEST_JOB_NAME,
run_id=TEST_JOB_RUN_ID,
delay=1,
max_attempts=1,
).serialize()

class_name = class_path.split(".")[-1]
clazz = globals()[class_name]
instance = clazz(**args)

class_path2, args2 = instance.serialize()

assert class_path == class_path2
assert args == args2
assert args.get("waiter_delay") == 1
assert args.get("waiter_max_attempts") == 1
2 changes: 1 addition & 1 deletion tests/system/providers/amazon/aws/example_glue_databrew.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def delete_job(job_name: str):
)

# [START howto_operator_glue_databrew_start]
start_job = GlueDataBrewStartJobOperator(task_id="startjob", job_name=job_name, delay=15)
start_job = GlueDataBrewStartJobOperator(task_id="startjob", job_name=job_name, waiter_delay=15)
# [END howto_operator_glue_databrew_start]

delete_bucket = S3DeleteBucketOperator(
Expand Down