Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Don't raise a warning in ExecutorSafeguard when execute is called from an extended operator #42849

Merged
merged 12 commits into from
Oct 12, 2024
Merged
11 changes: 10 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import warnings
from datetime import datetime, timedelta
from functools import total_ordering, wraps
from threading import local
from types import FunctionType
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -392,14 +393,22 @@ class ExecutorSafeguard:
"""

test_mode = conf.getboolean("core", "unit_test_mode")
_sentinel = local()
_sentinel.callers = {}

@classmethod
def decorator(cls, func):
@wraps(func)
def wrapper(self, *args, **kwargs):
from airflow.decorators.base import DecoratedOperator

sentinel = kwargs.pop(f"{self.__class__.__name__}__sentinel", None)
sentinel_key = f"{self.__class__.__name__}__sentinel"
sentinel = kwargs.pop(sentinel_key, None)

if sentinel:
cls._sentinel.callers[sentinel_key] = sentinel
else:
sentinel = cls._sentinel.callers.pop(f"{func.__qualname__.split('.')[0]}__sentinel", None)

if not cls.test_mode and not sentinel == _sentinel and not isinstance(self, DecoratedOperator):
message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside TaskInstance!"
Expand Down
3 changes: 2 additions & 1 deletion tests/api_fastapi/views/public/test_dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from airflow.utils.session import provide_session
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags

from dev.tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags

pytestmark = pytest.mark.db_test

Expand Down
24 changes: 23 additions & 1 deletion tests/models/test_baseoperatormeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def execute(self, context: Context) -> Any:
return f"Hello {self.owner}!"


class ExtendedHelloWorldOperator(HelloWorldOperator):
def execute(self, context: Context) -> Any:
return super().execute(context)


class TestExecutorSafeguard:
def setup_method(self):
ExecutorSafeguard.test_mode = False
Expand All @@ -49,12 +54,29 @@ def teardown_method(self, method):

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
@pytest.mark.db_test
def test_executor_when_classic_operator_called_from_dag(self, dag_maker):
@patch.object(HelloWorldOperator, "log")
def test_executor_when_classic_operator_called_from_dag(self, mock_log, dag_maker):
with dag_maker() as dag:
HelloWorldOperator(task_id="hello_operator")

dag_run = dag.test()
assert dag_run.state == DagRunState.SUCCESS
mock_log.warning.assert_not_called()

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
@pytest.mark.db_test
@patch.object(HelloWorldOperator, "log")
def test_executor_when_extended_classic_operator_called_from_dag(
self,
mock_log,
dag_maker,
):
with dag_maker() as dag:
ExtendedHelloWorldOperator(task_id="hello_operator")

dag_run = dag.test()
assert dag_run.state == DagRunState.SUCCESS
mock_log.warning.assert_not_called()

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
@pytest.mark.parametrize(
Expand Down