Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle db isolation for mapped operators and task groups #39259

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from flask import Response

from airflow.jobs.job import Job, most_recent_job
from airflow.models.taskinstance import _record_task_map_for_downstreams
from airflow.models.xcom_arg import _get_task_map_length
from airflow.sensors.base import _orig_start_date
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.session import create_session
Expand Down Expand Up @@ -66,12 +68,14 @@ def _initialize_map() -> dict[str, Callable]:
_defer_task,
_get_template_context,
_get_ti_db_access,
_get_task_map_length,
_update_rtif,
_orig_start_date,
_handle_failure,
_handle_reschedule,
_add_log,
_xcom_pull,
_record_task_map_for_downstreams,
DagFileProcessor.update_import_errors,
DagFileProcessor.manage_slas,
DagFileProcessorManager.deactivate_stale_dags,
Expand Down
33 changes: 28 additions & 5 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,14 @@ def _execute_callable(context: Context, **execute_callable_kwargs):
for key, value in xcom_value.items():
task_instance.xcom_push(key=key, value=value, session=session_or_null)
task_instance.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, session=session_or_null)
if TYPE_CHECKING:
assert task_orig.dag
_record_task_map_for_downstreams(
task_instance=task_instance, task=task_orig, value=xcom_value, session=session_or_null
task_instance=task_instance,
task=task_orig,
dag=task_orig.dag,
value=xcom_value,
session=session_or_null,
)
return result

Expand Down Expand Up @@ -1253,25 +1259,40 @@ def _refresh_from_task(
task_instance_mutation_hook(task_instance)


@internal_api_call
@provide_session
def _record_task_map_for_downstreams(
*, task_instance: TaskInstance | TaskInstancePydantic, task: Operator, value: Any, session: Session
*,
task_instance: TaskInstance | TaskInstancePydantic,
task: Operator,
dag: DAG,
value: Any,
session: Session,
) -> None:
"""
Record the task map for downstream tasks.

:param task_instance: the task instance
:param task: The task object
:param dag: the dag associated with the task
:param value: The value
:param session: SQLAlchemy ORM Session

:meta private:
"""
# when taking task over RPC, we need to add the dag back
if isinstance(task, MappedOperator):
if not task.dag:
task.dag = dag
elif not task._dag:
task._dag = dag
dstandish marked this conversation as resolved.
Show resolved Hide resolved

if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
return
# TODO: We don't push TaskMap for mapped task instances because it's not
# currently possible for a downstream to depend on one individual mapped
# task instance. This will change when we implement task mapping inside
# a mapped task group, and we'll need to further analyze the case.
# currently possible for a downstream to depend on one individual mapped
# task instance. This will change when we implement task mapping inside
# a mapped task group, and we'll need to further analyze the case.
dstandish marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(task, MappedOperator):
return
if value is None:
Expand Down Expand Up @@ -3339,6 +3360,8 @@ def render_templates(
# MappedOperator is useless for template rendering, and we need to be
# able to access the unmapped task instead.
original_task.render_template_fields(context, jinja_env)
if isinstance(self.task, MappedOperator):
self.task = context["ti"].task
dstandish marked this conversation as resolved.
Show resolved Hide resolved
dstandish marked this conversation as resolved.
Show resolved Hide resolved

return original_task

Expand Down
100 changes: 58 additions & 42 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

from sqlalchemy import func, or_, select

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import AirflowException, XComNotFound
from airflow.models import MappedOperator, TaskInstance
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskmixin import DependencyMixin
from airflow.utils.db import exists_query
from airflow.utils.mixins import ResolveMixin
Expand Down Expand Up @@ -222,6 +223,53 @@ def __exit__(self, exc_type, exc_val, exc_tb):
SetupTeardownContext.set_work_task_roots_and_leaves()


@internal_api_call
@provide_session
def _get_task_map_length(
*,
dag_id: str,
task_id: str,
run_id: str,
is_mapped: bool,
session: Session = NEW_SESSION,
) -> int | None:
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCom

if is_mapped:
unfinished_ti_exists = exists_query(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id == task_id,
# Special NULL treatment is needed because 'state' can be NULL.
# The "IN" part would produce "NULL NOT IN ..." and eventually
# "NULl = NULL", which is a big no-no in SQL.
or_(
TaskInstance.state.is_(None),
TaskInstance.state.in_(s.value for s in State.unfinished if s is not None),
),
session=session,
)
if unfinished_ti_exists:
return None # Not all of the expanded tis are done yet.
query = select(func.count(XCom.map_index)).where(
XCom.dag_id == dag_id,
XCom.run_id == run_id,
XCom.task_id == task_id,
XCom.map_index >= 0,
XCom.key == XCOM_RETURN_KEY,
)
else:
query = select(TaskMap.length).where(
TaskMap.dag_id == dag_id,
TaskMap.run_id == run_id,
TaskMap.task_id == task_id,
TaskMap.map_index < 0,
)
return session.scalar(query)


class PlainXComArg(XComArg):
"""Reference to one single XCom without any additional semantics.

Expand Down Expand Up @@ -364,51 +412,19 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
return super().zip(*others, fillvalue=fillvalue)

def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCom

task = self.operator
if isinstance(task, MappedOperator):
unfinished_ti_exists = exists_query(
TaskInstance.dag_id == task.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id == task.task_id,
# Special NULL treatment is needed because 'state' can be NULL.
# The "IN" part would produce "NULL NOT IN ..." and eventually
# "NULl = NULL", which is a big no-no in SQL.
or_(
TaskInstance.state.is_(None),
TaskInstance.state.in_(s.value for s in State.unfinished if s is not None),
),
session=session,
)
if unfinished_ti_exists:
return None # Not all of the expanded tis are done yet.
query = select(func.count(XCom.map_index)).where(
XCom.dag_id == task.dag_id,
XCom.run_id == run_id,
XCom.task_id == task.task_id,
XCom.map_index >= 0,
XCom.key == XCOM_RETURN_KEY,
)
else:
query = select(TaskMap.length).where(
TaskMap.dag_id == task.dag_id,
TaskMap.run_id == run_id,
TaskMap.task_id == task.task_id,
TaskMap.map_index < 0,
)
return session.scalar(query)
return _get_task_map_length(
dag_id=self.operator.dag_id,
task_id=self.operator.task_id,
is_mapped=isinstance(self.operator, MappedOperator),
run_id=run_id,
session=session,
)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
from airflow.models.taskinstance import TaskInstance

ti = context["ti"]
if not isinstance(ti, TaskInstance):
raise NotImplementedError("Wait for AIP-44 implementation to complete")

if TYPE_CHECKING:
assert isinstance(ti, TaskInstance)
task_id = self.operator.task_id
map_indexes = ti.get_relevant_upstream_map_indexes(
self.operator,
Expand Down