Skip to content

Commit

Permalink
Add more type hints to the code base (#30503)
Browse files Browse the repository at this point in the history
* Fully type Pool

Also fix a bug where create_or_update_pool silently fails when an empty
name is given. An error is raised instead now.

* Add types to 'airflow dags'

* Add types to 'airflow task' and 'airflow job'

* Improve KubernetesExecutor typing

* Add types to BackfillJob

This triggers an existing typing bug that pickle_id is incorrectly typed
as str in executors, while it should be int in practice. This is fixed
to keep things straight.

* Add types to job classes

* Fix missing DagModel case in SchedulerJob

* Add types to DagCode

* Add more types to DagRun

* Add types to serialized DAG model

* Add more types to TaskInstance and TaskReschedule

* Add types to Trigger

* Add types to MetastoreBackend

* Add types to external task sensor

* Add types to AirflowSecurityManager

This uncovers a couple of incorrect type hints in the base
SecurityManager (in fab_security), which are also fixed.

* Add types to views

This slightly improves how view functions are typechecked and should
prevent some trivial bugs.
  • Loading branch information
uranusjr authored Apr 7, 2023
1 parent 0b83f06 commit 1a85446
Show file tree
Hide file tree
Showing 32 changed files with 562 additions and 390 deletions.
5 changes: 3 additions & 2 deletions airflow/api/common/delete_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@
import logging

from sqlalchemy import and_, or_
from sqlalchemy.orm import Session

from airflow import models
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models import DagModel, TaskFail
from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils.db import get_sqla_model_classes
from airflow.utils.session import provide_session
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State

log = logging.getLogger(__name__)


@provide_session
def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int:
def delete_dag(dag_id: str, keep_records_in_log: bool = True, session: Session = NEW_SESSION) -> int:
"""
Delete a DAG by a dag_id.
Expand Down
11 changes: 6 additions & 5 deletions airflow/api/common/experimental/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
from __future__ import annotations

from deprecated import deprecated
from sqlalchemy.orm import Session

from airflow.exceptions import AirflowBadRequest, PoolNotFound
from airflow.models import Pool
from airflow.utils.session import provide_session
from airflow.utils.session import NEW_SESSION, provide_session


@deprecated(reason="Use Pool.get_pool() instead", version="2.2.4")
@provide_session
def get_pool(name, session=None):
def get_pool(name, session: Session = NEW_SESSION):
"""Get pool by a given name."""
if not (name and name.strip()):
raise AirflowBadRequest("Pool name shouldn't be empty")
Expand All @@ -41,14 +42,14 @@ def get_pool(name, session=None):

@deprecated(reason="Use Pool.get_pools() instead", version="2.2.4")
@provide_session
def get_pools(session=None):
def get_pools(session: Session = NEW_SESSION):
"""Get all pools."""
return session.query(Pool).all()


@deprecated(reason="Use Pool.create_pool() instead", version="2.2.4")
@provide_session
def create_pool(name, slots, description, session=None):
def create_pool(name, slots, description, session: Session = NEW_SESSION):
"""Create a pool with given parameters."""
if not (name and name.strip()):
raise AirflowBadRequest("Pool name shouldn't be empty")
Expand Down Expand Up @@ -79,7 +80,7 @@ def create_pool(name, slots, description, session=None):

@deprecated(reason="Use Pool.delete_pool() instead", version="2.2.4")
@provide_session
def delete_pool(name, session=None):
def delete_pool(name, session: Session = NEW_SESSION):
"""Delete pool by a given name."""
if not (name and name.strip()):
raise AirflowBadRequest("Pool name shouldn't be empty")
Expand Down
97 changes: 50 additions & 47 deletions airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import errno
import json
import logging
import operator
import signal
import subprocess
import sys
import warnings

from graphviz.dot import Dot
from sqlalchemy.orm import Session
Expand All @@ -47,33 +49,7 @@
log = logging.getLogger(__name__)


@cli_utils.action_cli
def dag_backfill(args, dag=None):
"""Creates backfill job or dry run for a DAG or list of DAGs using regex."""
logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)

signal.signal(signal.SIGTERM, sigint_handler)

import warnings

warnings.warn(
"--ignore-first-depends-on-past is deprecated as the value is always set to True",
category=RemovedInAirflow3Warning,
)

if args.ignore_first_depends_on_past is False:
args.ignore_first_depends_on_past = True

if not args.start_date and not args.end_date:
raise AirflowException("Provide a start_date and/or end_date")

if not dag:
dags = get_dags(args.subdir, dag_id=args.dag_id, use_regex=args.treat_dag_as_regex)
else:
dags = dag if type(dag) == list else [dag]

dags.sort(key=lambda d: d.dag_id)

def _run_dag_backfill(dags: list[DAG], args) -> None:
# If only one date is passed, using same as start and end
args.end_date = args.end_date or args.start_date
args.start_date = args.start_date or args.end_date
Expand Down Expand Up @@ -133,12 +109,39 @@ def dag_backfill(args, dag=None):
print(str(vr))
sys.exit(1)


@cli_utils.action_cli
def dag_backfill(args, dag: list[DAG] | DAG | None = None) -> None:
"""Creates backfill job or dry run for a DAG or list of DAGs using regex."""
logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
signal.signal(signal.SIGTERM, sigint_handler)
warnings.warn(
"--ignore-first-depends-on-past is deprecated as the value is always set to True",
category=RemovedInAirflow3Warning,
)

if args.ignore_first_depends_on_past is False:
args.ignore_first_depends_on_past = True

if not args.start_date and not args.end_date:
raise AirflowException("Provide a start_date and/or end_date")

if not dag:
dags = get_dags(args.subdir, dag_id=args.dag_id, use_regex=args.treat_dag_as_regex)
elif isinstance(dag, list):
dags = dag
else:
dags = [dag]
del dag

dags.sort(key=lambda d: d.dag_id)
_run_dag_backfill(dags, args)
if len(dags) > 1:
log.info("All of the backfills are done.")


@cli_utils.action_cli
def dag_trigger(args):
def dag_trigger(args) -> None:
"""Creates a dag run for the specified dag."""
api_client = get_current_api_client()
try:
Expand All @@ -159,7 +162,7 @@ def dag_trigger(args):


@cli_utils.action_cli
def dag_delete(args):
def dag_delete(args) -> None:
"""Deletes all DB records related to the specified dag."""
api_client = get_current_api_client()
if (
Expand All @@ -177,18 +180,18 @@ def dag_delete(args):


@cli_utils.action_cli
def dag_pause(args):
def dag_pause(args) -> None:
"""Pauses a DAG."""
set_is_paused(True, args)


@cli_utils.action_cli
def dag_unpause(args):
def dag_unpause(args) -> None:
"""Unpauses a DAG."""
set_is_paused(False, args)


def set_is_paused(is_paused, args):
def set_is_paused(is_paused: bool, args) -> None:
"""Sets is_paused for DAG by a given dag_id."""
dag = DagModel.get_dagmodel(args.dag_id)

Expand All @@ -200,7 +203,7 @@ def set_is_paused(is_paused, args):
print(f"Dag: {args.dag_id}, paused: {is_paused}")


def dag_dependencies_show(args):
def dag_dependencies_show(args) -> None:
"""Displays DAG dependencies, save to file or show as imgcat image."""
dot = render_dag_dependencies(SerializedDagModel.get_dag_dependencies())
filename = args.save
Expand All @@ -219,7 +222,7 @@ def dag_dependencies_show(args):
print(dot.source)


def dag_show(args):
def dag_show(args) -> None:
"""Displays DAG or saves it's graphic representation to the file."""
dag = get_dag(args.subdir, args.dag_id)
dot = render_dag(dag)
Expand All @@ -239,7 +242,7 @@ def dag_show(args):
print(dot.source)


def _display_dot_via_imgcat(dot: Dot):
def _display_dot_via_imgcat(dot: Dot) -> None:
data = dot.pipe(format="png")
try:
with subprocess.Popen("imgcat", stdout=subprocess.PIPE, stdin=subprocess.PIPE) as proc:
Expand All @@ -255,15 +258,15 @@ def _display_dot_via_imgcat(dot: Dot):
raise


def _save_dot_to_file(dot: Dot, filename: str):
def _save_dot_to_file(dot: Dot, filename: str) -> None:
filename_without_ext, _, ext = filename.rpartition(".")
dot.render(filename=filename_without_ext, format=ext, cleanup=True)
print(f"File {filename} saved")


@cli_utils.action_cli
@provide_session
def dag_state(args, session=NEW_SESSION):
def dag_state(args, session: Session = NEW_SESSION) -> None:
"""
Returns the state (and conf if exists) of a DagRun at the command line.
>>> airflow dags state tutorial 2015-01-01T00:00:00.000000
Expand All @@ -284,7 +287,7 @@ def dag_state(args, session=NEW_SESSION):


@cli_utils.action_cli
def dag_next_execution(args):
def dag_next_execution(args) -> None:
"""
Returns the next execution datetime of a DAG at the command line.
>>> airflow dags next-execution tutorial
Expand Down Expand Up @@ -312,15 +315,15 @@ def print_execution_interval(interval: DataInterval | None):
next_interval = dag.get_next_data_interval(last_parsed_dag)
print_execution_interval(next_interval)

for i in range(1, args.num_executions):
for _ in range(1, args.num_executions):
next_info = dag.next_dagrun_info(next_interval, restricted=False)
next_interval = None if next_info is None else next_info.data_interval
print_execution_interval(next_interval)


@cli_utils.action_cli
@suppress_logs_and_warning
def dag_list_dags(args):
def dag_list_dags(args) -> None:
"""Displays dags with or without stats at the command line."""
dagbag = DagBag(process_subdir(args.subdir))
if dagbag.import_errors:
Expand All @@ -332,7 +335,7 @@ def dag_list_dags(args):
file=sys.stderr,
)
AirflowConsole().print_as(
data=sorted(dagbag.dags.values(), key=lambda d: d.dag_id),
data=sorted(dagbag.dags.values(), key=operator.attrgetter("dag_id")),
output=args.output,
mapper=lambda x: {
"dag_id": x.dag_id,
Expand All @@ -345,7 +348,7 @@ def dag_list_dags(args):

@cli_utils.action_cli
@suppress_logs_and_warning
def dag_list_import_errors(args):
def dag_list_import_errors(args) -> None:
"""Displays dags with import errors on the command line."""
dagbag = DagBag(process_subdir(args.subdir))
data = []
Expand All @@ -359,7 +362,7 @@ def dag_list_import_errors(args):

@cli_utils.action_cli
@suppress_logs_and_warning
def dag_report(args):
def dag_report(args) -> None:
"""Displays dagbag stats at the command line."""
dagbag = DagBag(process_subdir(args.subdir))
AirflowConsole().print_as(
Expand All @@ -378,7 +381,7 @@ def dag_report(args):
@cli_utils.action_cli
@suppress_logs_and_warning
@provide_session
def dag_list_jobs(args, dag=None, session=NEW_SESSION):
def dag_list_jobs(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
"""Lists latest n jobs."""
queries = []
if dag:
Expand Down Expand Up @@ -408,7 +411,7 @@ def dag_list_jobs(args, dag=None, session=NEW_SESSION):
@cli_utils.action_cli
@suppress_logs_and_warning
@provide_session
def dag_list_dag_runs(args, dag=None, session=NEW_SESSION):
def dag_list_dag_runs(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
"""Lists dag runs for a given DAG."""
if dag:
args.dag_id = dag.dag_id
Expand Down Expand Up @@ -445,7 +448,7 @@ def dag_list_dag_runs(args, dag=None, session=NEW_SESSION):

@provide_session
@cli_utils.action_cli
def dag_test(args, dag=None, session=None):
def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> None:
"""Execute one single DagRun for a given DAG and execution date."""
run_conf = None
if args.conf:
Expand Down Expand Up @@ -481,7 +484,7 @@ def dag_test(args, dag=None, session=None):

@provide_session
@cli_utils.action_cli
def dag_reserialize(args, session: Session = NEW_SESSION):
def dag_reserialize(args, session: Session = NEW_SESSION) -> None:
"""Serialize a DAG instance."""
session.query(SerializedDagModel).delete(synchronize_session=False)

Expand Down
6 changes: 4 additions & 2 deletions airflow/cli/commands/jobs_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
# under the License.
from __future__ import annotations

from sqlalchemy.orm import Session

from airflow.jobs.base_job import BaseJob
from airflow.utils.net import get_hostname
from airflow.utils.session import provide_session
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State


@provide_session
def check(args, session=None):
def check(args, session: Session = NEW_SESSION) -> None:
"""Checks if job(s) are still alive."""
if args.allow_multiple and not args.limit > 1:
raise SystemExit("To use option --allow-multiple, you must set the limit to a value greater than 1.")
Expand Down
Loading

0 comments on commit 1a85446

Please sign in to comment.