diff --git a/README.md b/README.md index 64a3c8b3..697107d0 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,51 @@ consumer_dag: bread_type: 'Sourdough' ``` ![custom_operators.png](img/custom_operators.png) + +### Callbacks +**dag-factory** also supports using "callbacks" at the DAG, Task, and TaskGroup level. These callbacks can be defined in +a few different ways. The first points directly to a Python function that has been defined in the `include/callbacks.py` +file. + +```yaml +example_dag1: + on_failure_callback: include.callbacks.example_callback1 +... +``` + +Here, the `on_success_callback` points to first a file, and then to a function name within that file. Notice that this +callback is defined using `default_args`, meaning this callback will be applied to all tasks. + +```yaml +example_dag1: + ... + default_args: + on_success_callback_file: /usr/local/airflow/include/callbacks.py + on_success_callback_name: example_callback1 +``` + +**dag-factory** users can also leverage provider-built tools when configuring callbacks. In this example, the +`send_slack_notification` function from the Slack provider is used to dispatch a message when a DAG failure occurs. This +function is passed to the `callback` key under `on_failure_callback`. This pattern allows for callback definitions to +take parameters (such as `text`, `channel`, and `username`, as shown here). + +**Note that this functionality is currently only supported for `on_failure_callback`'s defined at the DAG-level, or in +`default_args`. Support for other callback types and Task/TaskGroup-level definitions are coming soon.** + +```yaml +example_dag1: + on_failure_callback: + callback: airflow.providers.slack.notifications.slack.send_slack_notification + slack_conn_id: example_slack_id + text: | + :red_circle: Task Failed. + This task has failed and needs to be addressed. + Please remediate this issue ASAP. + channel: analytics-alerts + username: Airflow +... +``` + ## Notes ### HttpSensor (since 1.0.0) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index e50b3abf..b58f9d7c 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -5,6 +5,7 @@ import re from copy import deepcopy from datetime import datetime, timedelta +from functools import partial from typing import Any, Callable, Dict, List, Union from airflow import DAG, configuration @@ -178,10 +179,9 @@ def get_dag_params(self) -> Dict[str, Any]: ) if utils.check_dict_key(dag_params["default_args"], "on_failure_callback"): - if isinstance(dag_params["default_args"]["on_failure_callback"], str): - dag_params["default_args"]["on_failure_callback"]: Callable = import_string( - dag_params["default_args"]["on_failure_callback"] - ) + dag_params["default_args"]["on_failure_callback"]: Callable = self.set_callback( + parameters=dag_params["default_args"], callback_type="on_failure_callback" + ) if utils.check_dict_key(dag_params["default_args"], "on_retry_callback"): if isinstance(dag_params["default_args"]["on_retry_callback"], str): @@ -198,8 +198,9 @@ def get_dag_params(self) -> Dict[str, Any]: dag_params["on_success_callback"]: Callable = import_string(dag_params["on_success_callback"]) if utils.check_dict_key(dag_params, "on_failure_callback"): - if isinstance(dag_params["on_failure_callback"], str): - dag_params["on_failure_callback"]: Callable = import_string(dag_params["on_failure_callback"]) + dag_params["on_failure_callback"]: Callable = self.set_callback( + parameters=dag_params, callback_type="on_failure_callback" + ) if utils.check_dict_key(dag_params, "on_success_callback_name") and utils.check_dict_key( dag_params, "on_success_callback_file" @@ -212,9 +213,8 @@ def get_dag_params(self) -> Dict[str, Any]: if utils.check_dict_key(dag_params, "on_failure_callback_name") and utils.check_dict_key( dag_params, "on_failure_callback_file" ): - dag_params["on_failure_callback"]: Callable = utils.get_python_callable( - dag_params["on_failure_callback_name"], - dag_params["on_failure_callback_file"], + dag_params["on_failure_callback"] = self.set_callback( + parameters=dag_params, callback_type="on_failure_callback", has_name_and_file=True ) if utils.check_dict_key(dag_params["default_args"], "on_success_callback_name") and utils.check_dict_key( @@ -229,10 +229,8 @@ def get_dag_params(self) -> Dict[str, Any]: if utils.check_dict_key(dag_params["default_args"], "on_failure_callback_name") and utils.check_dict_key( dag_params["default_args"], "on_failure_callback_file" ): - - dag_params["default_args"]["on_failure_callback"]: Callable = utils.get_python_callable( - dag_params["default_args"]["on_failure_callback_name"], - dag_params["default_args"]["on_failure_callback_file"], + dag_params["default_args"]["on_failure_callback"] = self.set_callback( + parameters=dag_params["default_args"], callback_type="on_failure_callback", has_name_and_file=True ) if utils.check_dict_key(dag_params, "template_searchpath"): @@ -805,3 +803,51 @@ def build(self) -> Dict[str, Union[str, DAG]]: self.set_dependencies(tasks, tasks_dict, dag_params.get("task_groups", {}), task_groups_dict) return {"dag_id": dag_params["dag_id"], "dag": dag} + + @staticmethod + def set_callback(parameters: Union[dict, str], callback_type: str, has_name_and_file=False) -> Callable: + """ + Update the passed-in config with the callback. + + :param parameters: + :param callback_type: + :param has_name_and_file: + :returns: Callable + """ + # There is scenario where a callback is passed in via a file and a name. For the most part, this will be a + # Python callable that is treated similarly to a Python callable that the PythonOperator may leverage. That + # being said, what if this is not a Python callable? What if this is another type? + if has_name_and_file: + return utils.get_python_callable( + python_callable_name=parameters[f"{callback_type}_name"], + python_callable_file=parameters[f"{callback_type}_file"], + ) + + # If the value stored at parameters[callback_type] is a string, it should be imported under the assumption that + # it is a function that is "ready to be called". If not returning the function, something like this could be + # used to update the config parameters[callback_type] = import_string(parameters[callback_type]) + if isinstance(parameters[callback_type], str): + return import_string(parameters[callback_type]) + + # Otherwise, if the parameter[callback_type] is a dictionary, it should be treated similar to the Python + # callable + elif isinstance(parameters[callback_type], dict): + # Pull the on_failure_callback dictionary from dag_params + on_state_callback_params: dict = parameters[callback_type] + + # Check to see if there is a "callback" key in the on_failure_callback dictionary. If there is, parse + # out that callable, and add the parameters + if utils.check_dict_key(on_state_callback_params, "callback"): + if isinstance(on_state_callback_params["callback"], str): + on_state_callback_callable: Callable = import_string(on_state_callback_params["callback"]) + del on_state_callback_params["callback"] + + # Return the callable, this time, using the params provided in the YAML file, rather than a .py + # file with a callable configured. If not returning the partial, something like this could be used + # to update the config ... parameters[callback_type]: Callable = partial(...) + if hasattr(on_state_callback_callable, "notify"): + return on_state_callback_callable(**on_state_callback_params) + + return partial(on_state_callback_callable, **on_state_callback_params) + + raise DagFactoryConfigException(f"Invalid type passed to {callback_type}") diff --git a/dev/dags/customized/callables/__init__.py b/dev/dags/customized/callables/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dev/dags/customized/callables/python.py b/dev/dags/customized/callables/python.py new file mode 100644 index 00000000..8c4a73ea --- /dev/null +++ b/dev/dags/customized/callables/python.py @@ -0,0 +1,16 @@ +""" +failure.py + +Create a callable that intentionally "fails". + +Author: Jake Roach +Date: 2024-10-22 +""" + + +def succeeding_task(): + print("Task has executed successfully!") + + +def failing_task(): + raise Exception("Intentionally failing this Task to trigger on_failure_callback.") diff --git a/dev/dags/customized/callbacks/__init__.py b/dev/dags/customized/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dev/dags/customized/callbacks/custom_callbacks.py b/dev/dags/customized/callbacks/custom_callbacks.py new file mode 100644 index 00000000..b94d7989 --- /dev/null +++ b/dev/dags/customized/callbacks/custom_callbacks.py @@ -0,0 +1,11 @@ +""" +example_callbacks.py + +Author: Jake Roach +Date: 2024-10-22 +""" + + +def output_message(context, param1, param2): + print("A callback has been raised!") + print(f"{param1} ---------- {param2}") diff --git a/dev/dags/example_callbacks.py b/dev/dags/example_callbacks.py new file mode 100644 index 00000000..030fdbbb --- /dev/null +++ b/dev/dags/example_callbacks.py @@ -0,0 +1,17 @@ +import os +from pathlib import Path + +# The following import is here so Airflow parses this file +# from airflow import DAG +import dagfactory + +DEFAULT_CONFIG_ROOT_DIR = "/usr/local/airflow/dags/" +CONFIG_ROOT_DIR = Path(os.getenv("CONFIG_ROOT_DIR", DEFAULT_CONFIG_ROOT_DIR)) + +config_file = str(CONFIG_ROOT_DIR / "example_callbacks.yml") + +example_dag_factory = dagfactory.DagFactory(config_file) + +# Creating task dependencies +example_dag_factory.clean_dags(globals()) +example_dag_factory.generate_dags(globals()) diff --git a/dev/dags/example_callbacks.yml b/dev/dags/example_callbacks.yml new file mode 100644 index 00000000..097c76d6 --- /dev/null +++ b/dev/dags/example_callbacks.yml @@ -0,0 +1,28 @@ +example_callbacks: + default_args: + start_date: "2024-01-01" + on_failure_callback: + callback: airflow.providers.slack.notifications.slack.send_slack_notification + slack_conn_id: slack_conn_id + text: | + :red_circle: Task Failed. + This task has failed and needs to be addressed. + Please remediate this issue ASAP. + channel: "#channel" + schedule_interval: "@daily" + catchup: False + on_failure_callback: + callback: customized.callbacks.custom_callbacks.output_message + param1: param1 + param2: param2 + tasks: + start: + operator: airflow.operators.python.PythonOperator + python_callable_file: $CONFIG_ROOT_DIR/customized/callables/python.py + python_callable_name: succeeding_task + end: + operator: airflow.operators.python.PythonOperator + python_callable_file: $CONFIG_ROOT_DIR/customized/callables/python.py + python_callable_name: failing_task + dependencies: + - start diff --git a/dev/requirements.txt b/dev/requirements.txt index 1bb359bb..36aeee7c 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -1 +1,2 @@ # Astro Runtime includes the following pre-installed providers packages: https://www.astronomer.io/docs/astro/runtime-image-architecture#provider-packages +apache-airflow-providers-slack diff --git a/pyproject.toml b/pyproject.toml index 94ed96f5..bb2a1e0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ [project.optional-dependencies] tests = [ + "apache-airflow-providers-slack", "pytest>=6.0", "pytest-cov", "pre-commit" @@ -95,7 +96,7 @@ universal = true [tool.pytest.ini_options] filterwarnings = ["ignore::DeprecationWarning"] minversion = "6.0" -markers = ["integration"] +markers = ["integration", "callbacks"] ###################################### # THIRD PARTY TOOLS diff --git a/tests/test_dagbuilder.py b/tests/test_dagbuilder.py index f777aa6d..7081e360 100644 --- a/tests/test_dagbuilder.py +++ b/tests/test_dagbuilder.py @@ -1,4 +1,5 @@ import datetime +import functools import os from pathlib import Path from unittest.mock import mock_open, patch @@ -261,6 +262,60 @@ }, } +# Alternative way to define callbacks (only "on_failure_callbacks" for now, more to come) +DAG_CONFIG_CALLBACK_WITH_PARAMETERS = { + "doc_md": "##here is a doc md string", + "default_args": { + "owner": "custom_owner", + "on_failure_callback": { + "callback": f"{__name__}.empty_callback_with_params", + "param_1": "value_1", + "param_2": "value_2", + }, + }, + "description": "this is an example dag", + "schedule_interval": "0 3 * * *", + "tags": ["tag1", "tag2"], + "on_failure_callback": { + "callback": f"{__name__}.empty_callback_with_params", + "param_1": "value_1", + "param_2": "value_2", + }, + "tasks": { + "task_1": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 1", + "execution_timeout_secs": 5, + }, + }, +} + +DAG_CONFIG_PROVIDER_CALLBACK_WITH_PARAMETERS = { + "doc_md": "##here is a doc md string", + "default_args": { + "owner": "custom_owner", + "on_failure_callback": { + "callback": "airflow.providers.slack.notifications.slack.send_slack_notification", + "slack_conn_id": "slack_conn_id", + "text": f""" + Sample, multi-line callback text. + """, + "channel": "#channel", + "username": "username" + }, + }, + "description": "this is an example dag", + "schedule_interval": "0 3 * * *", + "tags": ["tag1", "tag2"], + "tasks": { + "task_1": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 1", + "execution_timeout_secs": 5, + }, + }, +} + UTC = pendulum.timezone("UTC") DAG_CONFIG_TASK_GROUP_WITH_CALLBACKS = { @@ -711,6 +766,12 @@ def print_context_callback(context, **kwargs): print(context) +def empty_callback_with_params(context, param_1, param_2, **kwargs): + # Context is the first parameter passed into the callback + print(param_1) + print(param_2) + + def test_make_task_with_callback(): td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG) operator = "airflow.operators.python_operator.PythonOperator" @@ -734,6 +795,7 @@ def test_make_task_with_callback(): assert callable(actual.on_retry_callback) +@pytest.mark.callbacks def test_dag_with_callback_name_and_file(): td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK_NAME_AND_FILE, DEFAULT_CONFIG) dag = td.build().get("dag") @@ -754,6 +816,7 @@ def test_dag_with_callback_name_and_file(): assert not callable(td_task.on_failure_callback) +@pytest.mark.callbacks def test_dag_with_callback_name_and_file_default_args(): td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK_NAME_AND_FILE_DEFAULT_ARGS, DEFAULT_CONFIG) dag = td.build().get("dag") @@ -793,6 +856,67 @@ def test_make_dag_with_callback(): td.build() +@pytest.mark.callbacks +@pytest.mark.parametrize( + "callback_type,in_default_args", [("on_failure_callback", False), ("on_failure_callback", True)] +) +def test_dag_with_on_callback_str(callback_type, in_default_args): + # Using a different config (DAG_CONFIG_CALLBACK) than below + td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK, DEFAULT_CONFIG) + td.build() + + config_obj = td.dag_config.get("default_args") if in_default_args else td.dag_config + + # Validate the .set_callback() method works as expected when importing a string, + assert callback_type in config_obj + assert callable(config_obj.get(callback_type)) + assert config_obj.get(callback_type).__name__ == "print_context_callback" + + +@pytest.mark.callbacks +@pytest.mark.parametrize( + "callback_type,in_default_args", [("on_failure_callback", False), ("on_failure_callback", True)] +) +def test_dag_with_on_callback_and_params(callback_type, in_default_args): + # Import the DAG using the callback config that was build above + td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK_WITH_PARAMETERS, DEFAULT_CONFIG) + td.build() + + config_obj = td.dag_config.get("default_args") if in_default_args else td.dag_config + + # Check to see if callback_type is in the DAG config, and the type of value that is returned, pull the callback + assert callback_type in config_obj + on_callback: functools.partial = config_obj.get(callback_type) + + assert isinstance(on_callback, functools.partial) + assert callable(on_callback) + assert on_callback.func.__name__ == "empty_callback_with_params" + + # Parameters + assert "param_1" in on_callback.keywords + assert on_callback.keywords.get("param_1") == "value_1" + assert "param_2" in on_callback.keywords + assert on_callback.keywords.get("param_2") == "value_2" + + +@pytest.mark.callbacks +def test_dag_with_provider_callback(): + if version.parse(AIRFLOW_VERSION) >= version.parse("2.6.0"): + td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_PROVIDER_CALLBACK_WITH_PARAMETERS, DEFAULT_CONFIG) + td.build() + + # Check to see if the on_failure_callback exists and that it's a callback + assert td.dag_config.get("default_args").get("on_failure_callback") + + on_failure_callback = td.dag_config.get("default_args").get("on_failure_callback") + assert callable(on_failure_callback) + + # Check values + assert on_failure_callback.slack_conn_id == "slack_conn_id" + assert on_failure_callback.channel == "#channel" + assert on_failure_callback.username == "username" + + def test_get_dag_params_with_template_searchpath(): from dagfactory import utils diff --git a/tests/test_example_dags.py b/tests/test_example_dags.py index 11b23088..759e15fd 100644 --- a/tests/test_example_dags.py +++ b/tests/test_example_dags.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from pathlib import Path try: @@ -16,10 +17,12 @@ from . import utils as test_utils -EXAMPLE_DAGS_DIR = Path(__file__).parent.parent / "examples" +EXAMPLE_DAGS_DIR = Path(__file__).parent.parent / "dev/dags" AIRFLOW_IGNORE_FILE = EXAMPLE_DAGS_DIR / ".airflowignore" AIRFLOW_VERSION = Version(airflow.__version__) -IGNORED_DAG_FILES = [] +IGNORED_DAG_FILES = [ + "example_callbacks.py" +] MIN_VER_DAG_FILE_VER: dict[str, list[str]] = { "2.3": ["example_dynamic_task_mapping.py"], @@ -51,9 +54,11 @@ def get_dag_bag() -> DagBag: print(f"Adding {dagfile} to .airflowignore") file.writelines([f"{dagfile}\n"]) + # Print the contents of the .airflowignore file, and build the DagBag print(".airflowignore contents: ") print(AIRFLOW_IGNORE_FILE.read_text()) db = DagBag(EXAMPLE_DAGS_DIR, include_examples=False) + assert db.dags assert not db.import_errors return db