Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sentinel to mark dag as removed on reserialization #39825

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET
from airflow.utils.types import ATTRIBUTE_REMOVED, NOTSET
from airflow.utils.xcom import XCOM_RETURN_KEY

if TYPE_CHECKING:
Expand Down Expand Up @@ -1245,11 +1245,21 @@ def dag(self) -> DAG: # type: ignore[override]
@dag.setter
def dag(self, dag: DAG | None):
"""Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok."""
from airflow.models.dag import DAG

if dag is None:
self._dag = None
return

# if set to removed, then just set and exit
if self._dag is ATTRIBUTE_REMOVED:
self._dag = dag
return
# if setting to removed, then just set and exit
if dag is ATTRIBUTE_REMOVED:
self._dag = ATTRIBUTE_REMOVED # type: ignore[assignment]
return

from airflow.models.dag import DAG

if not isinstance(dag, DAG):
raise TypeError(f"Expected DAG; received {dag.__class__.__name__}")
elif self.has_dag() and self.dag is not dag:
Expand Down
47 changes: 20 additions & 27 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.task_instance_session import set_current_task_instance_session
from airflow.utils.timeout import timeout
from airflow.utils.types import ATTRIBUTE_REMOVED
from airflow.utils.xcom import XCOM_RETURN_KEY

TR = TaskReschedule
Expand Down Expand Up @@ -902,13 +903,15 @@ def _clear_next_method_args(*, task_instance: TaskInstance | TaskInstancePydanti
def _get_template_context(
*,
task_instance: TaskInstance | TaskInstancePydantic,
dag: DAG,
session: Session | None = None,
ignore_param_exceptions: bool = True,
) -> Context:
"""
Return TI Context.

:param task_instance: the task instance
:param task_instance: the task instance for the task
:param dag for the task
:param session: SQLAlchemy ORM Session
:param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict

Expand All @@ -928,27 +931,10 @@ def _get_template_context(
assert task_instance.task
assert task
assert task.dag
try:
dag: DAG = task.dag
except AirflowException:
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

if isinstance(task_instance, TaskInstancePydantic):
ti = session.scalar(
select(TaskInstance).where(
TaskInstance.task_id == task_instance.task_id,
TaskInstance.dag_id == task_instance.dag_id,
TaskInstance.run_id == task_instance.run_id,
TaskInstance.map_index == task_instance.map_index,
)
)
dag = ti.dag_model.serialized_dag.dag
if hasattr(task_instance.task, "_dag"): # BaseOperator
task_instance.task._dag = dag
else: # MappedOperator
task_instance.task.dag = dag
else:
raise
if task.dag is ATTRIBUTE_REMOVED:
task.dag = dag # required after deserialization

dag_run = task_instance.get_dagrun(session)
data_interval = dag.get_run_data_interval(dag_run)

Expand Down Expand Up @@ -1278,12 +1264,8 @@ def _record_task_map_for_downstreams(

:meta private:
"""
# when taking task over RPC, we need to add the dag back
if isinstance(task, MappedOperator):
if not task.dag:
task.dag = dag
elif not task._dag:
task._dag = dag
if task.dag is ATTRIBUTE_REMOVED:
task.dag = dag # required after deserialization

if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
return
Expand Down Expand Up @@ -3313,8 +3295,12 @@ def get_template_context(
:param session: SQLAlchemy ORM Session
:param ignore_param_exceptions: flag to suppress value exceptions while initializing the ParamsDict
"""
if TYPE_CHECKING:
assert self.task
assert self.task.dag
return _get_template_context(
task_instance=self,
dag=self.task.dag,
session=session,
ignore_param_exceptions=ignore_param_exceptions,
)
Expand Down Expand Up @@ -3374,8 +3360,15 @@ def render_templates(
context = self.get_template_context()
original_task = self.task

ti = context["ti"]

if TYPE_CHECKING:
assert original_task
assert self.task
assert ti.task

if ti.task.dag is ATTRIBUTE_REMOVED:
ti.task.dag = self.task.dag

# If self.task is mapped, this call replaces self.task to point to the
# unmapped BaseOperator created by this function! This is because the
Expand Down
18 changes: 18 additions & 0 deletions airflow/serialization/pydantic/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,12 @@ def get_template_context(
"""
from airflow.models.taskinstance import _get_template_context

if TYPE_CHECKING:
assert self.task
assert self.task.dag
return _get_template_context(
task_instance=self,
dag=self.task.dag,
session=session,
ignore_param_exceptions=ignore_param_exceptions,
)
Expand Down Expand Up @@ -518,6 +522,20 @@ def _handle_reschedule(
)
_set_ti_attrs(self, updated_ti) # _handle_reschedule is a remote call that mutates the TI

def get_relevant_upstream_map_indexes(
self,
upstream: Operator,
ti_count: int | None,
*,
session: Session | None = None,
) -> int | range | None:
return TaskInstance.get_relevant_upstream_map_indexes(
self=self, # type: ignore[arg-type]
upstream=upstream,
ti_count=ti_count,
session=session,
)


if is_pydantic_2_installed():
TaskInstancePydantic.model_rebuild()
4 changes: 2 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
from airflow.utils.timezone import from_timestamp, parse_timezone
from airflow.utils.types import NOTSET, ArgNotSet
from airflow.utils.types import ATTRIBUTE_REMOVED, NOTSET, ArgNotSet

if TYPE_CHECKING:
from inspect import Parameter
Expand Down Expand Up @@ -1297,7 +1297,7 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
)
else:
op = SerializedBaseOperator(task_id=encoded_op["task_id"])

op.dag = ATTRIBUTE_REMOVED # type: ignore[assignment]
cls.populate_operator(op, encoded_op)
return op

Expand Down
19 changes: 19 additions & 0 deletions airflow/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ def is_arg_passed(arg: Union[ArgNotSet, None] = NOTSET) -> bool:
"""Sentinel value for argument default. See ``ArgNotSet``."""


class AttributeRemoved:
"""
Sentinel type to signal when attribute removed on serialization.

:meta private:
"""

def __getattr__(self, item):
raise RuntimeError("Attribute was removed on serialization and must be set again.")


ATTRIBUTE_REMOVED = AttributeRemoved()
"""
Sentinel value for attributes removed on serialization.

:meta private:
"""


class DagRunType(str, enum.Enum):
"""Class with DagRun types."""

Expand Down
1 change: 1 addition & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3568,6 +3568,7 @@ def test_operator_field_with_serialization(self, create_task_instance):
deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op)
assert deserialized_op.task_type == "EmptyOperator"
# Verify that ti.operator field renders correctly "with" Serialization
deserialized_op.dag = ti.task.dag
ser_ti = TI(task=deserialized_op, run_id=None)
assert ser_ti.operator == "EmptyOperator"
assert ser_ti.task.operator_name == "EmptyOperator"
Expand Down
5 changes: 4 additions & 1 deletion tests/providers/postgres/operators/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ def test_parameters_are_templatized(create_task_instance_of_operator):
task_id="test-task",
)
task: SQLExecuteQueryOperator = ti.render_templates(
{"param": {"conn_id": "pg", "table": "foo", "bar": "egg"}}
{
"param": {"conn_id": "pg", "table": "foo", "bar": "egg"},
"ti": ti,
}
)
assert task.conn_id == "pg"
assert task.sql == "SELECT * FROM foo WHERE spam = %(spam)s;"
Expand Down
32 changes: 25 additions & 7 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2509,21 +2509,31 @@ def test_operator_expand_deserialized_unmap():

ser_mapped = BaseSerialization.serialize(mapped)
deser_mapped = BaseSerialization.deserialize(ser_mapped)
deser_mapped.dag = None

ser_normal = BaseSerialization.serialize(normal)
deser_normal = BaseSerialization.deserialize(ser_normal)
deser_normal.dag = None
assert deser_mapped.unmap(None) == deser_normal


@pytest.mark.db_test
def test_sensor_expand_deserialized_unmap():
"""Unmap a deserialized mapped sensor should be similar to deserializing a non-mapped sensor"""
normal = BashSensor(task_id="a", bash_command=[1, 2], mode="reschedule")
mapped = BashSensor.partial(task_id="a", mode="reschedule").expand(bash_command=[1, 2])

serialize = SerializedBaseOperator.serialize

deserialize = SerializedBaseOperator.deserialize
assert deserialize(serialize(mapped)).unmap(None) == deserialize(serialize(normal))
dag = DAG(dag_id="hello", start_date=None)
with dag:
normal = BashSensor(task_id="a", bash_command=[1, 2], mode="reschedule")
mapped = BashSensor.partial(task_id="b", mode="reschedule").expand(bash_command=[1, 2])
ser_mapped = SerializedBaseOperator.serialize(mapped)
deser_mapped = SerializedBaseOperator.deserialize(ser_mapped)
deser_mapped.dag = dag
deser_unmapped = deser_mapped.unmap(None)
ser_normal = SerializedBaseOperator.serialize(normal)
deser_normal = SerializedBaseOperator.deserialize(ser_normal)
deser_normal.dag = dag
comps = set(BashSensor._comps)
comps.remove("task_id")
assert all(getattr(deser_unmapped, c, None) == getattr(deser_normal, c, None) for c in comps)


def test_task_resources_serde():
Expand Down Expand Up @@ -2625,6 +2635,10 @@ def x(arg1, arg2, arg3):
"retry_delay": timedelta(seconds=30),
}

# this dag is not pickleable in this context, so we have to simply
# set it to None
deserialized.dag = None

# Ensure the serialized operator can also be correctly pickled, to ensure
# correct interaction between DAG pickling and serialization. This is done
# here so we don't need to duplicate tests between pickled and non-pickled
Expand Down Expand Up @@ -2721,6 +2735,10 @@ def x(arg1, arg2, arg3):
"retry_delay": timedelta(seconds=30),
}

# this dag is not pickleable in this context, so we have to simply
# set it to None
deserialized.dag = None

# Ensure the serialized operator can also be correctly pickled, to ensure
# correct interaction between DAG pickling and serialization. This is done
# here so we don't need to duplicate tests between pickled and non-pickled
Expand Down
27 changes: 20 additions & 7 deletions tests/serialization/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from airflow.jobs.job import Job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.models import MappedOperator
from airflow.models.dag import DagModel
from airflow.models.dag import DAG, DagModel
from airflow.models.dataset import (
DagScheduleDatasetReference,
DatasetEvent,
Expand All @@ -43,7 +43,7 @@
from airflow.settings import _ENABLE_AIP_44
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.types import DagRunType
from airflow.utils.types import ATTRIBUTE_REMOVED, DagRunType
from tests.models import DEFAULT_DATE

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, d
"task_id": "target",
}

with dag_maker():
with dag_maker() as dag:

@task
def source():
Expand Down Expand Up @@ -117,7 +117,7 @@ def target(val=None):
# roundtrip ti
sered = BaseSerialization.serialize(ti, use_pydantic_models=True)
desered = BaseSerialization.deserialize(sered, use_pydantic_models=True)

assert desered.task.dag is ATTRIBUTE_REMOVED
assert "operator_class" not in sered["__var"]["task"]

assert desered.task.__class__ == MappedOperator
Expand All @@ -130,9 +130,22 @@ def target(val=None):

assert isinstance(desered.task.operator_class, dict)

resered = BaseSerialization.serialize(desered, use_pydantic_models=True)
deresered = BaseSerialization.deserialize(resered, use_pydantic_models=True)
assert deresered.task.operator_class == desered.task.operator_class == op_class_dict_expected
# let's check that we can safely add back dag...
assert isinstance(dag, DAG)
# dag already has this task
assert dag.has_task(desered.task.task_id) is True
# but the task has no dag
assert desered.task.dag is ATTRIBUTE_REMOVED
# and there are no upstream / downstreams on the task cus those are wiped out on serialization
# and this is wrong / not great but that's how it is
assert desered.task.upstream_task_ids == set()
assert desered.task.downstream_task_ids == set()
# add the dag back
desered.task.dag = dag
# great, no error
# but still, there are no upstream downstreams
assert desered.task.upstream_task_ids == set()
assert desered.task.downstream_task_ids == set()


@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
Expand Down
9 changes: 9 additions & 0 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ def test_serialize_deserialize_pydantic(input, pydantic_class, encoded_type, cmp
reserialized = BaseSerialization.serialize(deserialized, use_pydantic_models=True)
dereserialized = BaseSerialization.deserialize(reserialized, use_pydantic_models=True)
assert isinstance(dereserialized, pydantic_class)

if encoded_type == "task_instance":
deserialized.task.dag = None
dereserialized.task.dag = None

assert dereserialized == deserialized

# Verify recursive behavior
Expand Down Expand Up @@ -394,6 +399,10 @@ def test_all_pydantic_models_round_trip():
serialized = BaseSerialization.serialize(pydantic_instance, use_pydantic_models=True)
deserialized = BaseSerialization.deserialize(serialized, use_pydantic_models=True)
assert isinstance(deserialized, c)
if isinstance(pydantic_instance, TaskInstancePydantic):
# we can't access the dag on deserialization; but there is no dag here.
deserialized.task.dag = None
pydantic_instance.task.dag = None
assert pydantic_instance == deserialized


Expand Down
Loading