diff --git a/airflow/providers/amazon/aws/operators/dms.py b/airflow/providers/amazon/aws/operators/dms.py index 8a32f33f18fb7..0107b1c3826a3 100644 --- a/airflow/providers/amazon/aws/operators/dms.py +++ b/airflow/providers/amazon/aws/operators/dms.py @@ -19,14 +19,15 @@ from typing import TYPE_CHECKING, Sequence -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.dms import DmsHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -class DmsCreateTaskOperator(BaseOperator): +class DmsCreateTaskOperator(AwsBaseOperator[DmsHook]): """ Creates AWS DMS replication task. @@ -42,13 +43,19 @@ class DmsCreateTaskOperator(BaseOperator): :param migration_type: Migration type ('full-load'|'cdc'|'full-load-and-cdc'), full-load by default. :param create_task_kwargs: Extra arguments for DMS replication task creation. :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ( + aws_hook_class = DmsHook + template_fields: Sequence[str] = aws_template_fields( "replication_task_id", "source_endpoint_arn", "target_endpoint_arn", @@ -57,7 +64,6 @@ class DmsCreateTaskOperator(BaseOperator): "migration_type", "create_task_kwargs", ) - template_ext: Sequence[str] = () template_fields_renderers = { "table_mappings": "json", "create_task_kwargs": "json", @@ -92,9 +98,7 @@ def execute(self, context: Context): :return: replication task arn """ - dms_hook = DmsHook(aws_conn_id=self.aws_conn_id) - - task_arn = dms_hook.create_replication_task( + task_arn = self.hook.create_replication_task( replication_task_id=self.replication_task_id, source_endpoint_arn=self.source_endpoint_arn, target_endpoint_arn=self.target_endpoint_arn, @@ -108,7 +112,7 @@ def execute(self, context: Context): return task_arn -class DmsDeleteTaskOperator(BaseOperator): +class DmsDeleteTaskOperator(AwsBaseOperator[DmsHook]): """ Deletes AWS DMS replication task. @@ -118,26 +122,23 @@ class DmsDeleteTaskOperator(BaseOperator): :param replication_task_arn: Replication task ARN :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("replication_task_arn",) - template_ext: Sequence[str] = () - template_fields_renderers: dict[str, str] = {} + aws_hook_class = DmsHook + template_fields: Sequence[str] = aws_template_fields("replication_task_arn") - def __init__( - self, - *, - replication_task_arn: str | None = None, - aws_conn_id: str = "aws_default", - **kwargs, - ): + def __init__(self, *, replication_task_arn: str | None = None, **kwargs): super().__init__(**kwargs) self.replication_task_arn = replication_task_arn - self.aws_conn_id = aws_conn_id def execute(self, context: Context): """ @@ -145,12 +146,11 @@ def execute(self, context: Context): :return: replication task arn """ - dms_hook = DmsHook(aws_conn_id=self.aws_conn_id) - dms_hook.delete_replication_task(replication_task_arn=self.replication_task_arn) + self.hook.delete_replication_task(replication_task_arn=self.replication_task_arn) self.log.info("DMS replication task(%s) has been deleted.", self.replication_task_arn) -class DmsDescribeTasksOperator(BaseOperator): +class DmsDescribeTasksOperator(AwsBaseOperator[DmsHook]): """ Describes AWS DMS replication tasks. @@ -160,26 +160,24 @@ class DmsDescribeTasksOperator(BaseOperator): :param describe_tasks_kwargs: Describe tasks command arguments :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("describe_tasks_kwargs",) - template_ext: Sequence[str] = () + aws_hook_class = DmsHook + template_fields: Sequence[str] = aws_template_fields("describe_tasks_kwargs") template_fields_renderers: dict[str, str] = {"describe_tasks_kwargs": "json"} - def __init__( - self, - *, - describe_tasks_kwargs: dict | None = None, - aws_conn_id: str = "aws_default", - **kwargs, - ): + def __init__(self, *, describe_tasks_kwargs: dict | None = None, **kwargs): super().__init__(**kwargs) self.describe_tasks_kwargs = describe_tasks_kwargs or {} - self.aws_conn_id = aws_conn_id def execute(self, context: Context) -> tuple[str | None, list]: """ @@ -187,11 +185,10 @@ def execute(self, context: Context) -> tuple[str | None, list]: :return: Marker and list of replication tasks """ - dms_hook = DmsHook(aws_conn_id=self.aws_conn_id) - return dms_hook.describe_replication_tasks(**self.describe_tasks_kwargs) + return self.hook.describe_replication_tasks(**self.describe_tasks_kwargs) -class DmsStartTaskOperator(BaseOperator): +class DmsStartTaskOperator(AwsBaseOperator[DmsHook]): """ Starts AWS DMS replication task. @@ -204,18 +201,23 @@ class DmsStartTaskOperator(BaseOperator): ('start-replication'|'resume-processing'|'reload-target') :param start_task_kwargs: Extra start replication task arguments :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ( + aws_hook_class = DmsHook + template_fields: Sequence[str] = aws_template_fields( "replication_task_arn", "start_replication_task_type", "start_task_kwargs", ) - template_ext: Sequence[str] = () template_fields_renderers = {"start_task_kwargs": "json"} def __init__( @@ -234,14 +236,8 @@ def __init__( self.aws_conn_id = aws_conn_id def execute(self, context: Context): - """ - Start AWS DMS replication task from Airflow. - - :return: replication task arn - """ - dms_hook = DmsHook(aws_conn_id=self.aws_conn_id) - - dms_hook.start_replication_task( + """Start AWS DMS replication task from Airflow.""" + self.hook.start_replication_task( replication_task_arn=self.replication_task_arn, start_replication_task_type=self.start_replication_task_type, **self.start_task_kwargs, @@ -249,7 +245,7 @@ def execute(self, context: Context): self.log.info("DMS replication task(%s) is starting.", self.replication_task_arn) -class DmsStopTaskOperator(BaseOperator): +class DmsStopTaskOperator(AwsBaseOperator[DmsHook]): """ Stops AWS DMS replication task. @@ -259,33 +255,25 @@ class DmsStopTaskOperator(BaseOperator): :param replication_task_arn: Replication task ARN :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("replication_task_arn",) - template_ext: Sequence[str] = () - template_fields_renderers: dict[str, str] = {} + aws_hook_class = DmsHook + template_fields: Sequence[str] = aws_template_fields("replication_task_arn") - def __init__( - self, - *, - replication_task_arn: str | None = None, - aws_conn_id: str = "aws_default", - **kwargs, - ): + def __init__(self, *, replication_task_arn: str | None = None, **kwargs): super().__init__(**kwargs) self.replication_task_arn = replication_task_arn - self.aws_conn_id = aws_conn_id def execute(self, context: Context): - """ - Stop AWS DMS replication task from Airflow. - - :return: replication task arn - """ - dms_hook = DmsHook(aws_conn_id=self.aws_conn_id) - dms_hook.stop_replication_task(replication_task_arn=self.replication_task_arn) + """Stop AWS DMS replication task from Airflow.""" + self.hook.stop_replication_task(replication_task_arn=self.replication_task_arn) self.log.info("DMS replication task(%s) is stopping.", self.replication_task_arn) diff --git a/airflow/providers/amazon/aws/sensors/dms.py b/airflow/providers/amazon/aws/sensors/dms.py index d6ce3b3b1b94a..864a3b5276c32 100644 --- a/airflow/providers/amazon/aws/sensors/dms.py +++ b/airflow/providers/amazon/aws/sensors/dms.py @@ -17,47 +17,53 @@ # under the License. from __future__ import annotations -from functools import cached_property from typing import TYPE_CHECKING, Iterable, Sequence from deprecated import deprecated from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.amazon.aws.hooks.dms import DmsHook -from airflow.sensors.base import BaseSensorOperator +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -class DmsTaskBaseSensor(BaseSensorOperator): +class DmsTaskBaseSensor(AwsBaseSensor[DmsHook]): """ Contains general sensor behavior for DMS task. Subclasses should set ``target_statuses`` and ``termination_statuses`` fields. :param replication_task_arn: AWS DMS replication task ARN - :param aws_conn_id: aws connection to uses :param target_statuses: the target statuses, sensor waits until the task reaches any of these states :param termination_statuses: the termination statuses, sensor fails when the task reaches any of these states + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("replication_task_arn",) - template_ext: Sequence[str] = () + aws_hook_class = DmsHook + template_fields: Sequence[str] = aws_template_fields("replication_task_arn") def __init__( self, replication_task_arn: str, - aws_conn_id="aws_default", target_statuses: Iterable[str] | None = None, termination_statuses: Iterable[str] | None = None, - *args, **kwargs, ): - super().__init__(*args, **kwargs) - self.aws_conn_id = aws_conn_id + super().__init__(**kwargs) self.replication_task_arn = replication_task_arn self.target_statuses: Iterable[str] = target_statuses or [] self.termination_statuses: Iterable[str] = termination_statuses or [] @@ -67,14 +73,8 @@ def get_hook(self) -> DmsHook: """Get DmsHook.""" return self.hook - @cached_property - def hook(self) -> DmsHook: - return DmsHook(self.aws_conn_id) - def poke(self, context: Context): - status: str | None = self.hook.get_task_status(self.replication_task_arn) - - if not status: + if not (status := self.hook.get_task_status(self.replication_task_arn)): # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 message = f"Failed to read task status, task with ARN {self.replication_task_arn} not found" if self.soft_fail: @@ -105,15 +105,21 @@ class DmsTaskCompletedSensor(DmsTaskBaseSensor): :ref:`howto/sensor:DmsTaskCompletedSensor` :param replication_task_arn: AWS DMS replication task ARN + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("replication_task_arn",) - template_ext: Sequence[str] = () - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.target_statuses = ["stopped"] - self.termination_statuses = [ + def __init__(self, **kwargs): + kwargs["target_statuses"] = ["stopped"] + kwargs["termination_statuses"] = [ "creating", "deleting", "failed", @@ -123,3 +129,4 @@ def __init__(self, *args, **kwargs): "ready", "testing", ] + super().__init__(**kwargs) diff --git a/docs/apache-airflow-providers-amazon/operators/dms.rst b/docs/apache-airflow-providers-amazon/operators/dms.rst index 3a9f38e72ad25..2c30e3ca6ec88 100644 --- a/docs/apache-airflow-providers-amazon/operators/dms.rst +++ b/docs/apache-airflow-providers-amazon/operators/dms.rst @@ -36,6 +36,11 @@ Prerequisite Tasks .. include:: ../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + Operators --------- diff --git a/tests/providers/amazon/aws/operators/test_dms.py b/tests/providers/amazon/aws/operators/test_dms.py index c56e01e4c3d58..a9fa96ed91505 100644 --- a/tests/providers/amazon/aws/operators/test_dms.py +++ b/tests/providers/amazon/aws/operators/test_dms.py @@ -60,14 +60,34 @@ class TestDmsCreateTaskOperator: } def test_init(self): - create_operator = DmsCreateTaskOperator(task_id="create_task", **self.TASK_DATA) - - assert create_operator.replication_task_id == self.TASK_DATA["replication_task_id"] - assert create_operator.source_endpoint_arn == self.TASK_DATA["source_endpoint_arn"] - assert create_operator.target_endpoint_arn == self.TASK_DATA["target_endpoint_arn"] - assert create_operator.replication_instance_arn == self.TASK_DATA["replication_instance_arn"] - assert create_operator.migration_type == "full-load" - assert create_operator.table_mappings == self.TASK_DATA["table_mappings"] + op = DmsCreateTaskOperator( + task_id="create_task", + **self.TASK_DATA, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="ca-west-1", + verify=True, + botocore_config={"read_timeout": 42}, + ) + assert op.replication_task_id == self.TASK_DATA["replication_task_id"] + assert op.source_endpoint_arn == self.TASK_DATA["source_endpoint_arn"] + assert op.target_endpoint_arn == self.TASK_DATA["target_endpoint_arn"] + assert op.replication_instance_arn == self.TASK_DATA["replication_instance_arn"] + assert op.migration_type == "full-load" + assert op.table_mappings == self.TASK_DATA["table_mappings"] + assert op.hook.client_type == "dms" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "ca-west-1" + assert op.hook._verify is True + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = DmsCreateTaskOperator(task_id="create_task", **self.TASK_DATA) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None @mock.patch.object(DmsHook, "get_task_status", side_effect=("ready",)) @mock.patch.object(DmsHook, "create_replication_task", return_value=TASK_ARN) @@ -112,9 +132,29 @@ class TestDmsDeleteTaskOperator: } def test_init(self): - dms_operator = DmsDeleteTaskOperator(task_id="delete_task", replication_task_arn=TASK_ARN) - - assert dms_operator.replication_task_arn == TASK_ARN + op = DmsDeleteTaskOperator( + task_id="delete_task", + replication_task_arn=TASK_ARN, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="us-east-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.replication_task_arn == TASK_ARN + assert op.hook.client_type == "dms" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "us-east-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = DmsDeleteTaskOperator(task_id="describe_tasks", replication_task_arn=TASK_ARN) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None @mock.patch.object(DmsHook, "get_task_status", side_effect=("deleting",)) @mock.patch.object(DmsHook, "delete_replication_task") @@ -166,11 +206,31 @@ def setup_method(self): self.dag = DAG("dms_describe_tasks_operator", default_args=args, schedule="@once") def test_init(self): - dms_operator = DmsDescribeTasksOperator( + op = DmsDescribeTasksOperator( + task_id="describe_tasks", + describe_tasks_kwargs={"Filters": [self.FILTER]}, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="eu-west-2", + verify="/foo/bar/spam.egg", + botocore_config={"read_timeout": 42}, + ) + assert op.describe_tasks_kwargs == {"Filters": [self.FILTER]} + assert op.hook.client_type == "dms" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "eu-west-2" + assert op.hook._verify == "/foo/bar/spam.egg" + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = DmsDescribeTasksOperator( task_id="describe_tasks", describe_tasks_kwargs={"Filters": [self.FILTER]} ) - - assert dms_operator.describe_tasks_kwargs == {"Filters": [self.FILTER]} + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None @mock.patch.object(DmsHook, "describe_replication_tasks", return_value=(None, MOCK_RESPONSE)) @mock.patch.object(DmsHook, "get_conn") @@ -211,10 +271,30 @@ class TestDmsStartTaskOperator: } def test_init(self): - dms_operator = DmsStartTaskOperator(task_id="start_task", replication_task_arn=TASK_ARN) - - assert dms_operator.replication_task_arn == TASK_ARN - assert dms_operator.start_replication_task_type == "start-replication" + op = DmsStartTaskOperator( + task_id="start_task", + replication_task_arn=TASK_ARN, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="us-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.replication_task_arn == TASK_ARN + assert op.start_replication_task_type == "start-replication" + assert op.hook.client_type == "dms" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "us-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = DmsStartTaskOperator(task_id="start_task", replication_task_arn=TASK_ARN) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None @mock.patch.object(DmsHook, "get_task_status", side_effect=("starting",)) @mock.patch.object(DmsHook, "start_replication_task") @@ -248,9 +328,29 @@ class TestDmsStopTaskOperator: } def test_init(self): - dms_operator = DmsStopTaskOperator(task_id="stop_task", replication_task_arn=TASK_ARN) - - assert dms_operator.replication_task_arn == TASK_ARN + op = DmsStopTaskOperator( + task_id="stop_task", + replication_task_arn=TASK_ARN, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="eu-west-1", + verify=True, + botocore_config={"read_timeout": 42}, + ) + assert op.replication_task_arn == TASK_ARN + assert op.hook.client_type == "dms" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is True + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = DmsStopTaskOperator(task_id="stop_task", replication_task_arn=TASK_ARN) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None @mock.patch.object(DmsHook, "get_task_status", side_effect=("stopping",)) @mock.patch.object(DmsHook, "stop_replication_task") diff --git a/tests/providers/amazon/aws/sensors/test_dms.py b/tests/providers/amazon/aws/sensors/test_dms.py index 810510c80b91d..eb99949ea053b 100644 --- a/tests/providers/amazon/aws/sensors/test_dms.py +++ b/tests/providers/amazon/aws/sensors/test_dms.py @@ -20,51 +20,94 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.amazon.aws.hooks.dms import DmsHook from airflow.providers.amazon.aws.sensors.dms import DmsTaskCompletedSensor +@pytest.fixture +def mocked_get_task_status(): + with mock.patch.object(DmsHook, "get_task_status") as m: + yield m + + class TestDmsTaskCompletedSensor: def setup_method(self): - self.sensor = DmsTaskCompletedSensor( - task_id="test_dms_sensor", - aws_conn_id="aws_default", - replication_task_arn="task_arn", - ) + self.default_op_kwargs = { + "task_id": "test_dms_sensor", + "aws_conn_id": None, + "replication_task_arn": "task_arn", + } - @mock.patch.object(DmsHook, "get_task_status", side_effect=("stopped",)) - def test_poke_stopped(self, mock_get_task_status): - assert self.sensor.poke(None) + def test_init(self): + self.default_op_kwargs.pop("aws_conn_id", None) - @mock.patch.object(DmsHook, "get_task_status", side_effect=("running",)) - def test_poke_running(self, mock_get_task_status): - assert not self.sensor.poke(None) + sensor = DmsTaskCompletedSensor( + **self.default_op_kwargs, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="ca-west-1", + verify=True, + botocore_config={"read_timeout": 42}, + ) + assert sensor.hook.client_type == "dms" + assert sensor.hook.resource_type is None + assert sensor.hook.aws_conn_id == "fake-conn-id" + assert sensor.hook._region_name == "ca-west-1" + assert sensor.hook._verify is True + assert sensor.hook._config is not None + assert sensor.hook._config.read_timeout == 42 - @mock.patch.object(DmsHook, "get_task_status", side_effect=("starting",)) - def test_poke_starting(self, mock_get_task_status): - assert not self.sensor.poke(None) + sensor = DmsTaskCompletedSensor(task_id="create_task", replication_task_arn="task_arn") + assert sensor.hook.aws_conn_id == "aws_default" + assert sensor.hook._region_name is None + assert sensor.hook._verify is None + assert sensor.hook._config is None - @mock.patch.object(DmsHook, "get_task_status", side_effect=("ready",)) - def test_poke_ready(self, mock_get_task_status): - with pytest.raises(AirflowException) as ctx: - self.sensor.poke(None) - assert "Unexpected status: ready" in str(ctx.value) + @pytest.mark.parametrize("status", ["stopped"]) + def test_poke_completed(self, mocked_get_task_status, status): + mocked_get_task_status.return_value = status + assert DmsTaskCompletedSensor(**self.default_op_kwargs).poke({}) - @mock.patch.object(DmsHook, "get_task_status", side_effect=("creating",)) - def test_poke_creating(self, mock_get_task_status): - with pytest.raises(AirflowException) as ctx: - self.sensor.poke(None) - assert "Unexpected status: creating" in str(ctx.value) + @pytest.mark.parametrize("status", ["running", "starting"]) + def test_poke_not_completed(self, mocked_get_task_status, status): + mocked_get_task_status.return_value = status + assert not DmsTaskCompletedSensor(**self.default_op_kwargs).poke({}) - @mock.patch.object(DmsHook, "get_task_status", side_effect=("failed",)) - def test_poke_failed(self, mock_get_task_status): - with pytest.raises(AirflowException) as ctx: - self.sensor.poke(None) - assert "Unexpected status: failed" in str(ctx.value) + @pytest.mark.parametrize( + "status", + [ + "creating", + "deleting", + "failed", + "failed-move", + "modifying", + "moving", + "ready", + "testing", + ], + ) + @pytest.mark.parametrize( + "soft_fail, expected_exception", + [ + pytest.param(True, AirflowSkipException, id="soft-fail"), + pytest.param(False, AirflowException, id="non-soft-fail"), + ], + ) + def test_poke_terminated_status(self, mocked_get_task_status, status, soft_fail, expected_exception): + mocked_get_task_status.return_value = status + error_message = f"Unexpected status: {status}" + with pytest.raises(AirflowException, match=error_message): + DmsTaskCompletedSensor(**self.default_op_kwargs, soft_fail=soft_fail).poke({}) - @mock.patch.object(DmsHook, "get_task_status", side_effect=("deleting",)) - def test_poke_deleting(self, mock_get_task_status): - with pytest.raises(AirflowException) as ctx: - self.sensor.poke(None) - assert "Unexpected status: deleting" in str(ctx.value) + @pytest.mark.parametrize( + "soft_fail, expected_exception", + [ + pytest.param(True, AirflowSkipException, id="soft-fail"), + pytest.param(False, AirflowException, id="non-soft-fail"), + ], + ) + def test_poke_none_status(self, mocked_get_task_status, soft_fail, expected_exception): + mocked_get_task_status.return_value = None + with pytest.raises(AirflowException, match="task with ARN .* not found"): + DmsTaskCompletedSensor(**self.default_op_kwargs, soft_fail=soft_fail).poke({})