Skip to content

Commit

Permalink
Use base aws classes in AWS DMS Operators/Sensors (apache#36772)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Jan 15, 2024
1 parent bd21dec commit 01724d8
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 149 deletions.
126 changes: 57 additions & 69 deletions airflow/providers/amazon/aws/operators/dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

from typing import TYPE_CHECKING, Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.dms import DmsHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class DmsCreateTaskOperator(BaseOperator):
class DmsCreateTaskOperator(AwsBaseOperator[DmsHook]):
"""
Creates AWS DMS replication task.
Expand All @@ -42,13 +43,19 @@ class DmsCreateTaskOperator(BaseOperator):
:param migration_type: Migration type ('full-load'|'cdc'|'full-load-and-cdc'), full-load by default.
:param create_task_kwargs: Extra arguments for DMS replication task creation.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = (
aws_hook_class = DmsHook
template_fields: Sequence[str] = aws_template_fields(
"replication_task_id",
"source_endpoint_arn",
"target_endpoint_arn",
Expand All @@ -57,7 +64,6 @@ class DmsCreateTaskOperator(BaseOperator):
"migration_type",
"create_task_kwargs",
)
template_ext: Sequence[str] = ()
template_fields_renderers = {
"table_mappings": "json",
"create_task_kwargs": "json",
Expand Down Expand Up @@ -92,9 +98,7 @@ def execute(self, context: Context):
:return: replication task arn
"""
dms_hook = DmsHook(aws_conn_id=self.aws_conn_id)

task_arn = dms_hook.create_replication_task(
task_arn = self.hook.create_replication_task(
replication_task_id=self.replication_task_id,
source_endpoint_arn=self.source_endpoint_arn,
target_endpoint_arn=self.target_endpoint_arn,
Expand All @@ -108,7 +112,7 @@ def execute(self, context: Context):
return task_arn


class DmsDeleteTaskOperator(BaseOperator):
class DmsDeleteTaskOperator(AwsBaseOperator[DmsHook]):
"""
Deletes AWS DMS replication task.
Expand All @@ -118,39 +122,35 @@ class DmsDeleteTaskOperator(BaseOperator):
:param replication_task_arn: Replication task ARN
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("replication_task_arn",)
template_ext: Sequence[str] = ()
template_fields_renderers: dict[str, str] = {}
aws_hook_class = DmsHook
template_fields: Sequence[str] = aws_template_fields("replication_task_arn")

def __init__(
self,
*,
replication_task_arn: str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
def __init__(self, *, replication_task_arn: str | None = None, **kwargs):
super().__init__(**kwargs)
self.replication_task_arn = replication_task_arn
self.aws_conn_id = aws_conn_id

def execute(self, context: Context):
"""
Delete AWS DMS replication task from Airflow.
:return: replication task arn
"""
dms_hook = DmsHook(aws_conn_id=self.aws_conn_id)
dms_hook.delete_replication_task(replication_task_arn=self.replication_task_arn)
self.hook.delete_replication_task(replication_task_arn=self.replication_task_arn)
self.log.info("DMS replication task(%s) has been deleted.", self.replication_task_arn)


class DmsDescribeTasksOperator(BaseOperator):
class DmsDescribeTasksOperator(AwsBaseOperator[DmsHook]):
"""
Describes AWS DMS replication tasks.
Expand All @@ -160,38 +160,35 @@ class DmsDescribeTasksOperator(BaseOperator):
:param describe_tasks_kwargs: Describe tasks command arguments
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("describe_tasks_kwargs",)
template_ext: Sequence[str] = ()
aws_hook_class = DmsHook
template_fields: Sequence[str] = aws_template_fields("describe_tasks_kwargs")
template_fields_renderers: dict[str, str] = {"describe_tasks_kwargs": "json"}

def __init__(
self,
*,
describe_tasks_kwargs: dict | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
def __init__(self, *, describe_tasks_kwargs: dict | None = None, **kwargs):
super().__init__(**kwargs)
self.describe_tasks_kwargs = describe_tasks_kwargs or {}
self.aws_conn_id = aws_conn_id

def execute(self, context: Context) -> tuple[str | None, list]:
"""
Describe AWS DMS replication tasks from Airflow.
:return: Marker and list of replication tasks
"""
dms_hook = DmsHook(aws_conn_id=self.aws_conn_id)
return dms_hook.describe_replication_tasks(**self.describe_tasks_kwargs)
return self.hook.describe_replication_tasks(**self.describe_tasks_kwargs)


class DmsStartTaskOperator(BaseOperator):
class DmsStartTaskOperator(AwsBaseOperator[DmsHook]):
"""
Starts AWS DMS replication task.
Expand All @@ -204,18 +201,23 @@ class DmsStartTaskOperator(BaseOperator):
('start-replication'|'resume-processing'|'reload-target')
:param start_task_kwargs: Extra start replication task arguments
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = (
aws_hook_class = DmsHook
template_fields: Sequence[str] = aws_template_fields(
"replication_task_arn",
"start_replication_task_type",
"start_task_kwargs",
)
template_ext: Sequence[str] = ()
template_fields_renderers = {"start_task_kwargs": "json"}

def __init__(
Expand All @@ -234,22 +236,16 @@ def __init__(
self.aws_conn_id = aws_conn_id

def execute(self, context: Context):
"""
Start AWS DMS replication task from Airflow.
:return: replication task arn
"""
dms_hook = DmsHook(aws_conn_id=self.aws_conn_id)

dms_hook.start_replication_task(
"""Start AWS DMS replication task from Airflow."""
self.hook.start_replication_task(
replication_task_arn=self.replication_task_arn,
start_replication_task_type=self.start_replication_task_type,
**self.start_task_kwargs,
)
self.log.info("DMS replication task(%s) is starting.", self.replication_task_arn)


class DmsStopTaskOperator(BaseOperator):
class DmsStopTaskOperator(AwsBaseOperator[DmsHook]):
"""
Stops AWS DMS replication task.
Expand All @@ -259,33 +255,25 @@ class DmsStopTaskOperator(BaseOperator):
:param replication_task_arn: Replication task ARN
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("replication_task_arn",)
template_ext: Sequence[str] = ()
template_fields_renderers: dict[str, str] = {}
aws_hook_class = DmsHook
template_fields: Sequence[str] = aws_template_fields("replication_task_arn")

def __init__(
self,
*,
replication_task_arn: str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
def __init__(self, *, replication_task_arn: str | None = None, **kwargs):
super().__init__(**kwargs)
self.replication_task_arn = replication_task_arn
self.aws_conn_id = aws_conn_id

def execute(self, context: Context):
"""
Stop AWS DMS replication task from Airflow.
:return: replication task arn
"""
dms_hook = DmsHook(aws_conn_id=self.aws_conn_id)
dms_hook.stop_replication_task(replication_task_arn=self.replication_task_arn)
"""Stop AWS DMS replication task from Airflow."""
self.hook.stop_replication_task(replication_task_arn=self.replication_task_arn)
self.log.info("DMS replication task(%s) is stopping.", self.replication_task_arn)
55 changes: 31 additions & 24 deletions airflow/providers/amazon/aws/sensors/dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,53 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Iterable, Sequence

from deprecated import deprecated

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.dms import DmsHook
from airflow.sensors.base import BaseSensorOperator
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class DmsTaskBaseSensor(BaseSensorOperator):
class DmsTaskBaseSensor(AwsBaseSensor[DmsHook]):
"""
Contains general sensor behavior for DMS task.
Subclasses should set ``target_statuses`` and ``termination_statuses`` fields.
:param replication_task_arn: AWS DMS replication task ARN
:param aws_conn_id: aws connection to uses
:param target_statuses: the target statuses, sensor waits until
the task reaches any of these states
:param termination_statuses: the termination statuses, sensor fails when
the task reaches any of these states
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("replication_task_arn",)
template_ext: Sequence[str] = ()
aws_hook_class = DmsHook
template_fields: Sequence[str] = aws_template_fields("replication_task_arn")

def __init__(
self,
replication_task_arn: str,
aws_conn_id="aws_default",
target_statuses: Iterable[str] | None = None,
termination_statuses: Iterable[str] | None = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
super().__init__(**kwargs)
self.replication_task_arn = replication_task_arn
self.target_statuses: Iterable[str] = target_statuses or []
self.termination_statuses: Iterable[str] = termination_statuses or []
Expand All @@ -67,14 +73,8 @@ def get_hook(self) -> DmsHook:
"""Get DmsHook."""
return self.hook

@cached_property
def hook(self) -> DmsHook:
return DmsHook(self.aws_conn_id)

def poke(self, context: Context):
status: str | None = self.hook.get_task_status(self.replication_task_arn)

if not status:
if not (status := self.hook.get_task_status(self.replication_task_arn)):
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Failed to read task status, task with ARN {self.replication_task_arn} not found"
if self.soft_fail:
Expand Down Expand Up @@ -105,15 +105,21 @@ class DmsTaskCompletedSensor(DmsTaskBaseSensor):
:ref:`howto/sensor:DmsTaskCompletedSensor`
:param replication_task_arn: AWS DMS replication task ARN
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

template_fields: Sequence[str] = ("replication_task_arn",)
template_ext: Sequence[str] = ()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.target_statuses = ["stopped"]
self.termination_statuses = [
def __init__(self, **kwargs):
kwargs["target_statuses"] = ["stopped"]
kwargs["termination_statuses"] = [
"creating",
"deleting",
"failed",
Expand All @@ -123,3 +129,4 @@ def __init__(self, *args, **kwargs):
"ready",
"testing",
]
super().__init__(**kwargs)
5 changes: 5 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/dms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
Loading

0 comments on commit 01724d8

Please sign in to comment.