From da2998d74302538acc8e7b99a7ef79b356af815d Mon Sep 17 00:00:00 2001 From: AutomationDev85 <96178949+AutomationDev85@users.noreply.github.com> Date: Wed, 16 Oct 2024 18:47:31 +0200 Subject: [PATCH] Added task_instance_mutation_hook for mapped operator index 0 (#42661) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added task_instance_mutation_hook for mapped operator index 0 * Added unit test --------- Co-authored-by: Marco Küttelwesch (cherry picked from commit b7007e2b146e6ef929a211925a3d4397b1e9955d) --- airflow/models/abstractoperator.py | 2 ++ tests/models/test_mappedoperator.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 5e5d13d5dc266..45eb3c5fff189 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -654,6 +654,8 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence unmapped_ti.map_index = 0 self.log.debug("Updated in place to become %s", unmapped_ti) all_expanded_tis.append(unmapped_ti) + # execute hook for task instance map index 0 + task_instance_mutation_hook(unmapped_ti) session.flush() else: self.log.debug("Deleting the original task instance: %s", unmapped_ti) diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 2b0cd50165c45..3b7eff19036fc 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -20,6 +20,7 @@ from collections import defaultdict from datetime import timedelta from typing import TYPE_CHECKING +from unittest import mock from unittest.mock import patch import pendulum @@ -716,6 +717,30 @@ def test_expand_mapped_task_instance_with_named_index( assert indices == expected_rendered_names +@pytest.mark.parametrize( + "create_mapped_task", + [ + pytest.param(_create_mapped_with_name_template_classic, id="classic"), + pytest.param(_create_mapped_with_name_template_taskflow, id="taskflow"), + ], +) +def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session, create_mapped_task) -> None: + """Test that the tast_instance_mutation_hook is called.""" + expected_map_index = [0, 1, 2] + + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").expand(arg2=task1.output) + + dr = dag_maker.create_dagrun() + + with mock.patch("airflow.settings.task_instance_mutation_hook") as mock_hook: + expand_mapped_task(mapped, dr.run_id, task1.task_id, length=len(expected_map_index), session=session) + + for index, call in enumerate(mock_hook.call_args_list): + assert call.args[0].map_index == expected_map_index[index] + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.parametrize( "map_index, expected",