Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple "job runner" from BaseJob ORM model #30255

Merged
merged 2 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions airflow/api_connexion/endpoints/health_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

from airflow.api_connexion.schemas.health_schema import health_schema
from airflow.api_connexion.types import APIResponse
from airflow.jobs.scheduler_job import SchedulerJob
from airflow.jobs.triggerer_job import TriggererJob
from airflow.jobs.scheduler_job import SchedulerJobRunner
from airflow.jobs.triggerer_job import TriggererJobRunner

HEALTHY = "healthy"
UNHEALTHY = "unhealthy"
Expand All @@ -33,7 +33,7 @@ def get_health() -> APIResponse:
scheduler_status = UNHEALTHY
triggerer_status: str | None = UNHEALTHY
try:
scheduler_job = SchedulerJob.most_recent_job()
scheduler_job = SchedulerJobRunner.most_recent_job()

if scheduler_job:
latest_scheduler_heartbeat = scheduler_job.latest_heartbeat.isoformat()
Expand All @@ -42,7 +42,7 @@ def get_health() -> APIResponse:
except Exception:
metadatabase_status = UNHEALTHY
try:
triggerer_job = TriggererJob.most_recent_job()
triggerer_job = TriggererJobRunner.most_recent_job()

if triggerer_job:
latest_triggerer_heartbeat = triggerer_job.latest_heartbeat.isoformat()
Expand Down
13 changes: 10 additions & 3 deletions airflow/cli/commands/dag_processor_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,37 @@

import logging
from datetime import timedelta
from typing import Any

import daemon
from daemon.pidfile import TimeoutPIDLockFile

from airflow import settings
from airflow.configuration import conf
from airflow.jobs.dag_processor_job import DagProcessorJob
from airflow.jobs.base_job import BaseJob
from airflow.utils import cli as cli_utils
from airflow.utils.cli import setup_locations, setup_logging

log = logging.getLogger(__name__)


def _create_dag_processor_job(args) -> DagProcessorJob:
def _create_dag_processor_job(args: Any) -> BaseJob:
"""Creates DagFileProcessorProcess instance."""
from airflow.dag_processing.manager import DagFileProcessorManager

processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout")
processor_timeout = timedelta(seconds=processor_timeout_seconds)
return DagProcessorJob(

processor = DagFileProcessorManager(
processor_timeout=processor_timeout,
dag_directory=args.subdir,
max_runs=args.num_runs,
dag_ids=[],
pickle_dags=args.do_pickle,
)
return BaseJob(
job_runner=processor.job_runner,
)


@cli_utils.action_cli
Expand Down
15 changes: 9 additions & 6 deletions airflow/cli/commands/scheduler_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@
from airflow.api_internal.internal_api_call import InternalApiConfig
from airflow.configuration import conf
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.scheduler_job import SchedulerJob
from airflow.jobs.base_job import BaseJob
from airflow.jobs.scheduler_job import SchedulerJobRunner
from airflow.utils import cli as cli_utils
from airflow.utils.cli import process_subdir, setup_locations, setup_logging, sigint_handler, sigquit_handler
from airflow.utils.scheduler_health import serve_health_check


def _run_scheduler_job(job: SchedulerJob, *, skip_serve_logs: bool) -> None:
def _run_scheduler_job(job: BaseJob, *, skip_serve_logs: bool) -> None:
InternalApiConfig.force_database_direct_access()
enable_health_check = conf.getboolean("scheduler", "ENABLE_HEALTH_CHECK")
with _serve_logs(skip_serve_logs), _serve_health_check(enable_health_check):
Expand All @@ -46,10 +47,12 @@ def scheduler(args):
"""Starts Airflow Scheduler."""
print(settings.HEADER)

job = SchedulerJob(
subdir=process_subdir(args.subdir),
num_runs=args.num_runs,
do_pickle=args.do_pickle,
job = BaseJob(
potiuk marked this conversation as resolved.
Show resolved Hide resolved
job_runner=SchedulerJobRunner(
subdir=process_subdir(args.subdir),
num_runs=args.num_runs,
do_pickle=args.do_pickle,
)
)
ExecutorLoader.validate_database_executor_compatibility(job.executor)

Expand Down
14 changes: 8 additions & 6 deletions airflow/cli/commands/standalone_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
from airflow.configuration import AIRFLOW_HOME, conf, make_group_other_inaccessible
from airflow.executors import executor_constants
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.scheduler_job import SchedulerJob
from airflow.jobs.triggerer_job import TriggererJob
from airflow.jobs.base_job import most_recent_job
from airflow.jobs.job_runner import BaseJobRunner
from airflow.jobs.scheduler_job import SchedulerJobRunner
potiuk marked this conversation as resolved.
Show resolved Hide resolved
from airflow.jobs.triggerer_job import TriggererJobRunner
from airflow.utils import db


Expand Down Expand Up @@ -215,8 +217,8 @@ def is_ready(self):
"""
return (
self.port_open(self.web_server_port)
and self.job_running(SchedulerJob)
and self.job_running(TriggererJob)
and self.job_running(SchedulerJobRunner)
and self.job_running(TriggererJobRunner)
)

def port_open(self, port):
Expand All @@ -235,13 +237,13 @@ def port_open(self, port):
return False
return True

def job_running(self, job):
def job_running(self, job_runner_class: type[BaseJobRunner]):
"""
Checks if the given job name is running and heartbeating correctly.

Used to tell if scheduler is alive.
"""
recent = job.most_recent_job()
recent = most_recent_job(job_runner_class.job_type)
if not recent:
return False
return recent.is_alive()
Expand Down
10 changes: 7 additions & 3 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagRunNotFound, TaskInstanceNotFound
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.local_task_job import LocalTaskJob
from airflow.jobs.base_job import BaseJob
from airflow.jobs.local_task_job import LocalTaskJobRunner
from airflow.listeners.listener import get_listener_manager
from airflow.models import DagPickle, TaskInstance
from airflow.models.dag import DAG
Expand Down Expand Up @@ -247,7 +248,7 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None:

def _run_task_by_local_task_job(args, ti: TaskInstance) -> TaskReturnCode | None:
"""Run LocalTaskJob, which monitors the raw task execution process."""
run_job = LocalTaskJob(
local_task_job_runner = LocalTaskJobRunner(
task_instance=ti,
mark_success=args.mark_success,
pickle_id=args.pickle,
Expand All @@ -259,9 +260,12 @@ def _run_task_by_local_task_job(args, ti: TaskInstance) -> TaskReturnCode | None
pool=args.pool,
external_executor_id=_extract_external_executor_id(args),
)
run_job = BaseJob(
job_runner=local_task_job_runner,
dag_id=ti.dag_id,
)
try:
ret = run_job.run()

finally:
if args.shut_down_logging:
logging.shutdown()
Expand Down
6 changes: 4 additions & 2 deletions airflow/cli/commands/triggerer_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

from airflow import settings
from airflow.configuration import conf
from airflow.jobs.triggerer_job import TriggererJob
from airflow.jobs.base_job import BaseJob
from airflow.jobs.triggerer_job import TriggererJobRunner
from airflow.utils import cli as cli_utils
from airflow.utils.cli import setup_locations, setup_logging, sigint_handler, sigquit_handler
from airflow.utils.serve_logs import serve_logs
Expand All @@ -54,7 +55,8 @@ def triggerer(args):
"""Starts Airflow Triggerer."""
settings.MASK_SECRETS_IN_LOGS = True
print(settings.HEADER)
job = TriggererJob(capacity=args.capacity)
triggerer_job_runner = TriggererJobRunner(capacity=args.capacity)
job = BaseJob(job_runner=triggerer_job_runner)

if args.daemon:
pid, stdout, stderr, log_file = setup_locations(
Expand Down
26 changes: 19 additions & 7 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from importlib import import_module
from multiprocessing.connection import Connection as MultiprocessingConnection
from pathlib import Path
from typing import Any, NamedTuple, cast
from typing import TYPE_CHECKING, Any, NamedTuple, cast

from setproctitle import setproctitle
from sqlalchemy.orm import Session
Expand All @@ -46,7 +46,7 @@
from airflow.callbacks.callback_requests import CallbackRequest, SlaCallbackRequest
from airflow.configuration import conf
from airflow.dag_processing.processor import DagFileProcessorProcess
from airflow.jobs.base_job import BaseJob
from airflow.jobs.base_job import perform_heartbeat
from airflow.models import errors
from airflow.models.dag import DagModel
from airflow.models.dagwarning import DagWarning
Expand All @@ -67,6 +67,9 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import prohibit_commit, skip_locked, with_row_locks

if TYPE_CHECKING:
from airflow.jobs.dag_processor_job import DagProcessorJobRunner


class DagParsingStat(NamedTuple):
"""Information on processing progress."""
Expand Down Expand Up @@ -381,8 +384,9 @@ def __init__(
pickle_dags: bool,
signal_conn: MultiprocessingConnection | None = None,
async_mode: bool = True,
job: BaseJob | None = None,
):
from airflow.jobs.dag_processor_job import DagProcessorJobRunner

super().__init__()
# known files; this will be updated every `dag_dir_list_interval` and stuff added/removed accordingly
self._file_paths: list[str] = []
Expand All @@ -395,8 +399,9 @@ def __init__(
self._async_mode = async_mode
self._parsing_start_time: float | None = None
self._dag_directory = dag_directory
self._job = job

self._job_runner = DagProcessorJobRunner(
processor=self,
)
# Set the signal conn in to non-blocking mode, so that attempting to
# send when the buffer is full errors, rather than hangs for-ever
# attempting to send (this is to avoid deadlocks!)
Expand Down Expand Up @@ -461,6 +466,10 @@ def __init__(
else {}
)

@property
def job_runner(self) -> DagProcessorJobRunner:
return self._job_runner

def register_exit_signals(self):
"""Register signals that stop child processes."""
signal.signal(signal.SIGINT, self._exit_gracefully)
Expand Down Expand Up @@ -576,8 +585,11 @@ def _run_parsing_loop(self):
while True:
loop_start_time = time.monotonic()
ready = multiprocessing.connection.wait(self.waitables.keys(), timeout=poll_time)
if self._job:
self._job.heartbeat()
# we cannot (for now) define job in _job_runner nicely due to circular references of
# job and job runner, so we have to use getattr, but we might address it in the future
# change when decoupling these two even more
if getattr(self._job_runner, "job", None) is not None:
perform_heartbeat(self._job_runner.job, only_if_necessary=False)
Comment on lines +588 to +592
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a refactoring idea but that can wait until after this is merged and doesn’t need to block your other PRs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attempted in astronomer@68d5231

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will not need it at the end -> when we complet the refactoring (final state is #30376), this line is gone, we have job defined individually in each *JobRunner and we always know which runner we access, so there is no need to add the typeguard I think

if self._direct_scheduler_conn is not None and self._direct_scheduler_conn in ready:
agent_signal = self._direct_scheduler_conn.recv()

Expand Down
Loading