diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index be9a870f17b44..81c367393987d 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -24,6 +24,7 @@ from flask import Response +from airflow.jobs.job import Job, most_recent_job from airflow.serialization.serialized_objects import BaseSerialization from airflow.utils.session import create_session @@ -51,6 +52,12 @@ def _initialize_map() -> dict[str, Callable]: DagModel.get_current, DagFileProcessorManager.clear_nonexistent_import_errors, DagWarning.purge_inactive_dag_warnings, + Job._add_to_db, + Job._fetch_from_db, + Job._kill, + Job._update_heartbeat, + Job._update_in_db, + most_recent_job, MetastoreBackend._fetch_connection, MetastoreBackend._fetch_variable, XCom.get_value, diff --git a/airflow/jobs/backfill_job_runner.py b/airflow/jobs/backfill_job_runner.py index daa549dbe2ff5..92ade82ff93e1 100644 --- a/airflow/jobs/backfill_job_runner.py +++ b/airflow/jobs/backfill_job_runner.py @@ -62,7 +62,7 @@ from airflow.models.taskinstance import TaskInstanceKey -class BackfillJobRunner(BaseJobRunner[Job], LoggingMixin): +class BackfillJobRunner(BaseJobRunner, LoggingMixin): """ A backfill job runner consists of a dag or subdag for a specific time range. diff --git a/airflow/jobs/base_job_runner.py b/airflow/jobs/base_job_runner.py index 611579b239e9c..e26c100ea8513 100644 --- a/airflow/jobs/base_job_runner.py +++ b/airflow/jobs/base_job_runner.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING from airflow.utils.session import NEW_SESSION, provide_session @@ -27,22 +27,20 @@ from airflow.jobs.job import Job from airflow.serialization.pydantic.job import JobPydantic -J = TypeVar("J", "Job", "JobPydantic", "Job | JobPydantic") - -class BaseJobRunner(Generic[J]): +class BaseJobRunner: """Abstract class for job runners to derive from.""" job_type = "undefined" - def __init__(self, job: J) -> None: + def __init__(self, job: Job) -> None: if job.job_type and job.job_type != self.job_type: raise Exception( f"The job is already assigned a different job_type: {job.job_type}." f"This is a bug and should be reported." ) job.job_type = self.job_type - self.job: J = job + self.job: Job = job def _execute(self) -> int | None: """ @@ -65,7 +63,7 @@ def heartbeat_callback(self, session: Session = NEW_SESSION) -> None: @classmethod @provide_session - def most_recent_job(cls, session: Session = NEW_SESSION) -> Job | None: + def most_recent_job(cls, session: Session = NEW_SESSION) -> Job | JobPydantic | None: """Return the most recent job of this type, if any, based on last heartbeat received.""" from airflow.jobs.job import most_recent_job diff --git a/airflow/jobs/dag_processor_job_runner.py b/airflow/jobs/dag_processor_job_runner.py index e14ec7f1c4f15..76b2ab5925540 100644 --- a/airflow/jobs/dag_processor_job_runner.py +++ b/airflow/jobs/dag_processor_job_runner.py @@ -31,7 +31,7 @@ def empty_callback(_: Any) -> None: pass -class DagProcessorJobRunner(BaseJobRunner[Job], LoggingMixin): +class DagProcessorJobRunner(BaseJobRunner, LoggingMixin): """ DagProcessorJobRunner is a job runner that runs a DagFileProcessorManager processor. diff --git a/airflow/jobs/job.py b/airflow/jobs/job.py index 1f9070f1dd308..0afbb2a026923 100644 --- a/airflow/jobs/job.py +++ b/airflow/jobs/job.py @@ -26,6 +26,7 @@ from sqlalchemy.orm import backref, foreign, relationship from sqlalchemy.orm.session import make_transient +from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.executors.executor_loader import ExecutorLoader @@ -38,11 +39,13 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.platform import getuser -from airflow.utils.session import NEW_SESSION, create_session, provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime from airflow.utils.state import JobState if TYPE_CHECKING: + import datetime + from sqlalchemy.orm.session import Session @@ -120,15 +123,10 @@ def executor(self): return ExecutorLoader.get_default_executor() @cached_property - def heartrate(self): - if self.job_type == "TriggererJob": - return conf.getfloat("triggerer", "JOB_HEARTBEAT_SEC") - else: - # Heartrate used to be hardcoded to scheduler, so in all other - # cases continue to use that value for back compat - return conf.getfloat("scheduler", "JOB_HEARTBEAT_SEC") + def heartrate(self) -> float: + return Job._heartrate(self.job_type) - def is_alive(self, grace_multiplier=2.1): + def is_alive(self, grace_multiplier=2.1) -> bool: """ Is this job currently alive. @@ -138,28 +136,23 @@ def is_alive(self, grace_multiplier=2.1): :param grace_multiplier: multiplier of heartrate to require heart beat within """ - if self.job_type == "SchedulerJob": - health_check_threshold: int = conf.getint("scheduler", "scheduler_health_check_threshold") - elif self.job_type == "TriggererJob": - health_check_threshold: int = conf.getint("triggerer", "triggerer_health_check_threshold") - else: - health_check_threshold: int = self.heartrate * grace_multiplier - return ( - self.state == JobState.RUNNING - and (timezone.utcnow() - self.latest_heartbeat).total_seconds() < health_check_threshold + return Job._is_alive( + job_type=self.job_type, + heartrate=self.heartrate, + state=self.state, + latest_heartbeat=self.latest_heartbeat, + grace_multiplier=grace_multiplier, ) @provide_session def kill(self, session: Session = NEW_SESSION) -> NoReturn: """Handle on_kill callback and updates state in database.""" - job = session.scalar(select(Job).where(Job.id == self.id).limit(1)) - job.end_date = timezone.utcnow() try: self.on_kill() except Exception as e: self.log.error("on_kill() method failed: %s", e) - session.merge(job) - session.commit() + + Job._kill(job_id=self.id, session=session) raise AirflowException("Job shut down externally.") def on_kill(self): @@ -191,11 +184,10 @@ def heartbeat( try: # This will cause it to load from the db - session.merge(self) + self._merge_from(Job._fetch_from_db(self, session)) previous_heartbeat = self.latest_heartbeat if self.state == JobState.RESTARTING: - # TODO: Make sure it is AIP-44 compliant self.kill() # Figure out how long to sleep for @@ -207,18 +199,14 @@ def heartbeat( sleep_for = max(0, seconds_remaining) sleep(sleep_for) - # Update last heartbeat time - with create_session() as session: - # Make the session aware of this object - session.merge(self) - self.latest_heartbeat = timezone.utcnow() - session.commit() - # At this point, the DB has updated. - previous_heartbeat = self.latest_heartbeat - - heartbeat_callback(session) - self.log.debug("[heartbeat]") - self.heartbeat_failed = False + job = Job._update_heartbeat(job=self, session=session) + self._merge_from(job) + + # At this point, the DB has updated. + previous_heartbeat = self.latest_heartbeat + + heartbeat_callback(session) + self.log.debug("[heartbeat]") except OperationalError: Stats.incr(convert_camel_to_snake(self.__class__.__name__) + "_heartbeat_failure", 1, 1) if not self.heartbeat_failed: @@ -242,26 +230,131 @@ def prepare_for_execution(self, session: Session = NEW_SESSION): Stats.incr(self.__class__.__name__.lower() + "_start", 1, 1) self.state = JobState.RUNNING self.start_date = timezone.utcnow() - session.add(self) - session.commit() + self._merge_from(Job._add_to_db(job=self, session=session)) make_transient(self) @provide_session def complete_execution(self, session: Session = NEW_SESSION): get_listener_manager().hook.before_stopping(component=self) self.end_date = timezone.utcnow() - session.merge(self) - session.commit() + Job._update_in_db(job=self, session=session) Stats.incr(self.__class__.__name__.lower() + "_end", 1, 1) @provide_session - def most_recent_job(self, session: Session = NEW_SESSION) -> Job | None: + def most_recent_job(self, session: Session = NEW_SESSION) -> Job | JobPydantic | None: """Return the most recent job of this type, if any, based on last heartbeat received.""" return most_recent_job(self.job_type, session=session) + def _merge_from(self, job: Job | JobPydantic | None): + if job is None: + self.log.error("Job is empty: %s", self.id) + return + self.id = job.id + self.dag_id = job.dag_id + self.state = job.state + self.job_type = job.job_type + self.start_date = job.start_date + self.end_date = job.end_date + self.latest_heartbeat = job.latest_heartbeat + self.executor_class = job.executor_class + self.hostname = job.hostname + self.unixname = job.unixname + + @staticmethod + def _heartrate(job_type: str) -> float: + if job_type == "TriggererJob": + return conf.getfloat("triggerer", "JOB_HEARTBEAT_SEC") + else: + # Heartrate used to be hardcoded to scheduler, so in all other + # cases continue to use that value for back compat + return conf.getfloat("scheduler", "JOB_HEARTBEAT_SEC") + + @staticmethod + def _is_alive( + job_type: str | None, + heartrate: float, + state: JobState | str | None, + latest_heartbeat: datetime.datetime, + grace_multiplier: float = 2.1, + ) -> bool: + health_check_threshold: float + if job_type == "SchedulerJob": + health_check_threshold = conf.getint("scheduler", "scheduler_health_check_threshold") + elif job_type == "TriggererJob": + health_check_threshold = conf.getint("triggerer", "triggerer_health_check_threshold") + else: + health_check_threshold = heartrate * grace_multiplier + return ( + state == JobState.RUNNING + and (timezone.utcnow() - latest_heartbeat).total_seconds() < health_check_threshold + ) + + @staticmethod + @internal_api_call + @provide_session + def _kill(job_id: str, session: Session = NEW_SESSION) -> Job | JobPydantic: + job = session.scalar(select(Job).where(Job.id == job_id).limit(1)) + job.end_date = timezone.utcnow() + session.merge(job) + session.commit() + return job + + @staticmethod + @internal_api_call + @provide_session + def _fetch_from_db(job: Job | JobPydantic, session: Session = NEW_SESSION) -> Job | JobPydantic | None: + if isinstance(job, Job): + # not Internal API + session.merge(job) + return job + # Internal API, + return session.scalar(select(Job).where(Job.id == job.id).limit(1)) + + @staticmethod + @internal_api_call + @provide_session + def _add_to_db(job: Job | JobPydantic, session: Session = NEW_SESSION) -> Job | JobPydantic: + if isinstance(job, JobPydantic): + orm_job = Job() + orm_job._merge_from(job) + else: + orm_job = job + session.add(orm_job) + session.commit() + return orm_job + + @staticmethod + @internal_api_call + @provide_session + def _update_in_db(job: Job | JobPydantic, session: Session = NEW_SESSION): + if isinstance(job, Job): + # not Internal API + session.merge(job) + session.commit() + # Internal API. + orm_job: Job | None = session.scalar(select(Job).where(Job.id == job.id).limit(1)) + if orm_job is None: + return + orm_job._merge_from(job) + session.merge(orm_job) + session.commit() + @staticmethod + @internal_api_call + @provide_session + def _update_heartbeat(job: Job | JobPydantic, session: Session = NEW_SESSION) -> Job | JobPydantic: + orm_job: Job | None = session.scalar(select(Job).where(Job.id == job.id).limit(1)) + if orm_job is None: + return job + orm_job.latest_heartbeat = timezone.utcnow() + session.merge(orm_job) + session.commit() + return orm_job + + +@internal_api_call @provide_session -def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job | None: +def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job | JobPydantic | None: """ Return the most recent job of this type, if any, based on last heartbeat received. @@ -285,7 +378,7 @@ def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job | None @provide_session def run_job( - job: Job | JobPydantic, execute_callable: Callable[[], int | None], session: Session = NEW_SESSION + job: Job, execute_callable: Callable[[], int | None], session: Session = NEW_SESSION ) -> int | None: """ Run the job. @@ -294,12 +387,7 @@ def run_job( same DB session and the session is kept open throughout the whole execution. :meta private: - - TODO: Maybe we should not keep the session during job execution ?. """ - # The below assert is a temporary one, to make MyPy happy with partial AIP-44 work - we will remove it - # once final AIP-44 changes are completed. - assert not isinstance(job, JobPydantic), "Job should be ORM object not Pydantic one here (AIP-44 WIP)" job.prepare_for_execution(session=session) try: return execute_job(job, execute_callable=execute_callable) @@ -307,7 +395,7 @@ def run_job( job.complete_execution(session=session) -def execute_job(job: Job | JobPydantic, execute_callable: Callable[[], int | None]) -> int | None: +def execute_job(job: Job, execute_callable: Callable[[], int | None]) -> int | None: """ Execute the job. @@ -322,8 +410,8 @@ def execute_job(job: Job | JobPydantic, execute_callable: Callable[[], int | Non database operations or over the Internal API call. :param job: Job to execute - it can be either DB job or it's Pydantic serialized version. It does - not really matter, because except of running the heartbeat and state setting, - the runner should not modify the job state. + not really matter, because except of running the heartbeat and state setting, + the runner should not modify the job state. :param execute_callable: callable to execute when running the job. @@ -344,7 +432,7 @@ def execute_job(job: Job | JobPydantic, execute_callable: Callable[[], int | Non def perform_heartbeat( - job: Job | JobPydantic, heartbeat_callback: Callable[[Session], None], only_if_necessary: bool + job: Job, heartbeat_callback: Callable[[Session], None], only_if_necessary: bool ) -> None: """ Perform heartbeat for the Job passed to it,optionally checking if it is necessary. @@ -354,13 +442,9 @@ def perform_heartbeat( :param only_if_necessary: only heartbeat if it is necessary (i.e. if there are things to run for triggerer for example) """ - # The below assert is a temporary one, to make MyPy happy with partial AIP-44 work - we will remove it - # once final AIP-44 changes are completed. - assert not isinstance(job, JobPydantic), "Job should be ORM object not Pydantic one here (AIP-44 WIP)" seconds_remaining: float = 0.0 if job.latest_heartbeat and job.heartrate: seconds_remaining = job.heartrate - (timezone.utcnow() - job.latest_heartbeat).total_seconds() if seconds_remaining > 0 and only_if_necessary: return - with create_session() as session: - job.heartbeat(heartbeat_callback=heartbeat_callback, session=session) + job.heartbeat(heartbeat_callback=heartbeat_callback) diff --git a/airflow/jobs/local_task_job_runner.py b/airflow/jobs/local_task_job_runner.py index 079ad4cbba083..e068d88203411 100644 --- a/airflow/jobs/local_task_job_runner.py +++ b/airflow/jobs/local_task_job_runner.py @@ -41,7 +41,6 @@ from airflow.jobs.job import Job from airflow.models.taskinstance import TaskInstance - from airflow.serialization.pydantic.job import JobPydantic SIGSEGV_MESSAGE = """ ******************************************* Received SIGSEGV ******************************************* @@ -74,14 +73,14 @@ ********************************************************************************************************""" -class LocalTaskJobRunner(BaseJobRunner["Job | JobPydantic"], LoggingMixin): +class LocalTaskJobRunner(BaseJobRunner, LoggingMixin): """LocalTaskJob runs a single task instance.""" job_type = "LocalTaskJob" def __init__( self, - job: Job | JobPydantic, + job: Job, task_instance: TaskInstance, # TODO add TaskInstancePydantic ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 36fd1bdfb9a29..8e9d79fb333d4 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -126,7 +126,7 @@ def _is_parent_process() -> bool: return multiprocessing.current_process().name == "MainProcess" -class SchedulerJobRunner(BaseJobRunner[Job], LoggingMixin): +class SchedulerJobRunner(BaseJobRunner, LoggingMixin): """ SchedulerJobRunner runs for a specific time interval and schedules jobs that are ready to run. diff --git a/airflow/jobs/triggerer_job_runner.py b/airflow/jobs/triggerer_job_runner.py index 34a271c3ac869..24a243bed8e88 100644 --- a/airflow/jobs/triggerer_job_runner.py +++ b/airflow/jobs/triggerer_job_runner.py @@ -60,7 +60,6 @@ from airflow.jobs.job import Job from airflow.models import TaskInstance - from airflow.serialization.pydantic.job import JobPydantic from airflow.triggers.base import BaseTrigger HANDLER_SUPPORTS_TRIGGERER = False @@ -237,7 +236,7 @@ def setup_queue_listener(): return None -class TriggererJobRunner(BaseJobRunner["Job | JobPydantic"], LoggingMixin): +class TriggererJobRunner(BaseJobRunner, LoggingMixin): """ Run active triggers in asyncio and update their dependent tests/DAGs once their events have fired. @@ -250,7 +249,7 @@ class TriggererJobRunner(BaseJobRunner["Job | JobPydantic"], LoggingMixin): def __init__( self, - job: Job | JobPydantic, + job: Job, capacity=None, ): super().__init__(job) diff --git a/airflow/serialization/pydantic/job.py b/airflow/serialization/pydantic/job.py index 27c8ad8ca7496..39627f9a993b4 100644 --- a/airflow/serialization/pydantic/job.py +++ b/airflow/serialization/pydantic/job.py @@ -14,13 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from datetime import datetime -from typing import Optional +import datetime +from functools import cached_property +from typing import TYPE_CHECKING, Optional from pydantic import BaseModel as BaseModelPydantic +from airflow.executors.executor_loader import ExecutorLoader from airflow.jobs.base_job_runner import BaseJobRunner +if TYPE_CHECKING: + from airflow.jobs.job import Job + def check_runner_initialized(job_runner: Optional[BaseJobRunner], job_type: str) -> BaseJobRunner: if job_runner is None: @@ -35,19 +40,34 @@ class JobPydantic(BaseModelPydantic): dag_id: Optional[str] state: Optional[str] job_type: Optional[str] - start_date: Optional[datetime] - end_date: Optional[datetime] - latest_heartbeat: datetime + start_date: Optional[datetime.datetime] + end_date: Optional[datetime.datetime] + latest_heartbeat: datetime.datetime executor_class: Optional[str] hostname: Optional[str] unixname: Optional[str] - # not an ORM field - heartrate: Optional[int] - max_tis_per_query: Optional[int] - class Config: """Make sure it deals automatically with SQLAlchemy ORM classes.""" from_attributes = True orm_mode = True # Pydantic 1.x compatibility. + + @cached_property + def executor(self): + return ExecutorLoader.get_default_executor() + + @cached_property + def heartrate(self) -> float: + assert self.job_type is not None + return Job._heartrate(self.job_type) + + def is_alive(self, grace_multiplier=2.1) -> bool: + """Is this job currently alive.""" + return Job._is_alive( + job_type=self.job_type, + heartrate=self.heartrate, + state=self.state, + latest_heartbeat=self.latest_heartbeat, + grace_multiplier=grace_multiplier, + ) diff --git a/tests/jobs/test_base_job.py b/tests/jobs/test_base_job.py index 998cc3bebe0f9..a808fda1b06a2 100644 --- a/tests/jobs/test_base_job.py +++ b/tests/jobs/test_base_job.py @@ -203,21 +203,16 @@ def test_is_alive_scheduler(self, job_type): job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=10) assert job.is_alive() is False, "Completed jobs even with recent heartbeat should not be alive" - @patch("airflow.jobs.job.create_session") - def test_heartbeat_failed(self, mock_create_session): + def test_heartbeat_failed(self): when = timezone.utcnow() - datetime.timedelta(seconds=60) - with create_session() as session: - mock_session = Mock(spec_set=session, name="MockSession") - mock_create_session.return_value.__enter__.return_value = mock_session - - job = Job(heartrate=10, state=State.RUNNING) - job.latest_heartbeat = when - - mock_session.commit.side_effect = OperationalError("Force fail", {}, None) + mock_session = Mock(name="MockSession") + mock_session.commit.side_effect = OperationalError("Force fail", {}, None) + job = Job(heartrate=10, state=State.RUNNING) + job.latest_heartbeat = when - job.heartbeat(heartbeat_callback=lambda: None) + job.heartbeat(heartbeat_callback=lambda: None, session=mock_session) - assert job.latest_heartbeat == when, "attribute not updated when heartbeat fails" + assert job.latest_heartbeat == when, "attribute not updated when heartbeat fails" @conf_vars( { diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index d16bf398b58a8..ba4f9604d7c9e 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -41,7 +41,6 @@ if TYPE_CHECKING: from airflow.jobs.job import Job - from airflow.serialization.pydantic.job import JobPydantic @pytest.fixture() @@ -336,7 +335,7 @@ def test_prune_dict(self, mode, expected): class MockJobRunner(BaseJobRunner): job_type = "MockJob" - def __init__(self, job: Job | JobPydantic, func=None): + def __init__(self, job: Job, func=None): super().__init__(job) self.job = job self.job.job_type = self.job_type