Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-4227] Use python3-style type annotations. #5030

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions airflow/contrib/hooks/sagemaker_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tempfile
import time
import os
import collections
from typing import NamedTuple

import botocore.config
from botocore.exceptions import ClientError
Expand All @@ -41,7 +41,13 @@ class LogState:

# Position is a tuple that includes the last read timestamp and the number of items that were read
# at that time. This is used to figure out which event to start with on the next read.
Position = collections.namedtuple('Position', ['timestamp', 'skip'])
Position = NamedTuple(
"Position",
[
("timestamp", int),
("skip", bool),
]
)


def argmin(arr, f):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.

from typing import Dict

import yaml
from airflow.contrib.kubernetes.pod import Pod
from airflow.contrib.kubernetes.kubernetes_request_factory.kubernetes_request_factory \
Expand All @@ -40,8 +42,7 @@ class SimplePodRequestFactory(KubernetesRequestFactory):
def __init__(self):
pass

def create(self, pod):
# type: (Pod) -> dict
def create(self, pod: Pod) -> Dict:
req = yaml.safe_load(self._yaml)
self.extract_name(pod, req)
self.extract_labels(pod, req)
Expand Down Expand Up @@ -108,8 +109,7 @@ class ExtractXcomPodRequestFactory(KubernetesRequestFactory):
def __init__(self):
pass

def create(self, pod):
# type: (Pod) -> dict
def create(self, pod: Pod) -> Dict:
req = yaml.safe_load(self._yaml)
self.extract_name(pod, req)
self.extract_labels(pod, req)
Expand Down
7 changes: 2 additions & 5 deletions airflow/contrib/kubernetes/pod_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def delete_pod(self, pod):
if e.status != 404:
raise

def run_pod(self, pod, startup_timeout=120, get_logs=True):
# type: (Pod, int, bool) -> Tuple[State, Optional[str]]
def run_pod(self, pod: Pod, startup_timeout=120, get_logs=True) -> Tuple[State, Optional[str]]:
"""
Launches the pod synchronously and waits for completion.
Args:
Expand All @@ -91,9 +90,7 @@ def run_pod(self, pod, startup_timeout=120, get_logs=True):

return self._monitor_pod(pod, get_logs)

def _monitor_pod(self, pod, get_logs):
# type: (Pod, bool) -> Tuple[State, Optional[str]]

def _monitor_pod(self, pod: Pod, get_logs: bool) -> Tuple[State, Optional[str]]:
if get_logs:
logs = self._client.read_namespaced_pod_log(
name=pod.name,
Expand Down
18 changes: 12 additions & 6 deletions airflow/contrib/operators/azure_container_instances_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
# specific language governing permissions and limitations
# under the License.

from collections import namedtuple
from time import sleep
from typing import Dict, Sequence
from typing import Dict, Iterable, NamedTuple

from airflow.contrib.hooks.azure_container_instance_hook import AzureContainerInstanceHook
from airflow.contrib.hooks.azure_container_registry_hook import AzureContainerRegistryHook
Expand All @@ -38,13 +37,20 @@
from msrestazure.azure_exceptions import CloudError


Volume = namedtuple(
'Volume',
['conn_id', 'account_name', 'share_name', 'mount_path', 'read_only'],
Volume = NamedTuple(
"Volume",
[
("conn_id", str),
("account_name", str),
("share_name", str),
("mount_path", str),
("read_only", bool),
]
)


DEFAULT_ENVIRONMENT_VARIABLES = {} # type: Dict[str, str]
DEFAULT_VOLUMES = [] # type: Sequence[Volume]
DEFAULT_VOLUMES = [] # type: Iterable[Volume]
DEFAULT_MEMORY_IN_GB = 2.0
DEFAULT_CPU = 1.0

Expand Down
4 changes: 2 additions & 2 deletions airflow/contrib/operators/gcp_compute_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# specific language governing permissions and limitations
# under the License.
from copy import deepcopy
from typing import Dict

from googleapiclient.errors import HttpError

Expand Down Expand Up @@ -452,8 +453,7 @@ def __init__(self,
project_id=project_id, zone=self.zone, resource_id=resource_id,
gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs)

def _possibly_replace_template(self, dictionary):
# type: (dict) -> None
def _possibly_replace_template(self, dictionary: Dict) -> None:
if dictionary.get('instanceTemplate') == self.source_template:
dictionary['instanceTemplate'] = self.destination_template
self._change_performed = True
Expand Down
14 changes: 8 additions & 6 deletions airflow/contrib/operators/qubole_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,14 @@ class QuboleOperator(BaseOperator):
handler in task definition.
"""

template_fields = ('query', 'script_location', 'sub_command', 'script', 'files',
'archives', 'program', 'cmdline', 'sql', 'where_clause', 'tags',
'extract_query', 'boundary_query', 'macros', 'name', 'parameters',
'dbtap_id', 'hive_table', 'db_table', 'split_column', 'note_id',
'db_update_keys', 'export_dir', 'partition_spec', 'qubole_conn_id',
'arguments', 'user_program_arguments', 'cluster_label') # type: Iterable[str]
template_fields = (
'query', 'script_location', 'sub_command', 'script', 'files',
'archives', 'program', 'cmdline', 'sql', 'where_clause', 'tags',
'extract_query', 'boundary_query', 'macros', 'name', 'parameters',
'dbtap_id', 'hive_table', 'db_table', 'split_column', 'note_id',
'db_update_keys', 'export_dir', 'partition_spec', 'qubole_conn_id',
'arguments', 'user_program_arguments', 'cluster_label',
) # type: Iterable[str]

template_ext = ('.txt',) # type: Iterable[str]
ui_color = '#3064A1'
Expand Down
5 changes: 2 additions & 3 deletions airflow/contrib/utils/gcp_field_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
components in all elements of the array.
"""

from typing import List
from typing import Iterable

from airflow import LoggingMixin, AirflowException

Expand All @@ -119,8 +119,7 @@ class GcpBodyFieldSanitizer(LoggingMixin):
:type sanitize_specs: list[str]

"""
def __init__(self, sanitize_specs):
# type: (List[str]) -> None
def __init__(self, sanitize_specs: Iterable[str]) -> None:
super().__init__()
self._sanitize_specs = sanitize_specs

Expand Down
29 changes: 15 additions & 14 deletions airflow/contrib/utils/gcp_field_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,7 @@ class GcpBodyFieldValidator(LoggingMixin):
:type api_version: str

"""
def __init__(self, validation_specs, api_version):
# type: (Sequence[Dict], str) -> None
def __init__(self, validation_specs: Sequence[Dict], api_version: str) -> None:
super().__init__()
self._validation_specs = validation_specs
self._api_version = api_version
Expand All @@ -207,9 +206,15 @@ def _get_field_name_with_parent(field_name, parent):
return field_name

@staticmethod
def _sanity_checks(children_validation_specs, field_type, full_field_path,
regexp, allow_empty, custom_validation, value):
# type: (dict, str, str, str, Callable, object) -> None
def _sanity_checks(
children_validation_specs: Dict,
field_type: str,
full_field_path: str,
regexp: str,
allow_empty: bool,
custom_validation: Callable,
value,
) -> None:
if value is None and field_type != 'union':
raise GcpFieldValidationException(
"The required body field '{}' is missing. Please add it.".
Expand All @@ -236,8 +241,7 @@ def _sanity_checks(children_validation_specs, field_type, full_field_path,
format(full_field_path))

@staticmethod
def _validate_regexp(full_field_path, regexp, value):
# type: (str, str, str) -> None
def _validate_regexp(full_field_path: str, regexp: str, value: str) -> None:
if not re.match(regexp, value):
# Note matching of only the beginning as we assume the regexps all-or-nothing
raise GcpFieldValidationException(
Expand All @@ -246,15 +250,13 @@ def _validate_regexp(full_field_path, regexp, value):
format(full_field_path, value, regexp))

@staticmethod
def _validate_is_empty(full_field_path, value):
# type: (str, str) -> None
def _validate_is_empty(full_field_path: str, value: str) -> None:
if not value:
raise GcpFieldValidationException(
"The body field '{}' can't be empty. Please provide a value."
.format(full_field_path, value))

def _validate_dict(self, children_validation_specs, full_field_path, value):
# type: (dict, str, dict) -> None
def _validate_dict(self, children_validation_specs: Dict, full_field_path: str, value: Dict) -> None:
for child_validation_spec in children_validation_specs:
self._validate_field(validation_spec=child_validation_spec,
dictionary_to_validate=value,
Expand All @@ -272,9 +274,8 @@ def _validate_dict(self, children_validation_specs, full_field_path, value):
self._get_field_name_with_parent(field_name, full_field_path),
children_validation_specs)

def _validate_union(self, children_validation_specs, full_field_path,
dictionary_to_validate):
# type: (dict, str, dict) -> None
def _validate_union(self, children_validation_specs: Dict, full_field_path: str,
dictionary_to_validate: Dict) -> None:
field_found = False
found_field_name = None
for child_validation_spec in children_validation_specs:
Expand Down
6 changes: 3 additions & 3 deletions airflow/hooks/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _get_connection_from_env(cls, conn_id):
return conn

@classmethod
def get_connections(cls, conn_id): # type: (str) -> Iterable[Connection]
def get_connections(cls, conn_id: str) -> Iterable[Connection]:
conn = cls._get_connection_from_env(conn_id)
if conn:
conns = [conn]
Expand All @@ -72,15 +72,15 @@ def get_connections(cls, conn_id): # type: (str) -> Iterable[Connection]
return conns

@classmethod
def get_connection(cls, conn_id): # type: (str) -> Connection
def get_connection(cls, conn_id: str) -> Connection:
conn = random.choice(list(cls.get_connections(conn_id)))
if conn.host:
log = LoggingMixin().log
log.info("Using connection to: %s", conn.debug_info())
return conn

@classmethod
def get_hook(cls, conn_id): # type: (str) -> BaseHook
def get_hook(cls, conn_id: str) -> 'BaseHook':
connection = cls.get_connection(conn_id)
return connection.get_hook()

Expand Down
66 changes: 33 additions & 33 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,40 +238,40 @@ class derived from this one results in the creation of a task object,
@apply_defaults
def __init__(
self,
task_id, # type: str
owner=configuration.conf.get('operators', 'DEFAULT_OWNER'), # type: str
email=None, # type: Optional[str]
email_on_retry=True, # type: bool
email_on_failure=True, # type: bool
retries=0, # type: int
retry_delay=timedelta(seconds=300), # type: timedelta
retry_exponential_backoff=False, # type: bool
max_retry_delay=None, # type: Optional[datetime]
start_date=None, # type: Optional[datetime]
end_date=None, # type: Optional[datetime]
task_id: str,
owner: str = configuration.conf.get('operators', 'DEFAULT_OWNER'),
email: Optional[str] = None,
email_on_retry: bool = True,
email_on_failure: bool = True,
retries: int = 0,
retry_delay: timedelta = timedelta(seconds=300),
retry_exponential_backoff: bool = False,
max_retry_delay: Optional[datetime] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
schedule_interval=None, # not hooked as of now
depends_on_past=False, # type: bool
wait_for_downstream=False, # type: bool
dag=None, # type: Optional[DAG]
params=None, # type: Optional[Dict]
default_args=None, # type: Optional[Dict]
priority_weight=1, # type: int
weight_rule=WeightRule.DOWNSTREAM, # type: str
queue=configuration.conf.get('celery', 'default_queue'), # type: str
pool=None, # type: Optional[str]
sla=None, # type: Optional[timedelta]
execution_timeout=None, # type: Optional[timedelta]
on_failure_callback=None, # type: Optional[Callable]
on_success_callback=None, # type: Optional[Callable]
on_retry_callback=None, # type: Optional[Callable]
trigger_rule=TriggerRule.ALL_SUCCESS, # type: str
resources=None, # type: Optional[Dict]
run_as_user=None, # type: Optional[str]
task_concurrency=None, # type: Optional[int]
executor_config=None, # type: Optional[Dict]
do_xcom_push=True, # type: bool
inlets=None, # type: Optional[Dict]
outlets=None, # type: Optional[Dict]
depends_on_past: bool = False,
wait_for_downstream: bool = False,
dag: Optional['DAG'] = None,
params: Optional[Dict] = None,
default_args: Optional[Dict] = None,
priority_weight: int = 1,
weight_rule: str = WeightRule.DOWNSTREAM,
queue: str = configuration.conf.get('celery', 'default_queue'),
pool: Optional[str] = None,
sla: Optional[timedelta] = None,
execution_timeout: Optional[timedelta] = None,
on_failure_callback: Optional[Callable] = None,
on_success_callback: Optional[Callable] = None,
on_retry_callback: Optional[Callable] = None,
trigger_rule: str = TriggerRule.ALL_SUCCESS,
resources: Optional[Dict] = None,
run_as_user: Optional[str] = None,
task_concurrency: Optional[int] = None,
executor_config: Optional[Dict] = None,
do_xcom_push: bool = True,
inlets: Optional[Dict] = None,
outlets: Optional[Dict] = None,
*args,
**kwargs
):
Expand Down
47 changes: 23 additions & 24 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,30 +167,29 @@ class DAG(BaseDag, LoggingMixin):

def __init__(
self,
dag_id, # type: str
description='', # type: str
schedule_interval=timedelta(days=1), # type: Optional[ScheduleInterval]
start_date=None, # type: Optional[datetime]
end_date=None, # type: Optional[datetime]
full_filepath=None, # type: Optional[str]
template_searchpath=None, # type: Optional[Union[str, Iterable[str]]]
template_undefined=jinja2.Undefined, # type: Type[jinja2.Undefined]
user_defined_macros=None, # type: Optional[Dict]
user_defined_filters=None, # type: Optional[Dict]
default_args=None, # type: Optional[Dict]
concurrency=configuration.conf.getint('core', 'dag_concurrency'), # type: int
max_active_runs=configuration.conf.getint(
'core', 'max_active_runs_per_dag'), # type: int
dagrun_timeout=None, # type: Optional[timedelta]
sla_miss_callback=None, # type: Optional[Callable]
default_view=None, # type: Optional[str]
orientation=configuration.conf.get('webserver', 'dag_orientation'), # type: str
catchup=configuration.conf.getboolean('scheduler', 'catchup_by_default'), # type: bool
on_success_callback=None, # type: Optional[Callable]
on_failure_callback=None, # type: Optional[Callable]
doc_md=None, # type: Optional[str]
params=None, # type: Optional[Dict]
access_control=None # type: Optional[Dict]
dag_id: str,
description: str = '',
schedule_interval: Optional[ScheduleInterval] = timedelta(days=1),
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
full_filepath: Optional[str] = None,
template_searchpath: Optional[Union[str, Iterable[str]]] = None,
template_undefined: Type[jinja2.Undefined] = jinja2.Undefined,
user_defined_macros: Optional[Dict] = None,
user_defined_filters: Optional[Dict] = None,
default_args: Optional[Dict] = None,
concurrency: int = configuration.conf.getint('core', 'dag_concurrency'),
max_active_runs: int = configuration.conf.getint('core', 'max_active_runs_per_dag'),
dagrun_timeout: Optional[timedelta] = None,
sla_miss_callback: Optional[Callable] = None,
default_view: Optional[str] = None,
orientation: str = configuration.conf.get('webserver', 'dag_orientation'),
catchup: bool = configuration.conf.getboolean('scheduler', 'catchup_by_default'),
on_success_callback: Optional[Callable] = None,
on_failure_callback: Optional[Callable] = None,
doc_md: Optional[str] = None,
params: Optional[Dict] = None,
access_control: Optional[Dict] = None
):
self.user_defined_macros = user_defined_macros
self.user_defined_filters = user_defined_filters
Expand Down
Loading