Skip to content

Commit

Permalink
Add testcases to ensure that 'soft_fail' argument is respected when r…
Browse files Browse the repository at this point in the history
…unning ExternalTaskSensor (#34652)

* Add testcase to ensure the soft_fail param is respected

* Refactor testcase

* Fix testcase
  • Loading branch information
utkarsharma2 authored Oct 1, 2023
1 parent 1fdc231 commit 0fee730
Showing 1 changed file with 134 additions and 1 deletion.
135 changes: 134 additions & 1 deletion tests/sensors/test_external_task_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import itertools
import logging
import os
import re
import tempfile
import zipfile
from datetime import time, timedelta
Expand All @@ -29,7 +30,7 @@

from airflow import exceptions, settings
from airflow.decorators import task as task_deco
from airflow.exceptions import AirflowException, AirflowSensorTimeout, TaskDeferred
from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException, TaskDeferred
from airflow.models import DagBag, DagRun, TaskInstance
from airflow.models.dag import DAG
from airflow.models.serialized_dag import SerializedDagModel
Expand Down Expand Up @@ -838,6 +839,138 @@ def test_external_task_group_when_there_is_no_TIs(self):
ignore_ti_state=True,
)

@pytest.mark.parametrize(
"kwargs, expected_message",
(
(
{
"external_task_ids": [TEST_TASK_ID, TEST_TASK_ID_ALTERNATE],
"failed_states": [State.FAILED],
},
f"Some of the external tasks {re.escape(str([TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]))}"
f" in DAG {TEST_DAG_ID} failed.",
),
(
{
"external_task_group_id": [TEST_TASK_ID, TEST_TASK_ID_ALTERNATE],
"failed_states": [State.FAILED],
},
f"The external task_group '{re.escape(str([TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]))}'"
f" in DAG '{TEST_DAG_ID}' failed.",
),
(
{"failed_states": [State.FAILED]},
f"The external DAG {TEST_DAG_ID} failed.",
),
),
)
@pytest.mark.parametrize(
"soft_fail, expected_exception",
(
(
False,
AirflowException,
),
(
True,
AirflowSkipException,
),
),
)
@mock.patch("airflow.sensors.external_task.ExternalTaskSensor.get_count")
@mock.patch("airflow.sensors.external_task.ExternalTaskSensor._get_dttm_filter")
def test_fail_poke(
self, _get_dttm_filter, get_count, soft_fail, expected_exception, kwargs, expected_message
):
_get_dttm_filter.return_value = []
get_count.return_value = 1
op = ExternalTaskSensor(
task_id="test_external_task_duplicate_task_ids",
external_dag_id=TEST_DAG_ID,
allowed_states=["success"],
dag=self.dag,
soft_fail=soft_fail,
deferrable=False,
**kwargs,
)
with pytest.raises(expected_exception, match=expected_message):
op.execute(context={})

@pytest.mark.parametrize(
"response_get_current, response_exists, kwargs, expected_message",
(
(None, None, {}, f"The external DAG {TEST_DAG_ID} does not exist."),
(
DAG(dag_id="test"),
False,
{},
f"The external DAG {TEST_DAG_ID} was deleted.",
),
(
DAG(dag_id="test"),
True,
{"external_task_ids": [TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]},
f"The external task {TEST_TASK_ID} in DAG {TEST_DAG_ID} does not exist.",
),
(
DAG(dag_id="test"),
True,
{"external_task_group_id": [TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]},
f"The external task group '{re.escape(str([TEST_TASK_ID, TEST_TASK_ID_ALTERNATE]))}'"
f" in DAG '{TEST_DAG_ID}' does not exist.",
),
),
)
@pytest.mark.parametrize(
"soft_fail, expected_exception",
(
(
False,
AirflowException,
),
(
True,
AirflowSkipException,
),
),
)
@mock.patch("airflow.sensors.external_task.ExternalTaskSensor._get_dttm_filter")
@mock.patch("airflow.models.dagbag.DagBag.get_dag")
@mock.patch("os.path.exists")
@mock.patch("airflow.models.dag.DagModel.get_current")
def test_fail__check_for_existence(
self,
get_current,
exists,
get_dag,
_get_dttm_filter,
soft_fail,
expected_exception,
response_get_current,
response_exists,
kwargs,
expected_message,
):
_get_dttm_filter.return_value = []
get_current.return_value = response_get_current
exists.return_value = response_exists
get_dag_response = mock.MagicMock()
get_dag.return_value = get_dag_response
get_dag_response.has_task.return_value = False
get_dag_response.has_task_group.return_value = False
op = ExternalTaskSensor(
task_id="test_external_task_duplicate_task_ids",
external_dag_id=TEST_DAG_ID,
allowed_states=["success"],
dag=self.dag,
soft_fail=soft_fail,
check_existence=True,
**kwargs,
)
expected_message = "Skipping due to soft_fail is set to True." if soft_fail else expected_message
with pytest.raises(expected_exception, match=expected_message):
op.execute(context={})


class TestExternalTaskAsyncSensor:
TASK_ID = "external_task_sensor_check"
Expand Down

0 comments on commit 0fee730

Please sign in to comment.