Skip to content

Commit

Permalink
Add compat module for typing airflow.utils.context.Context (#619)
Browse files Browse the repository at this point in the history
The `airflow.utils.context` module was not introduced to OSS Airflow until 2.2.3. Importing this module as a top-level import sets an implicit, minimum requirement for Airflow 2.2.3 which is higher than the minumum `apache-airflow` requirement set for Astronomer Providers. This PR adds a typing compat module to be used throughout the repo for consistent typing of `context`.
  • Loading branch information
josh-fell authored Sep 14, 2022
1 parent a040fb8 commit 14751af
Show file tree
Hide file tree
Showing 33 changed files with 203 additions and 93 deletions.
77 changes: 77 additions & 0 deletions .circleci/scripts/pre_commit_context_typing_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python3
"""
Pre-commit hook to verify ``airflow.utils.context.Context`` is not imported in provider modules.
# TODO: This pre-commit hook can be removed once the repo has a minimum Apache Airflow requirement of 2.3.3+.
"""
from __future__ import annotations

import os
from ast import ImportFrom, NodeVisitor, parse
from pathlib import Path, PosixPath

ASTRONOMER_PROVIDERS_SOURCES_ROOT = Path(__file__).parents[2]
PROVIDERS_ROOT = ASTRONOMER_PROVIDERS_SOURCES_ROOT / "astronomer" / "providers"
TYPING_COMPAT_PATH = "astronomer/providers/utils/typing_compat.py"


class ImportCrawler(NodeVisitor):
"""AST crawler to determine if a module has an incompatible `airflow.utils.context.Context` import."""

def __init__(self) -> None:
self.has_incompatible_context_imports = False

def visit_ImportFrom(self, node: ImportFrom) -> None:
"""Visit an ImportFrom node to determine if `airflow.utils.context.Context` is imported directly."""
if self.has_incompatible_context_imports:
return

for alias in node.names:
if f"{node.module}.{alias.name}" == "airflow.utils.context.Context":
if not self.has_incompatible_context_imports:
self.has_incompatible_context_imports = True


def get_all_provider_files() -> list[PosixPath]:
"""Retrieve all eligible provider module files."""
provider_files = []
for (root, _, file_names) in os.walk(PROVIDERS_ROOT):
for file_name in file_names:
file_path = Path(root, file_name)
if (
file_path.is_file()
and file_path.name.endswith(".py")
and TYPING_COMPAT_PATH not in str(file_path)
):
provider_files.append(file_path)

return provider_files


def find_incompatible_context_imports(file_paths: list[PosixPath]) -> list[str]:
"""Retrieve any provider files that import `airflow.utils.context.Context` directly."""
incompatible_context_imports = []
for file_path in file_paths:
file_ast = parse(file_path.read_text(), filename=file_path.name)
crawler = ImportCrawler()
crawler.visit(file_ast)
if crawler.has_incompatible_context_imports:
incompatible_context_imports.append(str(file_path))

return incompatible_context_imports


if __name__ == "__main__":
provider_files = get_all_provider_files()
files_needing_typing_compat = find_incompatible_context_imports(provider_files)

if len(files_needing_typing_compat) > 0:
error_message = (
"The following files are importing `airflow.utils.context.Context`. "
"This is not compatible with the minimum `apache-airflow` requirement of this repository. "
"Please use `astronomer.providers.utils.typing_compat.Context` instead.\n\n\t{}".format(
"\n\t".join(files_needing_typing_compat)
)
)

raise SystemExit(error_message)
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,8 @@ repos:
files: ^setup.cfg$|^README.rst$
pass_filenames: false
entry: .circleci/scripts/pre_commit_readme_extra.py
- id: check-context-typing-compat
name: Ensure modules use local typing compat for airflow.utils.context.Context
entry: .circleci/scripts/pre_commit_context_typing_compat.py
language: python
pass_filenames: false
6 changes: 3 additions & 3 deletions astronomer/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.batch import BatchOperator
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from astronomer.providers.utils.typing_compat import Context


class BatchOperatorAsync(BatchOperator):
Expand Down Expand Up @@ -53,7 +53,7 @@ class BatchOperatorAsync(BatchOperator):
| ``waiter = waiters.get_waiter("JobComplete")``
"""

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""
Airflow runs this method on the worker and defers using the trigger.
Submit the job and get the job_id using which we defer and poll in trigger
Expand All @@ -79,7 +79,7 @@ def execute(self, context: "Context") -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[str, Any], event: Dict[str, Any]) -> None:
def execute_complete(self, context: Context, event: Dict[str, Any]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
6 changes: 3 additions & 3 deletions astronomer/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger
from astronomer.providers.utils.typing_compat import Context


class EmrContainerOperatorAsync(EmrContainerOperator):
Expand All @@ -28,7 +28,7 @@ class EmrContainerOperatorAsync(EmrContainerOperator):
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
"""

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""Deferred and give control to trigger"""
hook = EmrContainerHook(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id)
job_id = hook.submit_job(
Expand All @@ -52,7 +52,7 @@ def execute(self, context: "Context") -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[str, Any], event: Dict[str, Any]) -> str:
def execute_complete(self, context: Context, event: Dict[str, Any]) -> str:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
18 changes: 8 additions & 10 deletions astronomer/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import Any, Optional

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
Expand All @@ -11,9 +11,7 @@
from astronomer.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
)

if TYPE_CHECKING:
from airflow.utils.context import Context
from astronomer.providers.utils.typing_compat import Context


class RedshiftDeleteClusterOperatorAsync(RedshiftDeleteClusterOperator):
Expand Down Expand Up @@ -44,7 +42,7 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""
Logic that the operator uses to correctly identify which trigger to
execute, and defer execution as expected.
Expand Down Expand Up @@ -74,7 +72,7 @@ def execute(self, context: "Context") -> None:
"Unable to delete cluster since cluster is currently in status: %s", cluster_state
)

def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down Expand Up @@ -114,7 +112,7 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""
Logic that the operator uses to correctly identify which trigger to
execute, and defer execution as expected.
Expand All @@ -138,7 +136,7 @@ def execute(self, context: "Context") -> None:
"Unable to resume cluster since cluster is currently in status: %s", cluster_state
)

def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down Expand Up @@ -180,7 +178,7 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""
Logic that the operator uses to correctly identify which trigger to
execute, and defer execution as expected.
Expand All @@ -204,7 +202,7 @@ def execute(self, context: "Context") -> None:
"Unable to pause cluster since cluster is currently in status: %s", cluster_state
)

def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
6 changes: 3 additions & 3 deletions astronomer/providers/amazon/aws/operators/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from astronomer.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger
from astronomer.providers.utils.typing_compat import Context


class RedshiftDataOperatorAsync(RedshiftDataOperator):
Expand All @@ -32,7 +32,7 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""
Makes a sync call to RedshiftDataHook, executes the query and gets back the list of query_ids and
defers trigger to poll for the status for the queries executed.
Expand All @@ -54,7 +54,7 @@ def execute(self, context: "Context") -> None:
method_name="execute_complete",
)

def execute_complete(self, context: "Context", event: Any = None) -> None:
def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
10 changes: 5 additions & 5 deletions astronomer/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Dict, cast
from typing import Any, cast

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from astronomer.providers.amazon.aws.triggers.redshift_sql import RedshiftSQLTrigger
from astronomer.providers.utils.typing_compat import Context


class RedshiftSQLOperatorAsync(RedshiftSQLOperator):
Expand All @@ -30,15 +30,15 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""
Makes a sync call to RedshiftDataHook and execute the query and gets back the query_ids list and
defers trigger to poll for the status for the query executed
"""
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
query_ids, response = redshift_data_hook.execute_query(sql=cast(str, self.sql), params=self.params)
if response.get("status") == "error":
self.execute_complete({}, response)
self.execute_complete(cast(Context, {}), response)
return
context["ti"].xcom_push(key="return_value", value=query_ids)
self.defer(
Expand All @@ -52,7 +52,7 @@ def execute(self, context: "Context") -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
6 changes: 3 additions & 3 deletions astronomer/providers/amazon/aws/sensors/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.batch import BatchSensor
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.triggers.batch import BatchSensorTrigger
from astronomer.providers.utils.typing_compat import Context


class BatchSensorAsync(BatchSensor):
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure or a success state"""
self.defer(
timeout=timedelta(seconds=self.timeout),
Expand All @@ -47,7 +47,7 @@ def execute(self, context: "Context") -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[str, Any], event: Dict[str, Any]) -> None:
def execute_complete(self, context: Context, event: Dict[str, Any]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
8 changes: 4 additions & 4 deletions astronomer/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
EmrJobFlowSensor,
EmrStepSensor,
)
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.triggers.emr import (
EmrContainerSensorTrigger,
EmrJobFlowSensorTrigger,
EmrStepSensorTrigger,
)
from astronomer.providers.utils.typing_compat import Context


class EmrContainerSensorAsync(EmrContainerSensor):
Expand Down Expand Up @@ -45,7 +45,7 @@ def execute(self, context: Context) -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[Any, Any], event: Dict[str, str]) -> None:
def execute_complete(self, context: Context, event: Dict[str, str]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down Expand Up @@ -92,7 +92,7 @@ def execute(self, context: Context) -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[str, Any], event: Dict[str, Any]) -> None:
def execute_complete(self, context: Context, event: Dict[str, Any]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down Expand Up @@ -145,7 +145,7 @@ def execute(self, context: Context) -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[Any, Any], event: Dict[str, str]) -> None:
def execute_complete(self, context: Context, event: Dict[str, str]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
4 changes: 2 additions & 2 deletions astronomer/providers/amazon/aws/sensors/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterSensorTrigger,
)
from astronomer.providers.utils.typing_compat import Context


class RedshiftClusterSensorAsync(RedshiftClusterSensor):
Expand Down Expand Up @@ -41,7 +41,7 @@ def execute(self, context: Context) -> None:
method_name="execute_complete",
)

def execute_complete(self, context: "Context", event: Optional[Dict[Any, Any]] = None) -> None:
def execute_complete(self, context: Context, event: Optional[Dict[Any, Any]] = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
Loading

0 comments on commit 14751af

Please sign in to comment.