diff --git a/tests/test_dagbuilder.py b/tests/test_dagbuilder.py index eeeb60ab..d4b0b872 100644 --- a/tests/test_dagbuilder.py +++ b/tests/test_dagbuilder.py @@ -246,14 +246,7 @@ "param_2": "value_2", }, "on_failure_callback_name": "print_context_callback", - "on_failure_callback_file": __file__, - "on_retry_callback": { - "callback": "airflow.providers.slack.notifications.slack.send_slack_notification", - "slack_conn_id": "slack_conn_id", - "text": f"""Sample callback text.""", - "channel": "#channel", - "username": "username", - }, + "on_failure_callback_file": __file__ }, }, } @@ -684,7 +677,6 @@ def test_make_dag_with_callbacks(): if version.parse(AIRFLOW_VERSION) >= version.parse("2.6.0"): from airflow.providers.slack.notifications.slack import send_slack_notification - # TODO: Do something like this, but with TaskGroups (for Airflow versioning reasons...) dag_config_callbacks__with_provider = dict(DAG_CONFIG_CALLBACKS) dag_config_callbacks__with_provider["sla_miss_callback"] = { "callback": "airflow.providers.slack.notifications.slack.send_slack_notification", @@ -760,12 +752,13 @@ def test_make_dag_with_task_group_callbacks(): 3) There is a TaskGroup configured as part of the DAG, which has Tasks assigned to that group """ - # Import the DAG using the callback config that was build above. Like in the XXX test, we'll check the version of - # Airflow to ensure that TaskGroups can be built. + # Import the DAG using the callback config that was build above td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_TASK_GROUP_WITH_CALLBACKS, DEFAULT_CONFIG) + # This will be done only once; validate the exception that is raised if trying to use an invalid version of Airflow + # when building TaskGroups if version.parse(AIRFLOW_VERSION) < version.parse("2.2.0"): - error_message = "`task_groups` key can only be used with Airflow 2.x.x" # TODO: Apply this elsewhere + error_message = "`task_groups` key can only be used with Airflow 2.x.x" with pytest.raises(Exception, match=error_message): td.build() else: @@ -824,12 +817,22 @@ def test_dag_with_task_group_callbacks_tasks(): """ test_dag_with_task_group_callbacks_tasks """ - # Building the same DAG that was used for test_make_dag_with_task_group_callbacks and - # test_dag_with_task_group_callbacks_default_args - td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_TASK_GROUP_WITH_CALLBACKS, DEFAULT_CONFIG) + # Here, we're only going to build if the version is >= 2.6.0, and, we'll supplement task_4 with an additional + # callback (on_retry_callback) + if version.parse(AIRFLOW_VERSION) >= version.parse("2.6.0"): + from airflow.providers.slack.notifications.slack import send_slack_notification - if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"): - dag = td.build()["dag"] # Also, pull the dag + dag_config_callbacks__with_provider = dict(DAG_CONFIG_TASK_GROUP_WITH_CALLBACKS) + dag_config_callbacks__with_provider["tasks"]["task_4"]["on_retry_callback"] = { + "callback": "airflow.providers.slack.notifications.slack.send_slack_notification", + "slack_conn_id": "slack_conn_id", + "text": f"""Sample callback text.""", + "channel": "#channel", + "username": "username", + } + + td = dagbuilder.DagBuilder("test_dag", dag_config_callbacks__with_provider, DEFAULT_CONFIG) + dag = td.build()["dag"] task_4 = dag.task_dict["task_4"] assert callable(task_4.on_execute_callback) @@ -844,14 +847,11 @@ def test_dag_with_task_group_callbacks_tasks(): assert callable(task_4.on_failure_callback) assert task_4.on_failure_callback.__name__ == "print_context_callback" - if version.parse(AIRFLOW_VERSION) >= version.parse("2.6.0"): - from airflow.providers.slack.notifications.slack import send_slack_notification - - assert isinstance(task_4.on_retry_callback, send_slack_notification) - assert callable(task_4.on_retry_callback) - assert task_4.on_retry_callback.slack_conn_id == "slack_conn_id" - assert task_4.on_retry_callback.channel == "#channel" - assert task_4.on_retry_callback.username == "username" + assert isinstance(task_4.on_retry_callback, send_slack_notification) + assert callable(task_4.on_retry_callback) + assert task_4.on_retry_callback.slack_conn_id == "slack_conn_id" + assert task_4.on_retry_callback.channel == "#channel" + assert task_4.on_retry_callback.username == "username" def test_make_timetable():