Skip to content

Commit

Permalink
Improve ATTRIBUTE_REMOVED sentinel to use class and more context (#40920
Browse files Browse the repository at this point in the history
)

ATTRIBUTE_REMOVED was a singleton object but then it made it
difficult to find out which attribute has been removed. This PR
changes the approach to allow for multiple AttributeRemoved objects
as sentinel and compare class ratehr than object. The object contains
name of the removed attribute. Better diagnostics at the expense of
a bit more memory used.
  • Loading branch information
potiuk authored Jul 23, 2024
1 parent 84d19bc commit efca50a
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 16 deletions.
8 changes: 4 additions & 4 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import ATTRIBUTE_REMOVED, NOTSET
from airflow.utils.types import NOTSET, AttributeRemoved
from airflow.utils.xcom import XCOM_RETURN_KEY

if TYPE_CHECKING:
Expand Down Expand Up @@ -1243,12 +1243,12 @@ def dag(self, dag: DAG | None):
return

# if set to removed, then just set and exit
if self._dag is ATTRIBUTE_REMOVED:
if self._dag.__class__ is AttributeRemoved:
self._dag = dag
return
# if setting to removed, then just set and exit
if dag is ATTRIBUTE_REMOVED:
self._dag = ATTRIBUTE_REMOVED # type: ignore[assignment]
if dag.__class__ is AttributeRemoved:
self._dag = AttributeRemoved("_dag") # type: ignore[assignment]
return

from airflow.models.dag import DAG
Expand Down
10 changes: 5 additions & 5 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.task_instance_session import set_current_task_instance_session
from airflow.utils.timeout import timeout
from airflow.utils.types import ATTRIBUTE_REMOVED
from airflow.utils.types import AttributeRemoved
from airflow.utils.xcom import XCOM_RETURN_KEY

TR = TaskReschedule
Expand Down Expand Up @@ -935,7 +935,7 @@ def _get_template_context(
assert task
assert task.dag

if task.dag is ATTRIBUTE_REMOVED:
if task.dag.__class__ is AttributeRemoved:
task.dag = dag # required after deserialization

dag_run = task_instance.get_dagrun(session)
Expand Down Expand Up @@ -1288,7 +1288,7 @@ def _record_task_map_for_downstreams(
:meta private:
"""
if task.dag is ATTRIBUTE_REMOVED:
if task.dag.__class__ is AttributeRemoved:
task.dag = dag # required after deserialization

if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
Expand Down Expand Up @@ -2672,7 +2672,7 @@ def ensure_dag(
"""Ensure that task has a dag object associated, might have been removed by serialization."""
if TYPE_CHECKING:
assert task_instance.task
if task_instance.task.dag is None or task_instance.task.dag is ATTRIBUTE_REMOVED:
if task_instance.task.dag is None or task_instance.task.dag.__class__ is AttributeRemoved:
task_instance.task.dag = DagBag(read_dags_from_db=True).get_dag(
dag_id=task_instance.dag_id, session=session
)
Expand Down Expand Up @@ -3465,7 +3465,7 @@ def render_templates(
assert self.task
assert ti.task

if ti.task.dag is ATTRIBUTE_REMOVED:
if ti.task.dag.__class__ is AttributeRemoved:
ti.task.dag = self.task.dag

# If self.task is mapped, this call replaces self.task to point to the
Expand Down
4 changes: 2 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
from airflow.utils.timezone import from_timestamp, parse_timezone
from airflow.utils.types import ATTRIBUTE_REMOVED, NOTSET, ArgNotSet
from airflow.utils.types import NOTSET, ArgNotSet, AttributeRemoved

if TYPE_CHECKING:
from inspect import Parameter
Expand Down Expand Up @@ -1329,7 +1329,7 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
)
else:
op = SerializedBaseOperator(task_id=encoded_op["task_id"])
op.dag = ATTRIBUTE_REMOVED # type: ignore[assignment]
op.dag = AttributeRemoved("dag") # type: ignore[assignment]
cls.populate_operator(op, encoded_op)
return op

Expand Down
9 changes: 7 additions & 2 deletions airflow/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,16 @@ class AttributeRemoved:
:meta private:
"""

def __init__(self, attribute_name: str):
self.attribute_name = attribute_name

def __getattr__(self, item):
raise RuntimeError(f"Attribute was removed on serialization and must be set again: {item}.")
raise RuntimeError(
f"Attribute {self.attribute_name} was removed on "
f"serialization and must be set again - found when accessing {item}."
)


ATTRIBUTE_REMOVED = AttributeRemoved()
"""
Sentinel value for attributes removed on serialization.
Expand Down
6 changes: 3 additions & 3 deletions tests/serialization/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from airflow.settings import _ENABLE_AIP_44
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.types import ATTRIBUTE_REMOVED, DagRunType
from airflow.utils.types import AttributeRemoved, DagRunType
from tests.models import DEFAULT_DATE

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -117,7 +117,7 @@ def target(val=None):
# roundtrip ti
sered = BaseSerialization.serialize(ti, use_pydantic_models=True)
desered = BaseSerialization.deserialize(sered, use_pydantic_models=True)
assert desered.task.dag is ATTRIBUTE_REMOVED
assert desered.task.dag.__class__ is AttributeRemoved
assert "operator_class" not in sered["__var"]["task"]

assert desered.task.__class__ == MappedOperator
Expand All @@ -135,7 +135,7 @@ def target(val=None):
# dag already has this task
assert dag.has_task(desered.task.task_id) is True
# but the task has no dag
assert desered.task.dag is ATTRIBUTE_REMOVED
assert desered.task.dag.__class__ is AttributeRemoved
# and there are no upstream / downstreams on the task cus those are wiped out on serialization
# and this is wrong / not great but that's how it is
assert desered.task.upstream_task_ids == set()
Expand Down

0 comments on commit efca50a

Please sign in to comment.