Skip to content

Commit

Permalink
Give up on trying to recreate task_id logic (#22794)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 1d141a2a69fa8c6b738732b31fde87a9a905d46d
  • Loading branch information
uranusjr authored and Cloud Composer Team committed Sep 12, 2024
1 parent 0089318 commit ce8f917
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 19 deletions.
6 changes: 2 additions & 4 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,7 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg:
)
partial_kwargs.update(task_kwargs)

user_supplied_task_id = partial_kwargs.pop("task_id")
task_id = get_unique_task_id(user_supplied_task_id, dag, task_group)
task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group)
params = partial_kwargs.pop("params", None) or default_params

# Logic here should be kept in sync with BaseOperatorMeta.partial().
Expand All @@ -349,7 +348,6 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg:
_MappedOperator = cast(Any, DecoratedMappedOperator)
operator = _MappedOperator(
operator_class=self.operator_class,
user_supplied_task_id=user_supplied_task_id,
mapped_kwargs={},
partial_kwargs=partial_kwargs,
task_id=task_id,
Expand Down Expand Up @@ -432,7 +430,7 @@ def _get_unmap_kwargs(self) -> Dict[str, Any]:
return {
"dag": self.dag,
"task_group": self.task_group,
"task_id": self.user_supplied_task_id,
"task_id": self.task_id,
"op_kwargs": op_kwargs,
"multiple_outputs": self.multiple_outputs,
"python_callable": self.python_callable,
Expand Down
7 changes: 1 addition & 6 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ def partial(
from airflow.utils.task_group import TaskGroupContext

validate_mapping_kwargs(operator_class, "partial", kwargs)
user_supplied_task_id = task_id

dag = dag or DagContext.get_current_dag()
if dag:
Expand Down Expand Up @@ -286,11 +285,7 @@ def partial(
partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {}
partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"])

return OperatorPartial(
operator_class=operator_class,
user_supplied_task_id=user_supplied_task_id,
kwargs=partial_kwargs,
)
return OperatorPartial(operator_class=operator_class, kwargs=partial_kwargs)


class BaseOperatorMeta(abc.ABCMeta):
Expand Down
13 changes: 8 additions & 5 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ class OperatorPartial:
"""

operator_class: Type["BaseOperator"]
user_supplied_task_id: str
kwargs: Dict[str, Any]

_expand_called: bool = False # Set when expand() is called to ease user debugging.
Expand Down Expand Up @@ -207,7 +206,6 @@ def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":

op = MappedOperator(
operator_class=self.operator_class,
user_supplied_task_id=self.user_supplied_task_id,
mapped_kwargs=mapped_kwargs,
partial_kwargs=partial_kwargs,
task_id=task_id,
Expand Down Expand Up @@ -244,7 +242,6 @@ class MappedOperator(AbstractOperator):
# that can be used to unmap this into a SerializedBaseOperator.
operator_class: Union[Type["BaseOperator"], Dict[str, Any]]

user_supplied_task_id: str # This is the task_id supplied by the user.
mapped_kwargs: Dict[str, "Mappable"]
partial_kwargs: Dict[str, Any]

Expand Down Expand Up @@ -469,7 +466,7 @@ def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]:

def _get_unmap_kwargs(self) -> Dict[str, Any]:
return {
"task_id": self.user_supplied_task_id,
"task_id": self.task_id,
"dag": self.dag,
"task_group": self.task_group,
"params": self.params,
Expand All @@ -482,7 +479,13 @@ def _get_unmap_kwargs(self) -> Dict[str, Any]:
def unmap(self) -> "BaseOperator":
"""Get the "normal" Operator after applying the current mapping."""
if isinstance(self.operator_class, type):
return self.operator_class(**self._get_unmap_kwargs(), _airflow_from_mapped=True)
# We can't simply specify task_id here because BaseOperator further
# mangles the task_id based on the task hierarchy (namely, group_id
# is prepended, and '__N' appended to deduplicate). Instead of
# recreating the whole logic here, we just overwrite task_id later.
op = self.operator_class(**self._get_unmap_kwargs(), _airflow_from_mapped=True)
op.task_id = self.task_id
return op

# After a mapped operator is serialized, there's no real way to actually
# unmap it since we've lost access to the underlying operator class.
Expand Down
1 change: 0 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,6 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Operator:
mapped_kwargs={},
partial_kwargs={},
task_id=encoded_op["task_id"],
user_supplied_task_id=encoded_op["user_supplied_task_id"],
params={},
deps=MappedOperator.deps_for(BaseOperator),
operator_extra_links=BaseOperator.operator_extra_links,
Expand Down
35 changes: 35 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
Variable,
XCom,
)
from airflow.models.taskfail import TaskFail
from airflow.models.taskinstance import TaskInstance, load_error_file, set_error_file
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCOM_RETURN_KEY
Expand Down Expand Up @@ -1365,6 +1366,40 @@ def test_email_alert_with_config(self, mock_send_email, dag_maker):
assert 'template: test_email_alert_with_config' == title
assert 'template: test_email_alert_with_config' == body

@pytest.mark.parametrize("task_id", ["test_email_alert", "test_email_alert__1"])
@patch('airflow.models.taskinstance.send_email')
def test_failure_mapped_taskflow(self, mock_send_email, dag_maker, session, task_id):
with dag_maker(session=session) as dag:

@dag.task(email='to')
def test_email_alert(x):
raise RuntimeError("Fail please")

test_email_alert.expand(x=["a", "b"]) # This is 'test_email_alert'.
test_email_alert.expand(x=[1, 2, 3]) # This is 'test_email_alert__1'.

dr: DagRun = dag_maker.create_dagrun(execution_date=timezone.utcnow())

ti = dr.get_task_instance(task_id, map_index=0, session=session)
assert ti is not None

# The task will fail and trigger email reporting.
with pytest.raises(RuntimeError, match=r"^Fail please$"):
ti.run(session=session)

(email, title, body), _ = mock_send_email.call_args
assert email == "to"
assert title == f"Airflow alert: <TaskInstance: test_dag.{task_id} test map_index=0 [failed]>"
assert body.startswith("Try 1")
assert "test_email_alert" in body

tf = (
session.query(TaskFail)
.filter_by(dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index)
.one_or_none()
)
assert tf, "TaskFail was recorded"

def test_set_duration(self):
task = DummyOperator(task_id='op', email='[email protected]')
ti = TI(task=task)
Expand Down
3 changes: 0 additions & 3 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,7 +1588,6 @@ def test_mapped_operator_serde():
'template_fields_renderers': {'bash_command': 'bash', 'env': 'json'},
'ui_color': '#f0ede4',
'ui_fgcolor': '#000',
'user_supplied_task_id': 'a',
'_expansion_kwargs_attr': 'mapped_kwargs',
}

Expand Down Expand Up @@ -1633,7 +1632,6 @@ def test_mapped_operator_xcomarg_serde():
'operator_extra_links': [],
'ui_color': '#fff',
'ui_fgcolor': '#000',
'user_supplied_task_id': 'task_2',
'_expansion_kwargs_attr': 'mapped_kwargs',
}

Expand Down Expand Up @@ -1721,7 +1719,6 @@ def x(arg1, arg2, arg3):
'template_ext': [],
'template_fields': ['op_args', 'op_kwargs'],
'template_fields_renderers': {"op_args": "py", "op_kwargs": "py"},
'user_supplied_task_id': 'x',
'_expansion_kwargs_attr': 'mapped_op_kwargs',
}

Expand Down

0 comments on commit ce8f917

Please sign in to comment.