Skip to content

Commit

Permalink
Added task_instance_mutation_hook for mapped operator index 0 (apache…
Browse files Browse the repository at this point in the history
…#42661)

* Added task_instance_mutation_hook for mapped operator index 0

* Added unit test

---------

Co-authored-by: Marco Küttelwesch <[email protected]>
  • Loading branch information
2 people authored and harjeevanmaan committed Oct 23, 2024
1 parent 49a2c67 commit b223714
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
2 changes: 2 additions & 0 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,8 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
unmapped_ti.map_index = 0
self.log.debug("Updated in place to become %s", unmapped_ti)
all_expanded_tis.append(unmapped_ti)
# execute hook for task instance map index 0
task_instance_mutation_hook(unmapped_ti)
session.flush()
else:
self.log.debug("Deleting the original task instance: %s", unmapped_ti)
Expand Down
25 changes: 25 additions & 0 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from collections import defaultdict
from datetime import timedelta
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import patch

import pendulum
Expand Down Expand Up @@ -730,6 +731,30 @@ def test_expand_mapped_task_instance_with_named_index(
assert indices == expected_rendered_names


@pytest.mark.parametrize(
"create_mapped_task",
[
pytest.param(_create_mapped_with_name_template_classic, id="classic"),
pytest.param(_create_mapped_with_name_template_taskflow, id="taskflow"),
],
)
def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session, create_mapped_task) -> None:
"""Test that the tast_instance_mutation_hook is called."""
expected_map_index = [0, 1, 2]

with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id="task_2").expand(arg2=task1.output)

dr = dag_maker.create_dagrun()

with mock.patch("airflow.settings.task_instance_mutation_hook") as mock_hook:
expand_mapped_task(mapped, dr.run_id, task1.task_id, length=len(expected_map_index), session=session)

for index, call in enumerate(mock_hook.call_args_list):
assert call.args[0].map_index == expected_map_index[index]


@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
@pytest.mark.parametrize(
"map_index, expected",
Expand Down

0 comments on commit b223714

Please sign in to comment.