diff --git a/airflow/api/common/delete_dag.py b/airflow/api/common/delete_dag.py index 39f1461fccc9e..6a1ad271c1e96 100644 --- a/airflow/api/common/delete_dag.py +++ b/airflow/api/common/delete_dag.py @@ -21,20 +21,21 @@ import logging from sqlalchemy import and_, or_ +from sqlalchemy.orm import Session from airflow import models from airflow.exceptions import AirflowException, DagNotFound from airflow.models import DagModel, TaskFail from airflow.models.serialized_dag import SerializedDagModel from airflow.utils.db import get_sqla_model_classes -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State log = logging.getLogger(__name__) @provide_session -def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int: +def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session = NEW_SESSION) -> int: """ Delete a DAG by a dag_id. diff --git a/airflow/api/common/experimental/pool.py b/airflow/api/common/experimental/pool.py index a37c3e4086042..4fa6f04badc46 100644 --- a/airflow/api/common/experimental/pool.py +++ b/airflow/api/common/experimental/pool.py @@ -19,15 +19,16 @@ from __future__ import annotations from deprecated import deprecated +from sqlalchemy.orm import Session from airflow.exceptions import AirflowBadRequest, PoolNotFound from airflow.models import Pool -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session @deprecated(reason="Use Pool.get_pool() instead", version="2.2.4") @provide_session -def get_pool(name, session=None): +def get_pool(name, session: Session = NEW_SESSION): """Get pool by a given name.""" if not (name and name.strip()): raise AirflowBadRequest("Pool name shouldn't be empty") @@ -41,14 +42,14 @@ def get_pool(name, session=None): @deprecated(reason="Use Pool.get_pools() instead", version="2.2.4") @provide_session -def get_pools(session=None): +def get_pools(session: Session = NEW_SESSION): """Get all pools.""" return session.query(Pool).all() @deprecated(reason="Use Pool.create_pool() instead", version="2.2.4") @provide_session -def create_pool(name, slots, description, session=None): +def create_pool(name, slots, description, session: Session = NEW_SESSION): """Create a pool with given parameters.""" if not (name and name.strip()): raise AirflowBadRequest("Pool name shouldn't be empty") @@ -79,7 +80,7 @@ def create_pool(name, slots, description, session=None): @deprecated(reason="Use Pool.delete_pool() instead", version="2.2.4") @provide_session -def delete_pool(name, session=None): +def delete_pool(name, session: Session = NEW_SESSION): """Delete pool by a given name.""" if not (name and name.strip()): raise AirflowBadRequest("Pool name shouldn't be empty") diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index ef891fba62832..f1ef3abc6e7f6 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -21,9 +21,11 @@ import errno import json import logging +import operator import signal import subprocess import sys +import warnings from graphviz.dot import Dot from sqlalchemy.orm import Session @@ -47,33 +49,7 @@ log = logging.getLogger(__name__) -@cli_utils.action_cli -def dag_backfill(args, dag=None): - """Creates backfill job or dry run for a DAG or list of DAGs using regex.""" - logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) - - signal.signal(signal.SIGTERM, sigint_handler) - - import warnings - - warnings.warn( - "--ignore-first-depends-on-past is deprecated as the value is always set to True", - category=RemovedInAirflow3Warning, - ) - - if args.ignore_first_depends_on_past is False: - args.ignore_first_depends_on_past = True - - if not args.start_date and not args.end_date: - raise AirflowException("Provide a start_date and/or end_date") - - if not dag: - dags = get_dags(args.subdir, dag_id=args.dag_id, use_regex=args.treat_dag_as_regex) - else: - dags = dag if type(dag) == list else [dag] - - dags.sort(key=lambda d: d.dag_id) - +def _run_dag_backfill(dags: list[DAG], args) -> None: # If only one date is passed, using same as start and end args.end_date = args.end_date or args.start_date args.start_date = args.start_date or args.end_date @@ -133,12 +109,39 @@ def dag_backfill(args, dag=None): print(str(vr)) sys.exit(1) + +@cli_utils.action_cli +def dag_backfill(args, dag: list[DAG] | DAG | None = None) -> None: + """Creates backfill job or dry run for a DAG or list of DAGs using regex.""" + logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) + signal.signal(signal.SIGTERM, sigint_handler) + warnings.warn( + "--ignore-first-depends-on-past is deprecated as the value is always set to True", + category=RemovedInAirflow3Warning, + ) + + if args.ignore_first_depends_on_past is False: + args.ignore_first_depends_on_past = True + + if not args.start_date and not args.end_date: + raise AirflowException("Provide a start_date and/or end_date") + + if not dag: + dags = get_dags(args.subdir, dag_id=args.dag_id, use_regex=args.treat_dag_as_regex) + elif isinstance(dag, list): + dags = dag + else: + dags = [dag] + del dag + + dags.sort(key=lambda d: d.dag_id) + _run_dag_backfill(dags, args) if len(dags) > 1: log.info("All of the backfills are done.") @cli_utils.action_cli -def dag_trigger(args): +def dag_trigger(args) -> None: """Creates a dag run for the specified dag.""" api_client = get_current_api_client() try: @@ -159,7 +162,7 @@ def dag_trigger(args): @cli_utils.action_cli -def dag_delete(args): +def dag_delete(args) -> None: """Deletes all DB records related to the specified dag.""" api_client = get_current_api_client() if ( @@ -177,18 +180,18 @@ def dag_delete(args): @cli_utils.action_cli -def dag_pause(args): +def dag_pause(args) -> None: """Pauses a DAG.""" set_is_paused(True, args) @cli_utils.action_cli -def dag_unpause(args): +def dag_unpause(args) -> None: """Unpauses a DAG.""" set_is_paused(False, args) -def set_is_paused(is_paused, args): +def set_is_paused(is_paused: bool, args) -> None: """Sets is_paused for DAG by a given dag_id.""" dag = DagModel.get_dagmodel(args.dag_id) @@ -200,7 +203,7 @@ def set_is_paused(is_paused, args): print(f"Dag: {args.dag_id}, paused: {is_paused}") -def dag_dependencies_show(args): +def dag_dependencies_show(args) -> None: """Displays DAG dependencies, save to file or show as imgcat image.""" dot = render_dag_dependencies(SerializedDagModel.get_dag_dependencies()) filename = args.save @@ -219,7 +222,7 @@ def dag_dependencies_show(args): print(dot.source) -def dag_show(args): +def dag_show(args) -> None: """Displays DAG or saves it's graphic representation to the file.""" dag = get_dag(args.subdir, args.dag_id) dot = render_dag(dag) @@ -239,7 +242,7 @@ def dag_show(args): print(dot.source) -def _display_dot_via_imgcat(dot: Dot): +def _display_dot_via_imgcat(dot: Dot) -> None: data = dot.pipe(format="png") try: with subprocess.Popen("imgcat", stdout=subprocess.PIPE, stdin=subprocess.PIPE) as proc: @@ -255,7 +258,7 @@ def _display_dot_via_imgcat(dot: Dot): raise -def _save_dot_to_file(dot: Dot, filename: str): +def _save_dot_to_file(dot: Dot, filename: str) -> None: filename_without_ext, _, ext = filename.rpartition(".") dot.render(filename=filename_without_ext, format=ext, cleanup=True) print(f"File {filename} saved") @@ -263,7 +266,7 @@ def _save_dot_to_file(dot: Dot, filename: str): @cli_utils.action_cli @provide_session -def dag_state(args, session=NEW_SESSION): +def dag_state(args, session: Session = NEW_SESSION) -> None: """ Returns the state (and conf if exists) of a DagRun at the command line. >>> airflow dags state tutorial 2015-01-01T00:00:00.000000 @@ -284,7 +287,7 @@ def dag_state(args, session=NEW_SESSION): @cli_utils.action_cli -def dag_next_execution(args): +def dag_next_execution(args) -> None: """ Returns the next execution datetime of a DAG at the command line. >>> airflow dags next-execution tutorial @@ -312,7 +315,7 @@ def print_execution_interval(interval: DataInterval | None): next_interval = dag.get_next_data_interval(last_parsed_dag) print_execution_interval(next_interval) - for i in range(1, args.num_executions): + for _ in range(1, args.num_executions): next_info = dag.next_dagrun_info(next_interval, restricted=False) next_interval = None if next_info is None else next_info.data_interval print_execution_interval(next_interval) @@ -320,7 +323,7 @@ def print_execution_interval(interval: DataInterval | None): @cli_utils.action_cli @suppress_logs_and_warning -def dag_list_dags(args): +def dag_list_dags(args) -> None: """Displays dags with or without stats at the command line.""" dagbag = DagBag(process_subdir(args.subdir)) if dagbag.import_errors: @@ -332,7 +335,7 @@ def dag_list_dags(args): file=sys.stderr, ) AirflowConsole().print_as( - data=sorted(dagbag.dags.values(), key=lambda d: d.dag_id), + data=sorted(dagbag.dags.values(), key=operator.attrgetter("dag_id")), output=args.output, mapper=lambda x: { "dag_id": x.dag_id, @@ -345,7 +348,7 @@ def dag_list_dags(args): @cli_utils.action_cli @suppress_logs_and_warning -def dag_list_import_errors(args): +def dag_list_import_errors(args) -> None: """Displays dags with import errors on the command line.""" dagbag = DagBag(process_subdir(args.subdir)) data = [] @@ -359,7 +362,7 @@ def dag_list_import_errors(args): @cli_utils.action_cli @suppress_logs_and_warning -def dag_report(args): +def dag_report(args) -> None: """Displays dagbag stats at the command line.""" dagbag = DagBag(process_subdir(args.subdir)) AirflowConsole().print_as( @@ -378,7 +381,7 @@ def dag_report(args): @cli_utils.action_cli @suppress_logs_and_warning @provide_session -def dag_list_jobs(args, dag=None, session=NEW_SESSION): +def dag_list_jobs(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None: """Lists latest n jobs.""" queries = [] if dag: @@ -408,7 +411,7 @@ def dag_list_jobs(args, dag=None, session=NEW_SESSION): @cli_utils.action_cli @suppress_logs_and_warning @provide_session -def dag_list_dag_runs(args, dag=None, session=NEW_SESSION): +def dag_list_dag_runs(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None: """Lists dag runs for a given DAG.""" if dag: args.dag_id = dag.dag_id @@ -445,7 +448,7 @@ def dag_list_dag_runs(args, dag=None, session=NEW_SESSION): @provide_session @cli_utils.action_cli -def dag_test(args, dag=None, session=None): +def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None: """Execute one single DagRun for a given DAG and execution date.""" run_conf = None if args.conf: @@ -481,7 +484,7 @@ def dag_test(args, dag=None, session=None): @provide_session @cli_utils.action_cli -def dag_reserialize(args, session: Session = NEW_SESSION): +def dag_reserialize(args, session: Session = NEW_SESSION) -> None: """Serialize a DAG instance.""" session.query(SerializedDagModel).delete(synchronize_session=False) diff --git a/airflow/cli/commands/jobs_command.py b/airflow/cli/commands/jobs_command.py index c030b5ea9b31b..959f7ebc4c841 100644 --- a/airflow/cli/commands/jobs_command.py +++ b/airflow/cli/commands/jobs_command.py @@ -16,14 +16,16 @@ # under the License. from __future__ import annotations +from sqlalchemy.orm import Session + from airflow.jobs.base_job import BaseJob from airflow.utils.net import get_hostname -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State @provide_session -def check(args, session=None): +def check(args, session: Session = NEW_SESSION) -> None: """Checks if job(s) are still alive.""" if args.allow_multiple and not args.limit > 1: raise SystemExit("To use option --allow-multiple, you must set the limit to a value greater than 1.") diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 42d0b9e86ed55..c2546f91a4dae 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -25,7 +25,7 @@ import sys import textwrap from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress -from typing import Generator, Union +from typing import Generator, Union, cast import pendulum from pendulum.parsing.exceptions import ParserError @@ -40,15 +40,15 @@ from airflow.jobs.local_task_job import LocalTaskJob from airflow.listeners.listener import get_listener_manager from airflow.models import DagPickle, TaskInstance -from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG from airflow.models.dagrun import DagRun -from airflow.models.operator import needs_expansion +from airflow.models.operator import Operator, needs_expansion +from airflow.models.param import ParamsDict from airflow.models.taskinstance import TaskReturnCode from airflow.settings import IS_K8S_EXECUTOR_POD from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS -from airflow.typing_compat import Literal +from airflow.typing_compat import Literal, Protocol from airflow.utils import cli as cli_utils from airflow.utils.cli import ( get_dag, @@ -146,7 +146,7 @@ def _get_dag_run( @provide_session def _get_ti( - task: BaseOperator, + task: Operator, map_index: int, *, exec_date_or_run_id: str | None = None, @@ -155,6 +155,9 @@ def _get_ti( session: Session = NEW_SESSION, ) -> tuple[TaskInstance, bool]: """Get the task instance through DagRun.run_id, if that fails, get the TI the old way.""" + dag = task.dag + if dag is None: + raise ValueError("Cannot get task instance for a task not assigned to a DAG") if not exec_date_or_run_id and not create_if_necessary: raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.") if needs_expansion(task): @@ -163,7 +166,7 @@ def _get_ti( elif map_index >= 0: raise RuntimeError("map_index passed to non-mapped task") dag_run, dr_created = _get_dag_run( - dag=task.dag, + dag=dag, exec_date_or_run_id=exec_date_or_run_id, create_if_necessary=create_if_necessary, session=session, @@ -173,7 +176,7 @@ def _get_ti( if ti_or_none is None: if not create_if_necessary: raise TaskInstanceNotFound( - f"TaskInstance for {task.dag.dag_id}, {task.task_id}, map={map_index} with " + f"TaskInstance for {dag.dag_id}, {task.task_id}, map={map_index} with " f"run_id or execution_date of {exec_date_or_run_id!r} not found" ) # TODO: Validate map_index is in range? @@ -197,13 +200,13 @@ def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None | Tas """ if args.local: return _run_task_by_local_task_job(args, ti) - elif args.raw: + if args.raw: return _run_raw_task(args, ti) - else: - return _run_task_by_executor(args, dag, ti) + _run_task_by_executor(args, dag, ti) + return None -def _run_task_by_executor(args, dag, ti): +def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None: """ Sends the task to the executor for execution. @@ -242,7 +245,7 @@ def _run_task_by_executor(args, dag, ti): executor.end() -def _run_task_by_local_task_job(args, ti) -> TaskReturnCode | None: +def _run_task_by_local_task_job(args, ti: TaskInstance) -> TaskReturnCode | None: """Run LocalTaskJob, which monitors the raw task execution process.""" run_job = LocalTaskJob( task_instance=ti, @@ -344,7 +347,7 @@ class TaskCommandMarker: @cli_utils.action_cli(check_db=False) -def task_run(args, dag=None): +def task_run(args, dag: DAG | None = None) -> TaskReturnCode | None: """ Run a single task instance. @@ -392,13 +395,12 @@ def task_run(args, dag=None): if args.pickle: print(f"Loading pickle id: {args.pickle}") - dag = get_dag_by_pickle(args.pickle) + _dag = get_dag_by_pickle(args.pickle) elif not dag: - dag = get_dag(args.subdir, args.dag_id) + _dag = get_dag(args.subdir, args.dag_id) else: - # Use DAG from parameter - pass - task = dag.get_task(task_id=args.task_id) + _dag = dag + task = _dag.get_task(task_id=args.task_id) ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, pool=args.pool) ti.init_run_context(raw=args.raw) @@ -416,10 +418,10 @@ def task_run(args, dag=None): task_return_code = None try: if args.interactive: - task_return_code = _run_task_by_selected_method(args, dag, ti) + task_return_code = _run_task_by_selected_method(args, _dag, ti) else: with _move_task_handlers_to_root(ti), _redirect_stdout_to_ti_log(ti): - task_return_code = _run_task_by_selected_method(args, dag, ti) + task_return_code = _run_task_by_selected_method(args, _dag, ti) if task_return_code == TaskReturnCode.DEFERRED: _set_task_deferred_context_var() finally: @@ -431,7 +433,7 @@ def task_run(args, dag=None): @cli_utils.action_cli(check_db=False) -def task_failed_deps(args): +def task_failed_deps(args) -> None: """ Get task instance dependencies that were not met. @@ -461,7 +463,7 @@ def task_failed_deps(args): @cli_utils.action_cli(check_db=False) @suppress_logs_and_warning -def task_state(args): +def task_state(args) -> None: """ Returns the state of a TaskInstance at the command line. >>> airflow tasks state tutorial sleep 2015-01-01 @@ -475,7 +477,7 @@ def task_state(args): @cli_utils.action_cli(check_db=False) @suppress_logs_and_warning -def task_list(args, dag=None): +def task_list(args, dag: DAG | None = None) -> None: """Lists the tasks within a DAG at the command line.""" dag = dag or get_dag(args.subdir, args.dag_id) if args.tree: @@ -485,7 +487,12 @@ def task_list(args, dag=None): print("\n".join(tasks)) -SUPPORTED_DEBUGGER_MODULES: list[str] = [ +class _SupportedDebugger(Protocol): + def post_mortem(self) -> None: + ... + + +SUPPORTED_DEBUGGER_MODULES = [ "pudb", "web_pdb", "ipdb", @@ -493,7 +500,7 @@ def task_list(args, dag=None): ] -def _guess_debugger(): +def _guess_debugger() -> _SupportedDebugger: """ Trying to guess the debugger used by the user. @@ -506,18 +513,19 @@ def _guess_debugger(): * `ipdb `__ * `pdb `__ """ - for mod in SUPPORTED_DEBUGGER_MODULES: + exc: Exception + for mod_name in SUPPORTED_DEBUGGER_MODULES: try: - return importlib.import_module(mod) - except ImportError: - continue - return importlib.import_module("pdb") + return cast(_SupportedDebugger, importlib.import_module(mod_name)) + except ImportError as e: + exc = e + raise exc @cli_utils.action_cli(check_db=False) @suppress_logs_and_warning @provide_session -def task_states_for_dag_run(args, session=None): +def task_states_for_dag_run(args, session: Session = NEW_SESSION) -> None: """Get the status of all task instances in a DagRun.""" dag_run = ( session.query(DagRun) @@ -560,7 +568,7 @@ def format_task_instance(ti: TaskInstance) -> dict[str, str]: @cli_utils.action_cli(check_db=False) -def task_test(args, dag=None): +def task_test(args, dag: DAG | None = None) -> None: """Tests task for a given dag_id.""" # We want to log output from operators etc to show up here. Normally # airflow.task would redirect to a file, but here we want it to propagate @@ -590,7 +598,7 @@ def task_test(args, dag=None): passed_in_params = json.loads(args.task_params) task.params.update(passed_in_params) - if task.params: + if task.params and isinstance(task.params, ParamsDict): task.params.validate() ti, dr_created = _get_ti( @@ -621,7 +629,7 @@ def task_test(args, dag=None): @cli_utils.action_cli(check_db=False) @suppress_logs_and_warning -def task_render(args, dag=None): +def task_render(args, dag: DAG | None = None) -> None: """Renders and displays templated fields for a given task.""" if not dag: dag = get_dag(args.subdir, args.dag_id) @@ -643,7 +651,7 @@ def task_render(args, dag=None): @cli_utils.action_cli(check_db=False) -def task_clear(args): +def task_clear(args) -> None: """Clears all task instances or only those matched by regex for a DAG(s).""" logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) @@ -680,13 +688,13 @@ class LoggerMutationHelper: :meta private: """ - def __init__(self, logger): + def __init__(self, logger: logging.Logger) -> None: self.handlers = logger.handlers[:] self.level = logger.level self.propagate = logger.propagate self.source_logger = logger - def apply(self, logger, replace=True): + def apply(self, logger: logging.Logger, replace: bool = True) -> None: """ Set ``logger`` with attrs stored on instance. @@ -702,7 +710,7 @@ def apply(self, logger, replace=True): if logger is not logging.getLogger(): logger.propagate = self.propagate - def move(self, logger, replace=True): + def move(self, logger: logging.Logger, replace: bool = True) -> None: """ Replace ``logger`` attrs with those from source. @@ -713,11 +721,11 @@ def move(self, logger, replace=True): self.source_logger.propagate = True self.source_logger.handlers[:] = [] - def reset(self): + def reset(self) -> None: self.apply(self.source_logger) - def __enter__(self): + def __enter__(self) -> LoggerMutationHelper: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.reset() diff --git a/airflow/cli/simple_table.py b/airflow/cli/simple_table.py index 87f7ce9f3b853..c2c60df95c072 100644 --- a/airflow/cli/simple_table.py +++ b/airflow/cli/simple_table.py @@ -18,7 +18,7 @@ import inspect import json -from typing import Any, Callable +from typing import Any, Callable, Sequence from rich.box import ASCII_DOUBLE_HEAD from rich.console import Console @@ -27,10 +27,15 @@ from tabulate import tabulate from airflow.plugins_manager import PluginsDirectorySource +from airflow.typing_compat import TypeGuard from airflow.utils import yaml from airflow.utils.platform import is_tty +def is_data_sequence(data: Sequence[dict | Any]) -> TypeGuard[Sequence[dict]]: + return all(isinstance(d, dict) for d in data) + + class AirflowConsole(Console): """Airflow rich console.""" @@ -88,7 +93,12 @@ def _normalize_data(self, value: Any, output: str) -> list | str | dict | None: return None return str(value) - def print_as(self, data: list[dict | Any], output: str, mapper: Callable | None = None): + def print_as( + self, + data: Sequence[dict | Any], + output: str, + mapper: Callable[[Any], dict] | None = None, + ) -> None: """Prints provided using format specified by output argument.""" output_to_renderer: dict[str, Callable[[Any], None]] = { "json": self.print_as_json, @@ -102,13 +112,12 @@ def print_as(self, data: list[dict | Any], output: str, mapper: Callable | None f"Unknown formatter: {output}. Allowed options: {list(output_to_renderer.keys())}" ) - if not all(isinstance(d, dict) for d in data) and not mapper: - raise ValueError("To tabulate non-dictionary data you need to provide `mapper` function") - if mapper: - dict_data: list[dict] = [mapper(d) for d in data] - else: + dict_data: Sequence[dict] = [mapper(d) for d in data] + elif is_data_sequence(data): dict_data = data + else: + raise ValueError("To tabulate non-dictionary data you need to provide `mapper` function") dict_data = [{k: self._normalize_data(v, output) for k, v in d.items()} for d in dict_data] renderer(dict_data) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index db8cdd2d56faf..cf1947992d645 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -151,7 +151,7 @@ def queue_task_instance( self, task_instance: TaskInstance, mark_success: bool = False, - pickle_id: str | None = None, + pickle_id: int | None = None, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, wait_for_past_depends_before_skipping: bool = False, diff --git a/airflow/executors/celery_kubernetes_executor.py b/airflow/executors/celery_kubernetes_executor.py index 94e9df684e7c1..00a7f15830599 100644 --- a/airflow/executors/celery_kubernetes_executor.py +++ b/airflow/executors/celery_kubernetes_executor.py @@ -116,7 +116,7 @@ def queue_task_instance( self, task_instance: TaskInstance, mark_success: bool = False, - pickle_id: str | None = None, + pickle_id: int | None = None, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, wait_for_past_depends_before_skipping: bool = False, diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py index c506453db984b..60fd51282e730 100644 --- a/airflow/executors/debug_executor.py +++ b/airflow/executors/debug_executor.py @@ -95,7 +95,7 @@ def queue_task_instance( self, task_instance: TaskInstance, mark_success: bool = False, - pickle_id: str | None = None, + pickle_id: int | None = None, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, wait_for_past_depends_before_skipping: bool = False, diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 276afba3219ff..3c56e7ceaf26e 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -36,6 +36,7 @@ from kubernetes import client, watch from kubernetes.client import Configuration, models as k8s from kubernetes.client.rest import ApiException +from sqlalchemy.orm import Session from urllib3.exceptions import ReadTimeoutError from airflow.configuration import conf @@ -50,7 +51,7 @@ from airflow.utils import timezone from airflow.utils.event_scheduler import EventScheduler from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State, TaskInstanceState ALL_NAMESPACES = "ALL_NAMESPACES" @@ -494,7 +495,7 @@ def _list_pods(self, query_kwargs): return pods @provide_session - def clear_not_launched_queued_tasks(self, session=None) -> None: + def clear_not_launched_queued_tasks(self, session: Session = NEW_SESSION) -> None: """ Clear tasks that were not yet launched, but were previously queued. diff --git a/airflow/executors/local_kubernetes_executor.py b/airflow/executors/local_kubernetes_executor.py index 9ce34dce8b60b..916d83839138e 100644 --- a/airflow/executors/local_kubernetes_executor.py +++ b/airflow/executors/local_kubernetes_executor.py @@ -117,7 +117,7 @@ def queue_task_instance( self, task_instance: TaskInstance, mark_success: bool = False, - pickle_id: str | None = None, + pickle_id: int | None = None, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, wait_for_past_depends_before_skipping: bool = False, diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py index ee37b9d51026e..9b750f11278f6 100644 --- a/airflow/jobs/backfill_job.py +++ b/airflow/jobs/backfill_job.py @@ -17,8 +17,9 @@ # under the License. from __future__ import annotations +import datetime import time -from typing import TYPE_CHECKING, Any, Iterable, Iterator, Sequence +from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, Sequence import attr import pendulum @@ -45,11 +46,12 @@ from airflow.timetables.base import DagRunInfo from airflow.utils import helpers, timezone from airflow.utils.configuration import conf as airflow_conf, tmp_configuration_copy -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType if TYPE_CHECKING: + from airflow.executors.base_executor import BaseExecutor from airflow.models.abstractoperator import AbstractOperator @@ -122,7 +124,7 @@ def __init__( disable_retry=False, *args, **kwargs, - ): + ) -> None: """ Create a BackfillJob. @@ -164,7 +166,7 @@ def __init__( self.disable_retry = disable_retry super().__init__(*args, **kwargs) - def _update_counters(self, ti_status, session): + def _update_counters(self, ti_status: _DagRunTaskStatus, session: Session) -> None: """ Updates the counters per state of the tasks that were running. @@ -240,7 +242,9 @@ def _update_counters(self, ti_status, session): session.flush() def _manage_executor_state( - self, running, session + self, + running: Mapping[TaskInstanceKey, TaskInstance], + session: Session, ) -> Iterator[tuple[AbstractOperator, str, Sequence[TaskInstance], int]]: """ Compare task instances' states with that of the executor. @@ -293,7 +297,12 @@ def _iter_task_needing_expansion() -> Iterator[AbstractOperator]: yield node, ti.run_id, new_tis, num_mapped_tis @provide_session - def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = None): + def _get_dag_run( + self, + dagrun_info: DagRunInfo, + dag: DAG, + session: Session = NEW_SESSION, + ) -> DagRun | None: """ Return an existing dag run for the given run date or create one. @@ -355,7 +364,12 @@ def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = Non return run @provide_session - def _task_instances_for_dag_run(self, dag, dag_run, session=None): + def _task_instances_for_dag_run( + self, + dag: DAG, + dag_run: DagRun, + session: Session = NEW_SESSION, + ) -> dict[TaskInstanceKey, TaskInstance]: """ Return a map of task instance keys to task instance objects for the given dag run. @@ -371,7 +385,7 @@ def _task_instances_for_dag_run(self, dag, dag_run, session=None): self.reset_state_for_orphaned_tasks(filter_by_dag_run=dag_run, session=session) # for some reason if we don't refresh the reference to run is lost - dag_run.refresh_from_db() + dag_run.refresh_from_db(session=session) make_transient(dag_run) dag_run.dag = dag @@ -389,7 +403,7 @@ def _task_instances_for_dag_run(self, dag, dag_run, session=None): raise return tasks_to_run - def _log_progress(self, ti_status): + def _log_progress(self, ti_status: _DagRunTaskStatus) -> None: self.log.info( "[backfill progress] | finished run %s of %s | tasks waiting: %s | succeeded: %s | " "running: %s | failed: %s | skipped: %s | deadlocked: %s | not ready: %s", @@ -408,10 +422,10 @@ def _log_progress(self, ti_status): def _process_backfill_task_instances( self, - ti_status, - executor, - pickle_id, - start_date=None, + ti_status: _DagRunTaskStatus, + executor: BaseExecutor, + pickle_id: int | None, + start_date: datetime.datetime | None = None, *, session: Session, ) -> list: @@ -674,7 +688,7 @@ def to_keep(key: TaskInstanceKey) -> bool: return executed_run_dates @provide_session - def _collect_errors(self, ti_status: _DagRunTaskStatus, session=None): + def _collect_errors(self, ti_status: _DagRunTaskStatus, session: Session = NEW_SESSION) -> Iterator[str]: def tabulate_ti_keys_set(ti_keys: Iterable[TaskInstanceKey]) -> str: # Sorting by execution date first sorted_ti_keys: Any = sorted( @@ -696,12 +710,11 @@ def tabulate_ti_keys_set(ti_keys: Iterable[TaskInstanceKey]) -> str: return tabulate(sorted_ti_keys, headers=headers) - err = "" if ti_status.failed: - err += "Some task instances failed:\n" - err += tabulate_ti_keys_set(ti_status.failed) + yield "Some task instances failed:\n" + yield tabulate_ti_keys_set(ti_status.failed) if ti_status.deadlocked: - err += "BackfillJob is deadlocked." + yield "BackfillJob is deadlocked." deadlocked_depends_on_past = any( t.are_dependencies_met( dep_context=DepContext(ignore_depends_on_past=False), @@ -714,31 +727,37 @@ def tabulate_ti_keys_set(ti_keys: Iterable[TaskInstanceKey]) -> str: for t in ti_status.deadlocked ) if deadlocked_depends_on_past: - err += ( + yield ( "Some of the deadlocked tasks were unable to run because " 'of "depends_on_past" relationships. Try running the ' "backfill with the option " '"ignore_first_depends_on_past=True" or passing "-I" at ' "the command line." ) - err += "\nThese tasks have succeeded:\n" - err += tabulate_ti_keys_set(ti_status.succeeded) - err += "\n\nThese tasks are running:\n" - err += tabulate_ti_keys_set(ti_status.running) - err += "\n\nThese tasks have failed:\n" - err += tabulate_ti_keys_set(ti_status.failed) - err += "\n\nThese tasks are skipped:\n" - err += tabulate_ti_keys_set(ti_status.skipped) - err += "\n\nThese tasks are deadlocked:\n" - err += tabulate_ti_keys_set([ti.key for ti in ti_status.deadlocked]) - - return err + yield "\nThese tasks have succeeded:\n" + yield tabulate_ti_keys_set(ti_status.succeeded) + yield "\n\nThese tasks are running:\n" + yield tabulate_ti_keys_set(ti_status.running) + yield "\n\nThese tasks have failed:\n" + yield tabulate_ti_keys_set(ti_status.failed) + yield "\n\nThese tasks are skipped:\n" + yield tabulate_ti_keys_set(ti_status.skipped) + yield "\n\nThese tasks are deadlocked:\n" + yield tabulate_ti_keys_set([ti.key for ti in ti_status.deadlocked]) def _get_dag_with_subdags(self) -> list[DAG]: return [self.dag] + self.dag.subdags @provide_session - def _execute_dagruns(self, dagrun_infos, ti_status, executor, pickle_id, start_date, session=None): + def _execute_dagruns( + self, + dagrun_infos: Iterable[DagRunInfo], + ti_status: _DagRunTaskStatus, + executor: BaseExecutor, + pickle_id: int | None, + start_date: datetime.datetime | None, + session: Session = NEW_SESSION, + ) -> None: """ Compute and execute dag runs and their respective task instances for the given dates. @@ -754,10 +773,9 @@ def _execute_dagruns(self, dagrun_infos, ti_status, executor, pickle_id, start_d for dagrun_info in dagrun_infos: for dag in self._get_dag_with_subdags(): dag_run = self._get_dag_run(dagrun_info, dag, session=session) - tis_map = self._task_instances_for_dag_run(dag, dag_run, session=session) if dag_run is None: continue - + tis_map = self._task_instances_for_dag_run(dag, dag_run, session=session) ti_status.active_runs.append(dag_run) ti_status.to_run.update(tis_map or {}) @@ -772,7 +790,11 @@ def _execute_dagruns(self, dagrun_infos, ti_status, executor, pickle_id, start_d ti_status.executed_dag_run_dates.update(processed_dag_run_dates) @provide_session - def _set_unfinished_dag_runs_to_failed(self, dag_runs, session=None): + def _set_unfinished_dag_runs_to_failed( + self, + dag_runs: Iterable[DagRun], + session: Session = NEW_SESSION, + ) -> None: """ Go through the dag_runs and update the state based on the task_instance state. Then set DAG runs that are not finished to failed. @@ -788,7 +810,7 @@ def _set_unfinished_dag_runs_to_failed(self, dag_runs, session=None): session.merge(dag_run) @provide_session - def _execute(self, session=None): + def _execute(self, session: Session = NEW_SESSION) -> None: """Initialize all required components of a dag for a specified date range and execute the tasks.""" ti_status = BackfillJob._DagRunTaskStatus() @@ -876,7 +898,7 @@ def _execute(self, session=None): ) remaining_dates = ti_status.total_runs - len(ti_status.executed_dag_run_dates) - err = self._collect_errors(ti_status=ti_status, session=session) + err = "".join(self._collect_errors(ti_status=ti_status, session=session)) if err: if not self.continue_on_failures or ti_status.deadlocked: raise BackfillUnfinished(err, ti_status) @@ -901,7 +923,11 @@ def _execute(self, session=None): self.log.info("Backfill done for DAG %s. Exiting.", self.dag) @provide_session - def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None) -> int | None: + def reset_state_for_orphaned_tasks( + self, + filter_by_dag_run: DagRun | None = None, + session: Session = NEW_SESSION, + ) -> int | None: """ Reset state of orphaned tasks. diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py index 919e1a7fbca80..4c7c68bb51744 100644 --- a/airflow/jobs/base_job.py +++ b/airflow/jobs/base_job.py @@ -18,10 +18,11 @@ from __future__ import annotations from time import sleep +from typing import NoReturn from sqlalchemy import Column, Index, Integer, String, case from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import backref, foreign, relationship +from sqlalchemy.orm import Session, backref, foreign, relationship from sqlalchemy.orm.session import make_transient from airflow.compat.functools import cached_property @@ -36,7 +37,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.platform import getuser -from airflow.utils.session import create_session, provide_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import UtcDateTime from airflow.utils.state import State @@ -121,7 +122,7 @@ def executor(self): @classmethod @provide_session - def most_recent_job(cls, session=None) -> BaseJob | None: + def most_recent_job(cls, session: Session = NEW_SESSION) -> BaseJob | None: """ Return the most recent job of this type, if any, based on last heartbeat received. @@ -160,7 +161,7 @@ def is_alive(self, grace_multiplier=2.1): ) @provide_session - def kill(self, session=None): + def kill(self, session: Session = NEW_SESSION) -> NoReturn: """Handles on_kill callback and updates state in database.""" job = session.query(BaseJob).filter(BaseJob.id == self.id).first() job.end_date = timezone.utcnow() @@ -176,10 +177,10 @@ def on_kill(self): """Will be called when an external kill command is received.""" @provide_session - def heartbeat_callback(self, session=None) -> None: + def heartbeat_callback(self, session: Session = NEW_SESSION) -> None: """Callback that is called during heartbeat. This method should be overwritten.""" - def heartbeat(self, only_if_necessary: bool = False): + def heartbeat(self, only_if_necessary: bool = False) -> None: """ Heartbeats update the job's entry in the database with a timestamp for the latest_heartbeat and allows for the job to be killed @@ -245,7 +246,7 @@ def heartbeat(self, only_if_necessary: bool = False): # We didn't manage to heartbeat, so make sure that the timestamp isn't updated self.latest_heartbeat = previous_heartbeat - def run(self): + def run(self) -> int | None: """Starts the job.""" Stats.incr(self.__class__.__name__.lower() + "_start", 1, 1) # Adding an entry in the DB @@ -275,5 +276,5 @@ def run(self): Stats.incr(self.__class__.__name__.lower() + "_end", 1, 1) return ret - def _execute(self): + def _execute(self) -> int | None: raise NotImplementedError("This method needs to be overridden") diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index ccc52e0b471df..c64b6443e1b55 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -20,6 +20,7 @@ import signal import psutil +from sqlalchemy.orm import Session from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -31,7 +32,7 @@ from airflow.utils.log.file_task_handler import _set_task_deferred_context_var from airflow.utils.net import get_hostname from airflow.utils.platform import IS_WINDOWS -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State SIGSEGV_MESSAGE = """ @@ -229,7 +230,7 @@ def on_kill(self): self.task_runner.on_finish() @provide_session - def heartbeat_callback(self, session=None): + def heartbeat_callback(self, session: Session = NEW_SESSION) -> None: """Self destruct task if state has been moved away from running externally.""" if self.terminating: # ensure termination if processes are created later @@ -273,7 +274,10 @@ def heartbeat_callback(self, session=None): # A DagRun timeout will cause tasks to be externally marked as skipped. dagrun = ti.get_dagrun(session=session) execution_time = (dagrun.end_date or timezone.utcnow()) - dagrun.start_date - dagrun_timeout = ti.task.dag.dagrun_timeout + if ti.task.dag is not None: + dagrun_timeout = ti.task.dag.dagrun_timeout + else: + dagrun_timeout = None if dagrun_timeout and execution_time > dagrun_timeout: self.log.warning("DagRun timed out after %s.", str(execution_time)) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 5f8a88b730013..d0b2352be39dd 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1293,11 +1293,11 @@ def _schedule_dag_run( callback: DagCallbackRequest | None = None dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) + dag_model = DM.get_dagmodel(dag_run.dag_id, session) - if not dag: - self.log.error("Couldn't find dag %s in DagBag/DB!", dag_run.dag_id) + if not dag or not dag_model: + self.log.error("Couldn't find DAG %s in DAG bag or database!", dag_run.dag_id) return callback - dag_model = DM.get_dagmodel(dag.dag_id, session) if ( dag_run.start_date @@ -1401,6 +1401,10 @@ def _send_sla_callbacks_to_processor(self, dag: DAG) -> None: return dag_model = DagModel.get_dagmodel(dag.dag_id) + if not dag_model: + self.log.error("Couldn't find DAG %s in database!", dag.dag_id) + return + request = SlaCallbackRequest( full_filepath=dag.fileloc, dag_id=dag.dag_id, diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 048d244711744..902ca2436f729 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -3271,7 +3271,7 @@ def timezone(self): @staticmethod @provide_session - def get_dagmodel(dag_id, session=NEW_SESSION): + def get_dagmodel(dag_id: str, session: Session = NEW_SESSION) -> DagModel | None: return session.get( DagModel, dag_id, diff --git a/airflow/models/dagcode.py b/airflow/models/dagcode.py index 47b9588cc8d00..84ccf6a186810 100644 --- a/airflow/models/dagcode.py +++ b/airflow/models/dagcode.py @@ -24,13 +24,14 @@ from sqlalchemy import BigInteger, Column, String, Text from sqlalchemy.dialects.mysql import MEDIUMTEXT +from sqlalchemy.orm import Session from sqlalchemy.sql.expression import literal from airflow.exceptions import AirflowException, DagCodeNotFound from airflow.models.base import Base from airflow.utils import timezone from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime log = logging.getLogger(__name__) @@ -59,7 +60,7 @@ def __init__(self, full_filepath: str, source_code: str | None = None): self.source_code = source_code or DagCode.code(self.fileloc) @provide_session - def sync_to_db(self, session=None): + def sync_to_db(self, session: Session = NEW_SESSION) -> None: """Writes code into database. :param session: ORM Session @@ -68,7 +69,7 @@ def sync_to_db(self, session=None): @classmethod @provide_session - def bulk_sync_to_db(cls, filelocs: Iterable[str], session=None): + def bulk_sync_to_db(cls, filelocs: Iterable[str], session: Session = NEW_SESSION) -> None: """Writes code in bulk into database. :param filelocs: file paths of DAGs to sync @@ -125,7 +126,7 @@ def bulk_sync_to_db(cls, filelocs: Iterable[str], session=None): @classmethod @provide_session - def remove_deleted_code(cls, alive_dag_filelocs: list[str], session=None): + def remove_deleted_code(cls, alive_dag_filelocs: list[str], session: Session = NEW_SESSION) -> None: """Deletes code not included in alive_dag_filelocs. :param alive_dag_filelocs: file paths of alive DAGs @@ -141,7 +142,7 @@ def remove_deleted_code(cls, alive_dag_filelocs: list[str], session=None): @classmethod @provide_session - def has_dag(cls, fileloc: str, session=None) -> bool: + def has_dag(cls, fileloc: str, session: Session = NEW_SESSION) -> bool: """Checks a file exist in dag_code table. :param fileloc: the file to check @@ -175,7 +176,7 @@ def _get_code_from_file(fileloc): @classmethod @provide_session - def _get_code_from_db(cls, fileloc, session=None): + def _get_code_from_db(cls, fileloc, session: Session = NEW_SESSION) -> str: dag_code = session.query(cls).filter(cls.fileloc_hash == cls.dag_fileloc_hash(fileloc)).first() if not dag_code: raise DagCodeNotFound() diff --git a/airflow/models/dagpickle.py b/airflow/models/dagpickle.py index caa319e9840f4..ae2c04ccaf510 100644 --- a/airflow/models/dagpickle.py +++ b/airflow/models/dagpickle.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING + import dill from sqlalchemy import BigInteger, Column, Integer, PickleType @@ -24,6 +26,9 @@ from airflow.utils import timezone from airflow.utils.sqlalchemy import UtcDateTime +if TYPE_CHECKING: + from airflow.models.dag import DAG + class DagPickle(Base): """ @@ -44,9 +49,9 @@ class DagPickle(Base): __tablename__ = "dag_pickle" - def __init__(self, dag): + def __init__(self, dag: DAG) -> None: self.dag_id = dag.dag_id if hasattr(dag, "template_env"): - dag.template_env = None + dag.template_env = None # type: ignore[attr-defined] self.pickle_hash = hash(dag) self.pickle = dag diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 3ce5d25b2e4ba..6f30abe276723 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -266,13 +266,18 @@ def refresh_from_db(self, session: Session = NEW_SESSION) -> None: @classmethod @provide_session - def active_runs_of_dags(cls, dag_ids=None, only_running=False, session=None) -> dict[str, int]: + def active_runs_of_dags( + cls, + dag_ids: Iterable[str] | None = None, + only_running: bool = False, + session: Session = NEW_SESSION, + ) -> dict[str, int]: """Get the number of active dag runs for each dag.""" query = session.query(cls.dag_id, func.count("*")) if dag_ids is not None: # 'set' called to avoid duplicate dag_ids, but converted back to 'list' # because SQLAlchemy doesn't accept a set here. - query = query.filter(cls.dag_id.in_(list(set(dag_ids)))) + query = query.filter(cls.dag_id.in_(set(dag_ids))) if only_running: query = query.filter(cls.state == State.RUNNING) else: @@ -596,7 +601,7 @@ def recalculate(self) -> _UnfinishedStates: dag_id=self.dag_id, run_id=self.run_id, is_failure_callback=True, - processor_subdir=dag_model.processor_subdir, + processor_subdir=None if dag_model is None else dag_model.processor_subdir, msg="task_failure", ) @@ -617,7 +622,7 @@ def recalculate(self) -> _UnfinishedStates: dag_id=self.dag_id, run_id=self.run_id, is_failure_callback=False, - processor_subdir=dag_model.processor_subdir, + processor_subdir=None if dag_model is None else dag_model.processor_subdir, msg="success", ) @@ -638,7 +643,7 @@ def recalculate(self) -> _UnfinishedStates: dag_id=self.dag_id, run_id=self.run_id, is_failure_callback=True, - processor_subdir=dag_model.processor_subdir, + processor_subdir=None if dag_model is None else dag_model.processor_subdir, msg="all_tasks_deadlocked", ) @@ -1230,7 +1235,7 @@ def is_backfill(self) -> bool: @classmethod @provide_session - def get_latest_runs(cls, session=None) -> list[DagRun]: + def get_latest_runs(cls, session: Session = NEW_SESSION) -> list[DagRun]: """Returns the latest DagRun for each DAG""" subquery = ( session.query(cls.dag_id, func.max(cls.execution_date).label("execution_date")) diff --git a/airflow/models/pool.py b/airflow/models/pool.py index d6b8ad38872dd..e7376c31eac1b 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Iterable +from typing import Any, Iterable from sqlalchemy import Column, Integer, String, Text, func from sqlalchemy.orm.session import Session @@ -58,13 +58,13 @@ def __repr__(self): @staticmethod @provide_session - def get_pools(session: Session = NEW_SESSION): + def get_pools(session: Session = NEW_SESSION) -> list[Pool]: """Get all pools.""" return session.query(Pool).all() @staticmethod @provide_session - def get_pool(pool_name: str, session: Session = NEW_SESSION): + def get_pool(pool_name: str, session: Session = NEW_SESSION) -> Pool | None: """ Get the Pool with specific pool name from the Pools. @@ -76,7 +76,7 @@ def get_pool(pool_name: str, session: Session = NEW_SESSION): @staticmethod @provide_session - def get_default_pool(session: Session = NEW_SESSION): + def get_default_pool(session: Session = NEW_SESSION) -> Pool | None: """ Get the Pool of the default_pool from the Pools. @@ -104,11 +104,17 @@ def is_default_pool(id: int, session: Session = NEW_SESSION) -> bool: @staticmethod @provide_session - def create_or_update_pool(name: str, slots: int, description: str, session: Session = NEW_SESSION): + def create_or_update_pool( + name: str, + slots: int, + description: str, + session: Session = NEW_SESSION, + ) -> Pool: """Create a pool with given parameters or update it if it already exists.""" if not name: - return - pool = session.query(Pool).filter_by(pool=name).first() + raise ValueError("Pool name must not be empty") + + pool = session.query(Pool).filter_by(pool=name).one_or_none() if pool is None: pool = Pool(pool=name, slots=slots, description=description) session.add(pool) @@ -117,12 +123,11 @@ def create_or_update_pool(name: str, slots: int, description: str, session: Sess pool.description = description session.commit() - return pool @staticmethod @provide_session - def delete_pool(name: str, session: Session = NEW_SESSION): + def delete_pool(name: str, session: Session = NEW_SESSION) -> Pool: """Delete pool by a given name.""" if name == Pool.DEFAULT_POOL_NAME: raise AirflowException(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted") @@ -196,7 +201,7 @@ def slots_stats( return pools - def to_json(self): + def to_json(self) -> dict[str, Any]: """ Get the Pool in a json structure @@ -210,7 +215,7 @@ def to_json(self): } @provide_session - def occupied_slots(self, session: Session = NEW_SESSION): + def occupied_slots(self, session: Session = NEW_SESSION) -> int: """ Get the number of slots used by running/queued tasks at the moment. @@ -222,13 +227,13 @@ def occupied_slots(self, session: Session = NEW_SESSION): return int( session.query(func.sum(TaskInstance.pool_slots)) .filter(TaskInstance.pool == self.pool) - .filter(TaskInstance.state.in_(list(EXECUTION_STATES))) + .filter(TaskInstance.state.in_(EXECUTION_STATES)) .scalar() or 0 ) @provide_session - def running_slots(self, session: Session = NEW_SESSION): + def running_slots(self, session: Session = NEW_SESSION) -> int: """ Get the number of slots used by running tasks at the moment. @@ -246,7 +251,7 @@ def running_slots(self, session: Session = NEW_SESSION): ) @provide_session - def queued_slots(self, session: Session = NEW_SESSION): + def queued_slots(self, session: Session = NEW_SESSION) -> int: """ Get the number of slots used by queued tasks at the moment. @@ -264,7 +269,7 @@ def queued_slots(self, session: Session = NEW_SESSION): ) @provide_session - def scheduled_slots(self, session: Session = NEW_SESSION): + def scheduled_slots(self, session: Session = NEW_SESSION) -> int: """ Get the number of slots scheduled at the moment. @@ -291,5 +296,4 @@ def open_slots(self, session: Session = NEW_SESSION) -> float: """ if self.slots == -1: return float("inf") - else: - return self.slots - self.occupied_slots(session) + return self.slots - self.occupied_slots(session) diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 53e5e2ccbd223..6535bd36d4921 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -35,7 +35,7 @@ from airflow.serialization.serialized_objects import DagDependency, SerializedDAG from airflow.settings import COMPRESS_SERIALIZED_DAGS, MIN_SERIALIZED_DAG_UPDATE_INTERVAL, json from airflow.utils import timezone -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime log = logging.getLogger(__name__) @@ -92,7 +92,7 @@ class SerializedDagModel(Base): load_op_links = True - def __init__(self, dag: DAG, processor_subdir: str | None = None): + def __init__(self, dag: DAG, processor_subdir: str | None = None) -> None: self.dag_id = dag.dag_id self.fileloc = dag.fileloc self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) @@ -115,7 +115,7 @@ def __init__(self, dag: DAG, processor_subdir: str | None = None): # when COMPRESS_SERIALIZED_DAGS is True self.__data_cache = dag_data - def __repr__(self): + def __repr__(self) -> str: return f"" @classmethod @@ -125,7 +125,7 @@ def write_dag( dag: DAG, min_update_interval: int | None = None, processor_subdir: str | None = None, - session: Session = None, + session: Session = NEW_SESSION, ) -> bool: """Serializes a DAG and writes it into database. If the record already exists, it checks if the Serialized DAG changed or not. If it is @@ -174,7 +174,7 @@ def write_dag( @classmethod @provide_session - def read_all_dags(cls, session: Session = None) -> dict[str, SerializedDAG]: + def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDAG]: """Reads all DAGs in serialized_dag table. :param session: ORM Session @@ -199,7 +199,7 @@ def read_all_dags(cls, session: Session = None) -> dict[str, SerializedDAG]: return dags @property - def data(self): + def data(self) -> dict | None: # use __data_cache to avoid decompress and loads if not hasattr(self, "__data_cache") or self.__data_cache is None: if self._data_compressed: @@ -210,19 +210,20 @@ def data(self): return self.__data_cache @property - def dag(self): + def dag(self) -> SerializedDAG: """The DAG deserialized from the ``data`` column""" SerializedDAG._load_operator_extra_links = self.load_op_links - if isinstance(self.data, dict): - dag = SerializedDAG.from_dict(self.data) + data = self.data + elif isinstance(self.data, str): + data = json.loads(self.data) else: - dag = SerializedDAG.from_json(self.data) - return dag + raise ValueError("invalid or missing serialized DAG data") + return SerializedDAG.from_dict(data) @classmethod @provide_session - def remove_dag(cls, dag_id: str, session: Session = None): + def remove_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> None: """Deletes a DAG with given dag_id. :param dag_id: dag_id to be deleted :param session: ORM Session @@ -232,8 +233,11 @@ def remove_dag(cls, dag_id: str, session: Session = None): @classmethod @provide_session def remove_deleted_dags( - cls, alive_dag_filelocs: list[str], processor_subdir: str | None = None, session=None - ): + cls, + alive_dag_filelocs: list[str], + processor_subdir: str | None = None, + session: Session = NEW_SESSION, + ) -> None: """Deletes DAGs not included in alive_dag_filelocs. :param alive_dag_filelocs: file paths of alive DAGs @@ -260,7 +264,7 @@ def remove_deleted_dags( @classmethod @provide_session - def has_dag(cls, dag_id: str, session: Session = None) -> bool: + def has_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> bool: """Checks a DAG exist in serialized_dag table. :param dag_id: the DAG to check @@ -270,7 +274,7 @@ def has_dag(cls, dag_id: str, session: Session = None) -> bool: @classmethod @provide_session - def get_dag(cls, dag_id: str, session: Session = None) -> SerializedDAG | None: + def get_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> SerializedDAG | None: row = cls.get(dag_id, session=session) if row: return row.dag @@ -278,7 +282,7 @@ def get_dag(cls, dag_id: str, session: Session = None) -> SerializedDAG | None: @classmethod @provide_session - def get(cls, dag_id: str, session: Session = None) -> SerializedDagModel | None: + def get(cls, dag_id: str, session: Session = NEW_SESSION) -> SerializedDagModel | None: """ Get the SerializedDAG for the given dag ID. It will cope with being passed the ID of a subdag by looking up the @@ -299,7 +303,11 @@ def get(cls, dag_id: str, session: Session = None) -> SerializedDagModel | None: @staticmethod @provide_session - def bulk_sync_to_db(dags: list[DAG], processor_subdir: str | None = None, session: Session = None): + def bulk_sync_to_db( + dags: list[DAG], + processor_subdir: str | None = None, + session: Session = NEW_SESSION, + ) -> None: """ Saves DAGs as Serialized DAG objects in the database. Each DAG is saved in a separate database query. @@ -319,7 +327,7 @@ def bulk_sync_to_db(dags: list[DAG], processor_subdir: str | None = None, sessio @classmethod @provide_session - def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> datetime | None: + def get_last_updated_datetime(cls, dag_id: str, session: Session = NEW_SESSION) -> datetime | None: """ Get the date when the Serialized DAG associated to DAG was last updated in serialized_dag table @@ -331,7 +339,7 @@ def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> date @classmethod @provide_session - def get_max_last_updated_datetime(cls, session: Session = None) -> datetime | None: + def get_max_last_updated_datetime(cls, session: Session = NEW_SESSION) -> datetime | None: """ Get the maximum date when any DAG was last updated in serialized_dag table @@ -341,7 +349,7 @@ def get_max_last_updated_datetime(cls, session: Session = None) -> datetime | No @classmethod @provide_session - def get_latest_version_hash(cls, dag_id: str, session: Session = None) -> str | None: + def get_latest_version_hash(cls, dag_id: str, session: Session = NEW_SESSION) -> str | None: """ Get the latest DAG version for a given DAG ID. @@ -353,7 +361,7 @@ def get_latest_version_hash(cls, dag_id: str, session: Session = None) -> str | @classmethod @provide_session - def get_dag_dependencies(cls, session: Session = None) -> dict[str, list[DagDependency]]: + def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str, list[DagDependency]]: """ Get the dependencies between DAGs diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index e4dbcb0e5eeb4..24b81de03103d 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -613,12 +613,12 @@ def command_as_list( wait_for_past_depends_before_skipping=False, ignore_ti_state=False, local=False, - pickle_id=None, + pickle_id: int | None = None, raw=False, job_id=None, pool=None, cfg_path=None, - ): + ) -> list[str]: """ Returns a command that can be executed anywhere where airflow is installed. This command is part of the message sent to executors by @@ -2613,7 +2613,7 @@ def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> Colum @Sentry.enrich_errors @provide_session - def schedule_downstream_tasks(self, session=None): + def schedule_downstream_tasks(self, session: Session = NEW_SESSION) -> None: """ The mini-scheduler for scheduling downstream tasks of this task instance :meta: private diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py index bbe09145c0c95..d8fe2c3ff28c0 100644 --- a/airflow/models/taskreschedule.py +++ b/airflow/models/taskreschedule.py @@ -23,14 +23,15 @@ from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, event, text from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Query, Session, relationship from airflow.models.base import COLLATION_ARGS, ID_LEN, Base -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime if TYPE_CHECKING: - from airflow.models.baseoperator import BaseOperator + from airflow.models.operator import Operator + from airflow.models.taskinstance import TaskInstance class TaskReschedule(Base): @@ -75,14 +76,14 @@ class TaskReschedule(Base): def __init__( self, - task: BaseOperator, + task: Operator, run_id: str, try_number: int, start_date: datetime.datetime, end_date: datetime.datetime, reschedule_date: datetime.datetime, map_index: int = -1, - ): + ) -> None: self.dag_id = task.dag_id self.task_id = task.task_id self.run_id = run_id @@ -95,7 +96,12 @@ def __init__( @staticmethod @provide_session - def query_for_task_instance(task_instance, descending=False, session=None, try_number=None): + def query_for_task_instance( + task_instance: TaskInstance, + descending: bool = False, + session: Session = NEW_SESSION, + try_number: int | None = None, + ) -> Query: """ Returns query for task reschedules for a given the task instance. @@ -123,7 +129,11 @@ def query_for_task_instance(task_instance, descending=False, session=None, try_n @staticmethod @provide_session - def find_for_task_instance(task_instance, session=None, try_number=None): + def find_for_task_instance( + task_instance: TaskInstance, + session: Session = NEW_SESSION, + try_number: int | None = None, + ) -> list[TaskReschedule]: """ Returns all task reschedules for the task instance and try number, in ascending order. diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index 271941f0a3f8c..f6195fa3d6cb8 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -21,7 +21,7 @@ from typing import Any, Iterable from sqlalchemy import Column, Integer, String, func, or_ -from sqlalchemy.orm import joinedload, relationship +from sqlalchemy.orm import Session, joinedload, relationship from airflow.api_internal.internal_api_call import internal_api_call from airflow.models.base import Base @@ -29,9 +29,9 @@ from airflow.triggers.base import BaseTrigger from airflow.utils import timezone from airflow.utils.retries import run_with_db_retries -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, with_row_locks -from airflow.utils.state import State +from airflow.utils.state import TaskInstanceState class Trigger(Base): @@ -68,7 +68,12 @@ class Trigger(Base): task_instance = relationship("TaskInstance", back_populates="trigger", lazy="joined", uselist=False) - def __init__(self, classpath: str, kwargs: dict[str, Any], created_date: datetime.datetime | None = None): + def __init__( + self, + classpath: str, + kwargs: dict[str, Any], + created_date: datetime.datetime | None = None, + ) -> None: super().__init__() self.classpath = classpath self.kwargs = kwargs @@ -76,7 +81,7 @@ def __init__(self, classpath: str, kwargs: dict[str, Any], created_date: datetim @classmethod @internal_api_call - def from_object(cls, trigger: BaseTrigger): + def from_object(cls, trigger: BaseTrigger) -> Trigger: """ Alternative constructor that creates a trigger row based directly off of a Trigger object. @@ -87,7 +92,7 @@ def from_object(cls, trigger: BaseTrigger): @classmethod @internal_api_call @provide_session - def bulk_fetch(cls, ids: Iterable[int], session=None) -> dict[int, Trigger]: + def bulk_fetch(cls, ids: Iterable[int], session: Session = NEW_SESSION) -> dict[int, Trigger]: """ Fetches all the Triggers by ID and returns a dict mapping ID -> Trigger instance @@ -106,7 +111,7 @@ def bulk_fetch(cls, ids: Iterable[int], session=None) -> dict[int, Trigger]: @classmethod @internal_api_call @provide_session - def clean_unused(cls, session=None): + def clean_unused(cls, session: Session = NEW_SESSION) -> None: """ Deletes all triggers that have no tasks/DAGs dependent on them (triggers have a one-to-many relationship to both) @@ -115,7 +120,7 @@ def clean_unused(cls, session=None): for attempt in run_with_db_retries(): with attempt: session.query(TaskInstance).filter( - TaskInstance.state != State.DEFERRED, TaskInstance.trigger_id.isnot(None) + TaskInstance.state != TaskInstanceState.DEFERRED, TaskInstance.trigger_id.isnot(None) ).update({TaskInstance.trigger_id: None}) # Get all triggers that have no task instances depending on them... ids = [ @@ -133,13 +138,13 @@ def clean_unused(cls, session=None): @classmethod @internal_api_call @provide_session - def submit_event(cls, trigger_id, event, session=None): + def submit_event(cls, trigger_id, event, session: Session = NEW_SESSION) -> None: """ Takes an event from an instance of itself, and triggers all dependent tasks to resume. """ for task_instance in session.query(TaskInstance).filter( - TaskInstance.trigger_id == trigger_id, TaskInstance.state == State.DEFERRED + TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED ): # Add the event's payload into the kwargs for the task next_kwargs = task_instance.next_kwargs or {} @@ -148,12 +153,12 @@ def submit_event(cls, trigger_id, event, session=None): # Remove ourselves as its trigger task_instance.trigger_id = None # Finally, mark it as scheduled so it gets re-queued - task_instance.state = State.SCHEDULED + task_instance.state = TaskInstanceState.SCHEDULED @classmethod @internal_api_call @provide_session - def submit_failure(cls, trigger_id, exc=None, session=None): + def submit_failure(cls, trigger_id, exc=None, session: Session = NEW_SESSION) -> None: """ Called when a trigger has failed unexpectedly, and we need to mark everything that depended on it as failed. Notably, we have to actually @@ -170,7 +175,7 @@ def submit_failure(cls, trigger_id, exc=None, session=None): in-process, but we can't do that right now. """ for task_instance in session.query(TaskInstance).filter( - TaskInstance.trigger_id == trigger_id, TaskInstance.state == State.DEFERRED + TaskInstance.trigger_id == trigger_id, TaskInstance.state == TaskInstanceState.DEFERRED ): # Add the error and set the next_method to the fail state traceback = format_exception(type(exc), exc, exc.__traceback__) if exc else None @@ -179,19 +184,19 @@ def submit_failure(cls, trigger_id, exc=None, session=None): # Remove ourselves as its trigger task_instance.trigger_id = None # Finally, mark it as scheduled so it gets re-queued - task_instance.state = State.SCHEDULED + task_instance.state = TaskInstanceState.SCHEDULED @classmethod @internal_api_call @provide_session - def ids_for_triggerer(cls, triggerer_id, session=None): + def ids_for_triggerer(cls, triggerer_id, session: Session = NEW_SESSION) -> list[int]: """Retrieves a list of triggerer_ids.""" return [row[0] for row in session.query(cls.id).filter(cls.triggerer_id == triggerer_id)] @classmethod @internal_api_call @provide_session - def assign_unassigned(cls, triggerer_id, capacity, session=None): + def assign_unassigned(cls, triggerer_id, capacity, session: Session = NEW_SESSION) -> None: """ Takes a triggerer_id and the capacity for that triggerer and assigns unassigned triggers until that capacity is reached, or there are no more unassigned triggers. diff --git a/airflow/secrets/metastore.py b/airflow/secrets/metastore.py index 803ede8c304bb..dc81675d771ff 100644 --- a/airflow/secrets/metastore.py +++ b/airflow/secrets/metastore.py @@ -21,9 +21,11 @@ import warnings from typing import TYPE_CHECKING +from sqlalchemy.orm import Session + from airflow.exceptions import RemovedInAirflow3Warning from airflow.secrets import BaseSecretsBackend -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: from airflow.models.connection import Connection @@ -33,7 +35,7 @@ class MetastoreBackend(BaseSecretsBackend): """Retrieves Connection object and Variable from airflow metastore database.""" @provide_session - def get_connection(self, conn_id, session=None) -> Connection | None: + def get_connection(self, conn_id: str, session: Session = NEW_SESSION) -> Connection | None: from airflow.models.connection import Connection conn = session.query(Connection).filter(Connection.conn_id == conn_id).first() @@ -41,7 +43,7 @@ def get_connection(self, conn_id, session=None) -> Connection | None: return conn @provide_session - def get_connections(self, conn_id, session=None) -> list[Connection]: + def get_connections(self, conn_id: str, session: Session = NEW_SESSION) -> list[Connection]: warnings.warn( "This method is deprecated. Please use " "`airflow.secrets.metastore.MetastoreBackend.get_connection`.", @@ -54,7 +56,7 @@ def get_connections(self, conn_id, session=None) -> list[Connection]: return [] @provide_session - def get_variable(self, key: str, session=None): + def get_variable(self, key: str, session: Session = NEW_SESSION) -> str | None: """ Get Airflow Variable from Metadata DB. diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 36b1a002f8458..a79302ddce491 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -35,11 +35,13 @@ from airflow.sensors.base import BaseSensorOperator from airflow.utils.file import correct_maybe_zipped from airflow.utils.helpers import build_airflow_url_with_query -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State if TYPE_CHECKING: - from sqlalchemy.orm import Query + from sqlalchemy.orm import Query, Session + + from airflow.utils.context import Context class ExternalDagLink(BaseOperatorLink): @@ -215,7 +217,7 @@ def _get_dttm_filter(self, context): return dttm if isinstance(dttm, list) else [dttm] @provide_session - def poke(self, context, session=None): + def poke(self, context: Context, session: Session = NEW_SESSION) -> bool: # delay check to poke rather than __init__ in case it was supplied as XComArgs if self.external_task_ids and len(self.external_task_ids) > len(set(self.external_task_ids)): raise ValueError("Duplicate task_ids passed in external_task_ids parameter") diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index 91bc9e638d3a5..d5d55c20cbd48 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -32,12 +32,14 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, TypeVar, cast +from sqlalchemy.orm import Session + from airflow import settings from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.utils import cli_action_loggers from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler from airflow.utils.platform import getuser, is_terminal_support_colors -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session T = TypeVar("T", bound=Callable) @@ -253,7 +255,7 @@ def get_dags(subdir: str | None, dag_id: str, use_regex: bool = False): @provide_session -def get_dag_by_pickle(pickle_id, session=None): +def get_dag_by_pickle(pickle_id: int, session: Session = NEW_SESSION) -> DAG: """Fetch DAG from the database using pickling""" from airflow.models import DagPickle diff --git a/airflow/utils/log/secrets_masker.py b/airflow/utils/log/secrets_masker.py index f73444174ca46..cac82d0dd12e2 100644 --- a/airflow/utils/log/secrets_masker.py +++ b/airflow/utils/log/secrets_masker.py @@ -28,6 +28,7 @@ Dict, Generator, Iterable, + Iterator, List, TextIO, Tuple, @@ -352,11 +353,61 @@ class RedactedIO(TextIO): def __init__(self): self.target = sys.stdout - self.fileno = sys.stdout.fileno + + def __enter__(self) -> TextIO: + return self.target.__enter__() + + def __exit__(self, t, v, b) -> None: + return self.target.__exit__(t, v, b) + + def __iter__(self) -> Iterator[str]: + return iter(self.target) + + def __next__(self) -> str: + return next(self.target) + + def close(self) -> None: + return self.target.close() + + def fileno(self) -> int: + return self.target.fileno() + + def flush(self) -> None: + return self.target.flush() + + def isatty(self) -> bool: + return self.target.isatty() + + def read(self, n: int = -1) -> str: + return self.target.read(n) + + def readable(self) -> bool: + return self.target.readable() + + def readline(self, n: int = -1) -> str: + return self.target.readline(n) + + def readlines(self, n: int = -1) -> list[str]: + return self.target.readlines(n) + + def seek(self, offset: int, whence: int = 0) -> int: + return self.target.seek(offset, whence) + + def seekable(self) -> bool: + return self.target.seekable() + + def tell(self) -> int: + return self.target.tell() + + def truncate(self, s: int | None = None) -> int: + return self.target.truncate(s) + + def writable(self) -> bool: + return self.target.writable() def write(self, s: str) -> int: s = redact(s) return self.target.write(s) - def flush(self) -> None: - return self.target.flush() + def writelines(self, lines) -> None: + self.target.writelines(lines) diff --git a/airflow/www/fab_security/sqla/manager.py b/airflow/www/fab_security/sqla/manager.py index a7533c5d7f237..e55255d67c212 100644 --- a/airflow/www/fab_security/sqla/manager.py +++ b/airflow/www/fab_security/sqla/manager.py @@ -232,7 +232,7 @@ def update_user(self, user): def get_user_by_id(self, pk): return self.get_session.get(self.user_model, pk) - def add_role(self, name: str) -> Role | None: + def add_role(self, name: str) -> Role: role = self.find_role(name) if role is None: try: @@ -546,7 +546,7 @@ def perms_include_action(self, perms, action_name): return True return False - def add_permission_to_role(self, role: Role, permission: Permission) -> None: + def add_permission_to_role(self, role: Role, permission: Permission | None) -> None: """ Add an existing permission pair to a role. diff --git a/airflow/www/security.py b/airflow/www/security.py index 72b2b95acb4ee..201c9ada0ad45 100644 --- a/airflow/www/security.py +++ b/airflow/www/security.py @@ -18,7 +18,7 @@ from __future__ import annotations import warnings -from typing import Sequence +from typing import Any, Collection, Container, Iterable, Sequence from flask import g from sqlalchemy import or_ @@ -28,7 +28,7 @@ from airflow.models import DagBag, DagModel from airflow.security import permissions from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.www.fab_security.sqla.manager import SecurityManager from airflow.www.fab_security.sqla.models import Permission, Resource, Role, User from airflow.www.fab_security.views import ( @@ -158,7 +158,7 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin): # DEFAULT ROLE CONFIGURATIONS ########################################################################### - ROLE_CONFIGS = [ + ROLE_CONFIGS: list[dict[str, Any]] = [ {"role": "Public", "perms": []}, {"role": "Viewer", "perms": VIEWER_PERMISSIONS}, { @@ -189,7 +189,7 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin): useroidmodelview = CustomUserOIDModelView userstatschartview = CustomUserStatsChartView - def __init__(self, appbuilder): + def __init__(self, appbuilder) -> None: super().__init__(appbuilder) # Go and fix up the SQLAInterface used from the stock one to our subclass. @@ -204,13 +204,13 @@ def __init__(self, appbuilder): view.datamodel = CustomSQLAInterface(view.datamodel.obj) self.perms = None - def create_db(self): + def create_db(self) -> None: if not self.appbuilder.update_perms: self.log.debug("Skipping db since appbuilder disables update_perms") return super().create_db() - def _get_root_dag_id(self, dag_id): + def _get_root_dag_id(self, dag_id: str) -> str: if "." in dag_id: dm = ( self.appbuilder.get_session.query(DagModel.dag_id, DagModel.root_dag_id) @@ -220,7 +220,7 @@ def _get_root_dag_id(self, dag_id): return dm.root_dag_id or dm.dag_id return dag_id - def init_role(self, role_name, perms): + def init_role(self, role_name, perms) -> None: """ Initialize the role with actions and related resources. :param role_name: @@ -234,7 +234,7 @@ def init_role(self, role_name, perms): ) self.bulk_sync_roles([{"role": role_name, "perms": perms}]) - def bulk_sync_roles(self, roles): + def bulk_sync_roles(self, roles: Iterable[dict[str, Any]]) -> None: """Sync the provided roles and permissions.""" existing_roles = self._get_all_roles_with_permissions() non_dag_perms = self._get_all_non_dag_permissions() @@ -252,7 +252,7 @@ def bulk_sync_roles(self, roles): if perm not in role.permissions: self.add_permission_to_role(role, perm) - def delete_role(self, role_name): + def delete_role(self, role_name: str) -> None: """ Delete the given Role @@ -279,7 +279,7 @@ def get_user_roles(user=None): user = g.user return user.roles - def get_readable_dags(self, user): + def get_readable_dags(self, user) -> Iterable[DagModel]: """Gets the DAGs readable by authenticated user.""" warnings.warn( "`get_readable_dags` has been deprecated. Please use `get_readable_dag_ids` instead.", @@ -290,7 +290,7 @@ def get_readable_dags(self, user): warnings.simplefilter("ignore", RemovedInAirflow3Warning) return self.get_accessible_dags([permissions.ACTION_CAN_READ], user) - def get_editable_dags(self, user): + def get_editable_dags(self, user) -> Iterable[DagModel]: """Gets the DAGs editable by authenticated user.""" warnings.warn( "`get_editable_dags` has been deprecated. Please use `get_editable_dag_ids` instead.", @@ -302,7 +302,12 @@ def get_editable_dags(self, user): return self.get_accessible_dags([permissions.ACTION_CAN_EDIT], user) @provide_session - def get_accessible_dags(self, user_actions, user, session=None): + def get_accessible_dags( + self, + user_actions: Container[str] | None, + user, + session: Session = NEW_SESSION, + ) -> Iterable[DagModel]: warnings.warn( "`get_accessible_dags` has been deprecated. Please use `get_accessible_dag_ids` instead.", RemovedInAirflow3Warning, @@ -320,7 +325,12 @@ def get_editable_dag_ids(self, user) -> set[str]: return self.get_accessible_dag_ids(user, [permissions.ACTION_CAN_EDIT]) @provide_session - def get_accessible_dag_ids(self, user, user_actions=None, session=None) -> set[str]: + def get_accessible_dag_ids( + self, + user, + user_actions: Container[str] | None = None, + session: Session = NEW_SESSION, + ) -> set[str]: """Generic function to get readable or writable DAGs for user.""" if not user_actions: user_actions = [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ] @@ -372,25 +382,25 @@ def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool: return any(self.get_readable_dag_ids(user)) return any(self.get_editable_dag_ids(user)) - def can_read_dag(self, dag_id, user=None) -> bool: + def can_read_dag(self, dag_id: str, user=None) -> bool: """Determines whether a user has DAG read access.""" root_dag_id = self._get_root_dag_id(dag_id) dag_resource_name = permissions.resource_name_for_dag(root_dag_id) return self.has_access(permissions.ACTION_CAN_READ, dag_resource_name, user=user) - def can_edit_dag(self, dag_id, user=None) -> bool: + def can_edit_dag(self, dag_id: str, user=None) -> bool: """Determines whether a user has DAG edit access.""" root_dag_id = self._get_root_dag_id(dag_id) dag_resource_name = permissions.resource_name_for_dag(root_dag_id) return self.has_access(permissions.ACTION_CAN_EDIT, dag_resource_name, user=user) - def can_delete_dag(self, dag_id, user=None) -> bool: + def can_delete_dag(self, dag_id: str, user=None) -> bool: """Determines whether a user has DAG delete access.""" root_dag_id = self._get_root_dag_id(dag_id) dag_resource_name = permissions.resource_name_for_dag(root_dag_id) return self.has_access(permissions.ACTION_CAN_DELETE, dag_resource_name, user=user) - def prefixed_dag_id(self, dag_id): + def prefixed_dag_id(self, dag_id: str) -> str: """Returns the permission name for a DAG id.""" warnings.warn( "`prefixed_dag_id` has been deprecated. " @@ -401,13 +411,13 @@ def prefixed_dag_id(self, dag_id): root_dag_id = self._get_root_dag_id(dag_id) return permissions.resource_name_for_dag(root_dag_id) - def is_dag_resource(self, resource_name): + def is_dag_resource(self, resource_name: str) -> bool: """Determines if a resource belongs to a DAG or all DAGs.""" if resource_name == permissions.RESOURCE_DAG: return True return resource_name.startswith(permissions.RESOURCE_DAG_PREFIX) - def has_access(self, action_name, resource_name, user=None) -> bool: + def has_access(self, action_name: str, resource_name: str, user=None) -> bool: """ Verify whether a given user could perform a certain action (e.g can_read, can_write, can_delete) on the given resource. @@ -430,13 +440,13 @@ def has_access(self, action_name, resource_name, user=None) -> bool: return False - def _has_role(self, role_name_or_list, user): + def _has_role(self, role_name_or_list: Container, user) -> bool: """Whether the user has this role name""" if not isinstance(role_name_or_list, list): role_name_or_list = [role_name_or_list] return any(r.name in role_name_or_list for r in user.roles) - def has_all_dags_access(self, user): + def has_all_dags_access(self, user) -> bool: """ Has all the dag access in any of the 3 cases: 1. Role needs to be in (Admin, Viewer, User, Op). @@ -451,15 +461,15 @@ def has_all_dags_access(self, user): or self.can_edit_all_dags(user) ) - def can_edit_all_dags(self, user=None): + def can_edit_all_dags(self, user=None) -> bool: """Has can_edit action on DAG resource""" return self.has_access(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG, user) - def can_read_all_dags(self, user=None): + def can_read_all_dags(self, user=None) -> bool: """Has can_read action on DAG resource""" return self.has_access(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG, user) - def clean_perms(self): + def clean_perms(self) -> None: """FAB leaves faulty permissions that need to be cleaned up""" self.log.debug("Cleaning faulty perms") sesh = self.appbuilder.get_session @@ -481,7 +491,7 @@ def clean_perms(self): if deleted_count: self.log.info("Deleted %s faulty permissions", deleted_count) - def _merge_perm(self, action_name, resource_name): + def _merge_perm(self, action_name: str, resource_name: str) -> None: """ Add the new (action, resource) to assoc_permission_role if it doesn't exist. It will add the related entry to ab_permission and ab_resource two meta tables as well. @@ -502,7 +512,7 @@ def _merge_perm(self, action_name, resource_name): if not perm and action_name and resource_name: self.create_permission(action_name, resource_name) - def add_homepage_access_to_custom_roles(self): + def add_homepage_access_to_custom_roles(self) -> None: """ Add Website.can_read access to all custom roles. @@ -576,7 +586,7 @@ def create_dag_specific_permissions(self) -> None: if dag.access_control: self.sync_perm_for_dag(dag_resource_name, dag.access_control) - def update_admin_permission(self): + def update_admin_permission(self) -> None: """ Admin should have all the permissions, except the dag permissions. because Admin already has Dags permission. @@ -598,7 +608,7 @@ def update_admin_permission(self): session.commit() - def sync_roles(self): + def sync_roles(self) -> None: """ 1. Init the default role(Admin, Viewer, User, Op, public) with related permissions. @@ -617,7 +627,7 @@ def sync_roles(self): self.update_admin_permission() self.clean_perms() - def sync_resource_permissions(self, perms=None): + def sync_resource_permissions(self, perms: Iterable[tuple[str, str]] | None = None) -> None: """Populates resource-based permissions.""" if not perms: return @@ -626,7 +636,11 @@ def sync_resource_permissions(self, perms=None): self.create_resource(resource_name) self.create_permission(action_name, resource_name) - def sync_perm_for_dag(self, dag_id, access_control=None): + def sync_perm_for_dag( + self, + dag_id: str, + access_control: dict[str, Collection[str]] | None = None, + ) -> None: """ Sync permissions for given dag id. The dag id surely exists in our dag bag as only / refresh button or DagBag will call this function @@ -644,7 +658,7 @@ def sync_perm_for_dag(self, dag_id, access_control=None): if access_control: self._sync_dag_view_permissions(dag_resource_name, access_control) - def _sync_dag_view_permissions(self, dag_id, access_control): + def _sync_dag_view_permissions(self, dag_id: str, access_control: dict[str, Collection[str]]) -> None: """ Set the access policy on the given DAG's ViewModel. @@ -667,7 +681,7 @@ def _revoke_stale_permissions(resource: Resource): for perm in existing_dag_perms: non_admin_roles = [role for role in perm.role if role.name != "Admin"] for role in non_admin_roles: - target_perms_for_role = access_control.get(role.name, {}) + target_perms_for_role = access_control.get(role.name, ()) if perm.action.name not in target_perms_for_role: self.log.info( "Revoking '%s' on DAG '%s' for role '%s'", @@ -703,7 +717,7 @@ def _revoke_stale_permissions(resource: Resource): if dag_perm: self.add_permission_to_role(role, dag_perm) - def create_perm_vm_for_all_dag(self): + def create_perm_vm_for_all_dag(self) -> None: """Create perm-vm if not exist and insert into FAB security model for all-dags.""" # create perm for global logical dag for resource_name in self.DAG_RESOURCES: @@ -711,7 +725,9 @@ def create_perm_vm_for_all_dag(self): self._merge_perm(action_name, resource_name) def check_authorization( - self, perms: Sequence[tuple[str, str]] | None = None, dag_id: str | None = None + self, + perms: Sequence[tuple[str, str]] | None = None, + dag_id: str | None = None, ) -> bool: """Checks that the logged in user has the specified permissions.""" if not perms: diff --git a/airflow/www/views.py b/airflow/www/views.py index e6c184cd72aab..9abe396c59c6c 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -19,6 +19,7 @@ import collections import copy +import datetime import itertools import json import logging @@ -29,7 +30,6 @@ import warnings from bisect import insort_left from collections import defaultdict -from datetime import datetime, timedelta from functools import wraps from json import JSONDecodeError from typing import Any, Callable, Collection, Iterator, Mapping, MutableMapping, Sequence @@ -117,7 +117,7 @@ from airflow.utils.log.log_reader import TaskLogReader from airflow.utils.net import get_hostname from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.state import State, TaskInstanceState +from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.strings import to_boolean from airflow.utils.task_group import MappedTaskGroup, TaskGroup, task_group_to_dict from airflow.utils.timezone import td_format, utcnow @@ -198,11 +198,11 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): if base_date: base_date = _safe_parse_datetime(base_date) elif run_id: - base_date = (timezone.utcnow() + timedelta(seconds=1)).replace(microsecond=0) + base_date = (timezone.utcnow() + datetime.timedelta(seconds=1)).replace(microsecond=0) else: # The DateTimeField widget truncates milliseconds and would loose # the first dag run. Round to next second. - base_date = (date_time + timedelta(seconds=1)).replace(microsecond=0) + base_date = (date_time + datetime.timedelta(seconds=1)).replace(microsecond=0) default_dag_run = conf.getint("webserver", "default_dag_run_display_number") num_runs = www_request.args.get("num_runs", default=default_dag_run, type=int) @@ -214,7 +214,7 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): # loaded and the actual requested run would be excluded by the limit(). Once # the user has changed base date to be anything else we want to use that instead. query_date = base_date - if date_time < base_date and date_time + timedelta(seconds=1) >= base_date: + if date_time < base_date and date_time + datetime.timedelta(seconds=1) >= base_date: query_date = date_time drs = ( @@ -247,7 +247,7 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): } -def _safe_parse_datetime(v, allow_empty=False) -> datetime | None: +def _safe_parse_datetime(v, allow_empty=False) -> datetime.datetime | None: """ Parse datetime and return error message for invalid dates @@ -863,7 +863,7 @@ def _iter_parsed_moved_data_table_names(): robots_file_access_count = ( session.query(Log) .filter(Log.event == "robots") - .filter(Log.dttm > (utcnow() - timedelta(days=7))) + .filter(Log.dttm > (utcnow() - datetime.timedelta(days=7))) .count() ) if robots_file_access_count > 0: @@ -930,7 +930,7 @@ def datasets(self): @expose("/next_run_datasets_summary", methods=["POST"]) @auth.has_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) @provide_session - def next_run_datasets_summary(self, session=None): + def next_run_datasets_summary(self, session: Session = NEW_SESSION): """Next run info for dataset triggered DAGs.""" allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) @@ -966,7 +966,7 @@ def next_run_datasets_summary(self, session=None): ] ) @provide_session - def dag_stats(self, session=None): + def dag_stats(self, session: Session = NEW_SESSION): """Dag statistics.""" dr = models.DagRun @@ -987,13 +987,11 @@ def dag_stats(self, session=None): if not filter_dag_ids: return flask.json.jsonify({}) - payload = {} dag_state_stats = dag_state_stats.filter(dr.dag_id.in_(filter_dag_ids)) - data = {} + data: dict[str, dict[str, int]] = collections.defaultdict(dict) + payload: dict[str, list[dict[str, Any]]] = collections.defaultdict(list) for dag_id, state, count in dag_state_stats: - if dag_id not in data: - data[dag_id] = {} data[dag_id][state] = count for dag_id in filter_dag_ids: @@ -1013,7 +1011,7 @@ def dag_stats(self, session=None): ] ) @provide_session - def task_stats(self, session=None): + def task_stats(self, session: Session = NEW_SESSION): """Task Statistics""" allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) @@ -1106,9 +1104,8 @@ def task_stats(self, session=None): ) data = get_task_stats_from_query(qry) - payload = {} + payload: dict[str, list[dict[str, Any]]] = collections.defaultdict(list) for dag_id in filter_dag_ids: - payload[dag_id] = [] for state in State.task_states: count = data.get(dag_id, {}).get(state, 0) payload[dag_id].append({"state": state, "count": count}) @@ -1122,7 +1119,7 @@ def task_stats(self, session=None): ] ) @provide_session - def last_dagruns(self, session=None): + def last_dagruns(self, session: Session = NEW_SESSION): """Last DAG runs""" allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) @@ -1196,11 +1193,11 @@ def legacy_code(self): ] ) @provide_session - def code(self, dag_id, session=None): + def code(self, dag_id, session: Session = NEW_SESSION): """Dag Code.""" dag = get_airflow_app().dag_bag.get_dag(dag_id, session=session) dag_model = DagModel.get_dagmodel(dag_id, session=session) - if not dag: + if not dag or not dag_model: flash(f'DAG "{dag_id}" seems to be missing.', "error") return redirect(url_for("Airflow.index")) @@ -1243,7 +1240,7 @@ def legacy_dag_details(self): ] ) @provide_session - def dag_details(self, dag_id, session=None): + def dag_details(self, dag_id, session: Session = NEW_SESSION): """Get Dag details.""" from airflow.models.dag import DagOwnerAttributes @@ -1266,7 +1263,7 @@ def dag_details(self, dag_id, session=None): .all() ) - active_runs = models.DagRun.find(dag_id=dag_id, state=State.RUNNING, external_trigger=False) + active_runs = models.DagRun.find(dag_id=dag_id, state=DagRunState.RUNNING, external_trigger=False) tags = session.query(models.DagTag).filter(models.DagTag.dag_id == dag_id).all() @@ -1488,28 +1485,25 @@ def rendered_k8s(self, *, session: Session = NEW_SESSION): ) @action_logging @provide_session - def get_logs_with_metadata(self, session=None): + def get_logs_with_metadata(self, session: Session = NEW_SESSION): """Retrieve logs including metadata.""" dag_id = request.args.get("dag_id") task_id = request.args.get("task_id") - execution_date = request.args.get("execution_date") + execution_date_str = request.args["execution_date"] map_index = request.args.get("map_index", -1, type=int) try_number = request.args.get("try_number", type=int) - metadata = request.args.get("metadata", "{}") + metadata_str = request.args.get("metadata", "{}") response_format = request.args.get("format", "json") # Validate JSON metadata try: - metadata = json.loads(metadata) - # metadata may be null - if not metadata: - metadata = {} + metadata: dict = json.loads(metadata_str) or {} except json.decoder.JSONDecodeError: return {"error": "Invalid JSON metadata"}, 400 # Convert string datetime into actual datetime try: - execution_date = timezone.parse(execution_date) + execution_date = timezone.parse(execution_date_str) except ValueError: error_message = ( f"Given execution date, {execution_date}, could not be identified as a date. " @@ -1565,9 +1559,9 @@ def get_logs_with_metadata(self, session=None): headers={"Content-Disposition": f"attachment; filename={attachment_filename}"}, ) except AttributeError as e: - error_message = [f"Task log handler does not support read logs.\n{str(e)}\n"] + error_messages = [f"Task log handler does not support read logs.\n{str(e)}\n"] metadata["end_of_log"] = True - return {"message": error_message, "error": True, "metadata": metadata} + return {"message": error_messages, "error": True, "metadata": metadata} @expose("/log") @auth.has_access( @@ -1579,9 +1573,9 @@ def get_logs_with_metadata(self, session=None): ) @action_logging @provide_session - def log(self, session=None): + def log(self, session: Session = NEW_SESSION): """Retrieve log.""" - dag_id = request.args.get("dag_id") + dag_id = request.args["dag_id"] task_id = request.args.get("task_id") map_index = request.args.get("map_index", -1, type=int) execution_date = request.args.get("execution_date") @@ -1629,7 +1623,7 @@ def log(self, session=None): ) @action_logging @provide_session - def redirect_to_external_log(self, session=None): + def redirect_to_external_log(self, session: Session = NEW_SESSION): """Redirects to external log.""" dag_id = request.args.get("dag_id") task_id = request.args.get("task_id") @@ -1666,7 +1660,7 @@ def redirect_to_external_log(self, session=None): ) @action_logging @provide_session - def task(self, session): + def task(self, session: Session = NEW_SESSION): """Retrieve task.""" dag_id = request.args.get("dag_id") task_id = request.args.get("task_id") @@ -1721,7 +1715,7 @@ def task(self, session): attr_renderers = wwwutils.get_attr_renderer() - attrs_to_skip = getattr(task, "HIDE_ATTRS_FROM_UI", set()) + attrs_to_skip: frozenset[str] = getattr(task, "HIDE_ATTRS_FROM_UI", frozenset()) def include_task_attrs(attr_name): return not ( @@ -1797,9 +1791,9 @@ def include_task_attrs(attr_name): ) @action_logging @provide_session - def xcom(self, session=None): + def xcom(self, session: Session = NEW_SESSION): """Retrieve XCOM.""" - dag_id = request.args.get("dag_id") + dag_id = request.args["dag_id"] task_id = request.args.get("task_id") map_index = request.args.get("map_index", -1, type=int) # Carrying execution_date through, even though it's irrelevant for @@ -1884,10 +1878,10 @@ def delete(self): ) @action_logging @provide_session - def trigger(self, session=None): + def trigger(self, session: Session = NEW_SESSION): """Triggers DAG Run.""" - dag_id = request.values.get("dag_id") - run_id = request.values.get("run_id") + dag_id = request.values["dag_id"] + run_id = request.values.get("run_id", "") origin = get_safe_url(request.values.get("origin")) unpause = request.values.get("unpause") request_conf = request.values.get("conf") @@ -2047,7 +2041,9 @@ def trigger(self, session=None): ) if unpause and dag.get_is_paused(): - models.DagModel.get_dagmodel(dag_id).set_is_paused(is_paused=False) + dag_model = models.DagModel.get_dagmodel(dag_id) + if dag_model is not None: + dag_model.set_is_paused(is_paused=False) try: dag.create_dagrun( @@ -2082,8 +2078,8 @@ def trigger(self, session=None): def _clear_dag_tis( self, dag: DAG, - start_date: datetime | None, - end_date: datetime | None, + start_date: datetime.datetime | None, + end_date: datetime.datetime | None, *, origin: str | None, task_ids: Collection[str | tuple[str, int]] | None = None, @@ -2266,7 +2262,7 @@ def dagrun_clear(self, *, session: Session = NEW_SESSION): ] ) @provide_session - def blocked(self, session=None): + def blocked(self, session: Session = NEW_SESSION): """Mark Dag Blocked.""" allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) @@ -2341,7 +2337,13 @@ def _mark_dagrun_state_as_success(self, dag_id, dag_run_id, confirmed): return htmlsafe_json_dumps(details, separators=(",", ":")) @provide_session - def _mark_dagrun_state_as_queued(self, dag_id: str, dag_run_id: str, confirmed: bool, session=None): + def _mark_dagrun_state_as_queued( + self, + dag_id: str, + dag_run_id: str, + confirmed: bool, + session: Session = NEW_SESSION, + ): if not dag_run_id: return {"status": "error", "message": "Invalid dag_run_id"} @@ -2658,7 +2660,7 @@ def legacy_tree(self): @gzipped @action_logging @provide_session - def grid(self, dag_id, session=None): + def grid(self, dag_id: str, session: Session = NEW_SESSION): """Get Dag's grid view.""" dag = get_airflow_app().dag_bag.get_dag(dag_id, session=session) dag_model = DagModel.get_dagmodel(dag_id, session=session) @@ -2736,7 +2738,7 @@ def legacy_calendar(self): @gzipped @action_logging @provide_session - def calendar(self, dag_id, session=None): + def calendar(self, dag_id: str, session: Session = NEW_SESSION): """Get DAG runs as calendar""" def _convert_to_date(session, column): @@ -2792,14 +2794,14 @@ def _convert_to_date(session, column): year = last_automated_data_interval.end.year restriction = TimeRestriction(dag.start_date, dag.end_date, False) - dates = collections.Counter() + dates: dict[datetime.date, int] = collections.Counter() if isinstance(dag.timetable, CronMixin): # Optimized calendar generation for timetables based on a cron expression. - dates_iter: Iterator[datetime | None] = croniter( + dates_iter: Iterator[datetime.datetime | None] = croniter( dag.timetable._expression, start_time=last_automated_data_interval.end, - ret_type=datetime, + ret_type=datetime.datetime, ) for dt in dates_iter: if dt is None: @@ -2832,23 +2834,17 @@ def _convert_to_date(session, column): ) now = DateTime.utcnow() - data = { "dag_states": data_dag_states, - "start_date": (dag.start_date or DateTime.utcnow()).date().isoformat(), + "start_date": (dag.start_date or now).date().isoformat(), "end_date": (dag.end_date or now).date().isoformat(), } - doc_md = wwwutils.wrapped_markdown(getattr(dag, "doc_md", None)) - - # avoid spaces to reduce payload size - data = htmlsafe_json_dumps(data, separators=(",", ":")) - return self.render_template( "airflow/calendar.html", dag=dag, - doc_md=doc_md, - data=data, + doc_md=wwwutils.wrapped_markdown(getattr(dag, "doc_md", None)), + data=htmlsafe_json_dumps(data, separators=(",", ":")), # Avoid spaces to reduce payload size. root=root, dag_model=dag_model, ) @@ -2878,7 +2874,7 @@ def legacy_graph(self): @gzipped @action_logging @provide_session - def graph(self, dag_id, session=None): + def graph(self, dag_id: str, session: Session = NEW_SESSION): """Get DAG as Graph.""" dag = get_airflow_app().dag_bag.get_dag(dag_id, session=session) dag_model = DagModel.get_dagmodel(dag_id, session=session) @@ -2995,7 +2991,7 @@ def legacy_duration(self): ) @action_logging @provide_session - def duration(self, dag_id, session=None): + def duration(self, dag_id: str, session: Session = NEW_SESSION): """Get Dag as duration graph.""" dag = get_airflow_app().dag_bag.get_dag(dag_id, session=session) dag_model = DagModel.get_dagmodel(dag_id, session=session) @@ -3007,11 +3003,11 @@ def duration(self, dag_id, session=None): wwwutils.check_dag_warnings(dag.dag_id, session) default_dag_run = conf.getint("webserver", "default_dag_run_display_number") - base_date = request.args.get("base_date") + base_date_str = request.args.get("base_date") num_runs = request.args.get("num_runs", default=default_dag_run, type=int) - if base_date: - base_date = _safe_parse_datetime(base_date) + if base_date_str: + base_date = _safe_parse_datetime(base_date_str) else: base_date = dag.get_latest_execution_date() or timezone.utcnow() @@ -3055,7 +3051,7 @@ def duration(self, dag_id, session=None): ) if dag.partial: ti_fails = ti_fails.filter(TaskFail.task_id.in_([t.task_id for t in dag.tasks])) - fails_totals = defaultdict(int) + fails_totals: dict[tuple[str, str, str], int] = defaultdict(int) for failed_task_instance in ti_fails: dict_key = ( failed_task_instance.dag_id, @@ -3068,8 +3064,8 @@ def duration(self, dag_id, session=None): # we must group any mapped TIs by dag_id, task_id, run_id mapped_tis = set() tis_grouped = itertools.groupby(task_instances, lambda x: (x.dag_id, x.task_id, x.run_id)) - for key, tis in tis_grouped: - tis = list(tis) + for _, group in tis_grouped: + tis = list(group) duration = sum(x.duration for x in tis if x.duration) if duration: first_ti = tis[0] @@ -3156,7 +3152,7 @@ def legacy_tries(self): ) @action_logging @provide_session - def tries(self, dag_id, session=None): + def tries(self, dag_id: str, session: Session = NEW_SESSION): """Shows all tries.""" dag = get_airflow_app().dag_bag.get_dag(dag_id, session=session) dag_model = DagModel.get_dagmodel(dag_id, session=session) @@ -3168,11 +3164,11 @@ def tries(self, dag_id, session=None): wwwutils.check_dag_warnings(dag.dag_id, session) default_dag_run = conf.getint("webserver", "default_dag_run_display_number") - base_date = request.args.get("base_date") + base_date_str = request.args.get("base_date") num_runs = request.args.get("num_runs", default=default_dag_run, type=int) - if base_date: - base_date = _safe_parse_datetime(base_date) + if base_date_str: + base_date = _safe_parse_datetime(base_date_str) else: base_date = dag.get_latest_execution_date() or timezone.utcnow() @@ -3251,7 +3247,7 @@ def legacy_landing_times(self): ) @action_logging @provide_session - def landing_times(self, dag_id, session=None): + def landing_times(self, dag_id: str, session: Session = NEW_SESSION): """Shows landing times.""" dag = get_airflow_app().dag_bag.get_dag(dag_id, session=session) dag_model = DagModel.get_dagmodel(dag_id, session=session) @@ -3263,11 +3259,11 @@ def landing_times(self, dag_id, session=None): wwwutils.check_dag_warnings(dag.dag_id, session) default_dag_run = conf.getint("webserver", "default_dag_run_display_number") - base_date = request.args.get("base_date") + base_date_str = request.args.get("base_date") num_runs = request.args.get("num_runs", default=default_dag_run, type=int) - if base_date: - base_date = _safe_parse_datetime(base_date) + if base_date_str: + base_date = _safe_parse_datetime(base_date_str) else: base_date = dag.get_latest_execution_date() or timezone.utcnow() @@ -3286,12 +3282,11 @@ def landing_times(self, dag_id, session=None): height=chart_height, chart_attr=self.line_chart_attr, ) - y_points = {} - x_points = {} + + y_points: dict[str, list[float]] = collections.defaultdict(list) + x_points: dict[str, list[tuple[int]]] = collections.defaultdict(list) for task in dag.tasks: task_id = task.task_id - y_points[task_id] = [] - x_points[task_id] = [] for ti in tis: if ti.task_id != task.task_id: continue @@ -3373,7 +3368,7 @@ def legacy_gantt(self): ) @action_logging @provide_session - def gantt(self, dag_id, session=None): + def gantt(self, dag_id: str, session: Session = NEW_SESSION): """Show GANTT chart.""" dag = get_airflow_app().dag_bag.get_dag(dag_id, session=session) dag_model = DagModel.get_dagmodel(dag_id, session=session) @@ -3419,7 +3414,7 @@ def gantt(self, dag_id, session=None): # or the try_number of the last complete run # https://issues.apache.org/jira/browse/AIRFLOW-2143 try_count = ti.prev_attempted_tries if ti.prev_attempted_tries != 0 else ti.try_number - task_dict = alchemy_to_dict(ti) + task_dict = alchemy_to_dict(ti) or {} task_dict["end_date"] = task_dict["end_date"] or timezone.utcnow() task_dict["extraLinks"] = dag.get_task(ti.task_id).extra_links task_dict["try_number"] = try_count @@ -3440,7 +3435,7 @@ def gantt(self, dag_id, session=None): prev_task_id = failed_task_instance.task_id tf_count += 1 task = dag.get_task(failed_task_instance.task_id) - task_dict = alchemy_to_dict(failed_task_instance) + task_dict = alchemy_to_dict(failed_task_instance) or {} end_date = task_dict["end_date"] or timezone.utcnow() task_dict["end_date"] = end_date task_dict["start_date"] = task_dict["start_date"] or end_date @@ -3544,7 +3539,7 @@ def extra_links(self, *, session: Session = NEW_SESSION): @gzipped @action_logging @provide_session - def graph_data(self, session=None): + def graph_data(self, session: Session = NEW_SESSION): """Get Graph Data""" dag_id = request.args.get("dag_id") dag = get_airflow_app().dag_bag.get_dag(dag_id, session=session) @@ -3859,7 +3854,7 @@ def legacy_audit_log(self): ] ) @provide_session - def audit_log(self, dag_id: str, session=None): + def audit_log(self, dag_id: str, session: Session = NEW_SESSION): dag = get_airflow_app().dag_bag.get_dag(dag_id, session=session) dag_model = DagModel.get_dagmodel(dag_id, session=session) if not dag: @@ -4198,7 +4193,13 @@ def action_mulemailsentfalse(self, items: list[SlaMiss]): return self._set_notification_property(items, "email_sent", False) @provide_session - def _set_notification_property(self, items: list[SlaMiss], attr: str, new_value: bool, session=None): + def _set_notification_property( + self, + items: list[SlaMiss], + attr: str, + new_value: bool, + session: Session = NEW_SESSION, + ): try: count = 0 for sla in items: @@ -4295,11 +4296,11 @@ class ConnectionFormWidget(FormWidget): """Form widget used to display connection""" @cached_property - def field_behaviours(self): + def field_behaviours(self) -> str: return json.dumps(ProvidersManager().field_behaviours) @cached_property - def testable_connection_types(self): + def testable_connection_types(self) -> list[str]: return [ connection_type for connection_type, hook_info in ProvidersManager().hooks.items() @@ -4444,7 +4445,7 @@ def action_muldelete(self, items): (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION), ] ) - def action_mulduplicate(self, connections, session=None): + def action_mulduplicate(self, connections, session: Session = NEW_SESSION): """Duplicate Multiple connections""" for selected_conn in connections: new_conn_id = selected_conn.conn_id @@ -5125,7 +5126,7 @@ def action_set_running(self, drs: list[DagRun]): return self._set_dag_runs_to_active_state(drs, State.RUNNING) @provide_session - def _set_dag_runs_to_active_state(self, drs: list[DagRun], state: str, session=None): + def _set_dag_runs_to_active_state(self, drs: list[DagRun], state: str, session: Session = NEW_SESSION): """This routine only supports Running and Queued state.""" try: count = 0 @@ -5150,7 +5151,7 @@ def _set_dag_runs_to_active_state(self, drs: list[DagRun], state: str, session=N @action_has_dag_edit_access @provide_session @action_logging - def action_set_failed(self, drs: list[DagRun], session=None): + def action_set_failed(self, drs: list[DagRun], session: Session = NEW_SESSION): """Set state to failed.""" try: count = 0 @@ -5178,7 +5179,7 @@ def action_set_failed(self, drs: list[DagRun], session=None): @action_has_dag_edit_access @provide_session @action_logging - def action_set_success(self, drs: list[DagRun], session=None): + def action_set_success(self, drs: list[DagRun], session: Session = NEW_SESSION): """Set state to success.""" try: count = 0 @@ -5201,7 +5202,7 @@ def action_set_success(self, drs: list[DagRun], session=None): @action_has_dag_edit_access @provide_session @action_logging - def action_clear(self, drs: list[DagRun], session=None): + def action_clear(self, drs: list[DagRun], session: Session = NEW_SESSION): """Clears the state.""" try: count = 0 @@ -5312,7 +5313,7 @@ def duration_f(self): end_date = self.get("end_date") duration = self.get("duration") if end_date and duration: - return td_format(timedelta(seconds=duration)) + return td_format(datetime.timedelta(seconds=duration)) return None formatters_columns = { @@ -5472,7 +5473,7 @@ def duration_f(self): end_date = self.get("end_date") duration = self.get("duration") if end_date and duration: - return td_format(timedelta(seconds=duration)) + return td_format(datetime.timedelta(seconds=duration)) return None formatters_columns = { @@ -5502,7 +5503,7 @@ def duration_f(self): @action_has_dag_edit_access @provide_session @action_logging - def action_clear(self, task_instances, session=None): + def action_clear(self, task_instances, session: Session = NEW_SESSION): """Clears the action.""" try: dag_to_tis = collections.defaultdict(list) @@ -5530,7 +5531,7 @@ def action_muldelete(self, items): return redirect(self.get_redirect()) @provide_session - def set_task_instance_state(self, tis, target_state, session=None): + def set_task_instance_state(self, tis, target_state, session: Session = NEW_SESSION): """Set task instance state.""" try: count = len(tis) @@ -5593,7 +5594,7 @@ class AutocompleteView(AirflowBaseView): @auth.has_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) @provide_session @expose("/dagmodel/autocomplete") - def autocomplete(self, session=None): + def autocomplete(self, session: Session = NEW_SESSION): """Autocomplete.""" query = unquote(request.args.get("query", "")) @@ -5638,7 +5639,7 @@ def autocomplete(self, session=None): class DagDependenciesView(AirflowBaseView): """View to show dependencies between DAGs""" - refresh_interval = timedelta( + refresh_interval = datetime.timedelta( seconds=conf.getint( "webserver", "dag_dependencies_refresh_interval", diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py index 053704fb059ab..273f107b4b7c8 100644 --- a/tests/www/views/test_views.py +++ b/tests/www/views/test_views.py @@ -408,7 +408,7 @@ def test_get_task_stats_from_query(): INVALID_DATETIME_RESPONSE, ), ( - "/log?execution_date=invalid", + "/log?dag_id=tutorial&execution_date=invalid", INVALID_DATETIME_RESPONSE, ), (