Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compat module for typing airflow.utils.context.Context #619

Merged
merged 7 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions .circleci/scripts/pre_commit_context_typing_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python3
"""
Pre-commit hook to verify ``airflow.utils.context.Context`` is not imported in provider modules.

# TODO: This pre-commit hook can be removed once the repo has a minimum Apache Airflow requirement of 2.3.3+.
"""
from __future__ import annotations

import os
from ast import ImportFrom, NodeVisitor, parse
from pathlib import Path, PosixPath

ASTRONOMER_PROVIDERS_SOURCES_ROOT = Path(__file__).parents[2]
PROVIDERS_ROOT = ASTRONOMER_PROVIDERS_SOURCES_ROOT / "astronomer" / "providers"
TYPING_COMPAT_PATH = "astronomer/providers/utils/typing_compat.py"


class ImportCrawler(NodeVisitor):
"""AST crawler to determine if a module has an incompatible `airflow.utils.context.Context` import."""

def __init__(self) -> None:
self.has_incompatible_context_imports = False

def visit_ImportFrom(self, node: ImportFrom) -> None:
"""Visit an ImportFrom node to determine if `airflow.utils.context.Context` is imported directly."""
if self.has_incompatible_context_imports:
return

for alias in node.names:
if f"{node.module}.{alias.name}" == "airflow.utils.context.Context":
if not self.has_incompatible_context_imports:
self.has_incompatible_context_imports = True


def get_all_provider_files() -> list[PosixPath]:
"""Retrieve all eligible provider module files."""
provider_files = []
for (root, _, file_names) in os.walk(PROVIDERS_ROOT):
for file_name in file_names:
file_path = Path(root, file_name)
if (
file_path.is_file()
and file_path.name.endswith(".py")
and TYPING_COMPAT_PATH not in str(file_path)
):
provider_files.append(file_path)

return provider_files


def find_incompatible_context_imports(file_paths: list[PosixPath]) -> list[str]:
"""Retrieve any provider files that import `airflow.utils.context.Context` directly."""
incompatible_context_imports = []
for file_path in file_paths:
file_ast = parse(file_path.read_text(), filename=file_path.name)
crawler = ImportCrawler()
crawler.visit(file_ast)
if crawler.has_incompatible_context_imports:
incompatible_context_imports.append(str(file_path))

return incompatible_context_imports


if __name__ == "__main__":
provider_files = get_all_provider_files()
files_needing_typing_compat = find_incompatible_context_imports(provider_files)

if len(files_needing_typing_compat) > 0:
error_message = (
"The following files are importing `airflow.utils.context.Context`. "
"This is not compatible with the minimum `apache-airflow` requirement of this repository. "
"Please use `astronomer.providers.utils.typing_compat.Context` instead.\n\n\t{}".format(
"\n\t".join(files_needing_typing_compat)
)
)

raise SystemExit(error_message)
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,8 @@ repos:
files: ^setup.cfg$|^README.rst$
pass_filenames: false
entry: .circleci/scripts/pre_commit_readme_extra.py
- id: check-context-typing-compat
name: Ensure modules use local typing compat for airflow.utils.context.Context
entry: .circleci/scripts/pre_commit_context_typing_compat.py
language: python
pass_filenames: false
6 changes: 3 additions & 3 deletions astronomer/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.batch import BatchOperator
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from astronomer.providers.utils.typing_compat import Context


class BatchOperatorAsync(BatchOperator):
Expand Down Expand Up @@ -53,7 +53,7 @@ class BatchOperatorAsync(BatchOperator):
| ``waiter = waiters.get_waiter("JobComplete")``
"""

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""
Airflow runs this method on the worker and defers using the trigger.
Submit the job and get the job_id using which we defer and poll in trigger
Expand All @@ -79,7 +79,7 @@ def execute(self, context: "Context") -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[str, Any], event: Dict[str, Any]) -> None:
def execute_complete(self, context: Context, event: Dict[str, Any]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
6 changes: 3 additions & 3 deletions astronomer/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger
from astronomer.providers.utils.typing_compat import Context


class EmrContainerOperatorAsync(EmrContainerOperator):
Expand All @@ -28,7 +28,7 @@ class EmrContainerOperatorAsync(EmrContainerOperator):
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
"""

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""Deferred and give control to trigger"""
hook = EmrContainerHook(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id)
job_id = hook.submit_job(
Expand All @@ -52,7 +52,7 @@ def execute(self, context: "Context") -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[str, Any], event: Dict[str, Any]) -> str:
def execute_complete(self, context: Context, event: Dict[str, Any]) -> str:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
18 changes: 8 additions & 10 deletions astronomer/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import Any, Optional

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
Expand All @@ -11,9 +11,7 @@
from astronomer.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
)

if TYPE_CHECKING:
from airflow.utils.context import Context
from astronomer.providers.utils.typing_compat import Context


class RedshiftDeleteClusterOperatorAsync(RedshiftDeleteClusterOperator):
Expand Down Expand Up @@ -44,7 +42,7 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""
Logic that the operator uses to correctly identify which trigger to
execute, and defer execution as expected.
Expand Down Expand Up @@ -74,7 +72,7 @@ def execute(self, context: "Context") -> None:
"Unable to delete cluster since cluster is currently in status: %s", cluster_state
)

def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down Expand Up @@ -114,7 +112,7 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""
Logic that the operator uses to correctly identify which trigger to
execute, and defer execution as expected.
Expand All @@ -138,7 +136,7 @@ def execute(self, context: "Context") -> None:
"Unable to resume cluster since cluster is currently in status: %s", cluster_state
)

def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down Expand Up @@ -180,7 +178,7 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""
Logic that the operator uses to correctly identify which trigger to
execute, and defer execution as expected.
Expand All @@ -204,7 +202,7 @@ def execute(self, context: "Context") -> None:
"Unable to pause cluster since cluster is currently in status: %s", cluster_state
)

def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
6 changes: 3 additions & 3 deletions astronomer/providers/amazon/aws/operators/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from astronomer.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger
from astronomer.providers.utils.typing_compat import Context


class RedshiftDataOperatorAsync(RedshiftDataOperator):
Expand All @@ -32,7 +32,7 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""
Makes a sync call to RedshiftDataHook, executes the query and gets back the list of query_ids and
defers trigger to poll for the status for the queries executed.
Expand All @@ -54,7 +54,7 @@ def execute(self, context: "Context") -> None:
method_name="execute_complete",
)

def execute_complete(self, context: "Context", event: Any = None) -> None:
def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
10 changes: 5 additions & 5 deletions astronomer/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Dict, cast
from typing import Any, cast

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from astronomer.providers.amazon.aws.triggers.redshift_sql import RedshiftSQLTrigger
from astronomer.providers.utils.typing_compat import Context


class RedshiftSQLOperatorAsync(RedshiftSQLOperator):
Expand All @@ -30,15 +30,15 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""
Makes a sync call to RedshiftDataHook and execute the query and gets back the query_ids list and
defers trigger to poll for the status for the query executed
"""
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
query_ids, response = redshift_data_hook.execute_query(sql=cast(str, self.sql), params=self.params)
if response.get("status") == "error":
self.execute_complete({}, response)
self.execute_complete(cast(Context, {}), response)
return
context["ti"].xcom_push(key="return_value", value=query_ids)
self.defer(
Expand All @@ -52,7 +52,7 @@ def execute(self, context: "Context") -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None:
def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
6 changes: 3 additions & 3 deletions astronomer/providers/amazon/aws/sensors/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.batch import BatchSensor
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.triggers.batch import BatchSensorTrigger
from astronomer.providers.utils.typing_compat import Context


class BatchSensorAsync(BatchSensor):
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure or a success state"""
self.defer(
timeout=timedelta(seconds=self.timeout),
Expand All @@ -47,7 +47,7 @@ def execute(self, context: "Context") -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[str, Any], event: Dict[str, Any]) -> None:
def execute_complete(self, context: Context, event: Dict[str, Any]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
8 changes: 4 additions & 4 deletions astronomer/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
EmrJobFlowSensor,
EmrStepSensor,
)
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.triggers.emr import (
EmrContainerSensorTrigger,
EmrJobFlowSensorTrigger,
EmrStepSensorTrigger,
)
from astronomer.providers.utils.typing_compat import Context


class EmrContainerSensorAsync(EmrContainerSensor):
Expand Down Expand Up @@ -45,7 +45,7 @@ def execute(self, context: Context) -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[Any, Any], event: Dict[str, str]) -> None:
def execute_complete(self, context: Context, event: Dict[str, str]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down Expand Up @@ -92,7 +92,7 @@ def execute(self, context: Context) -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[str, Any], event: Dict[str, Any]) -> None:
def execute_complete(self, context: Context, event: Dict[str, Any]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down Expand Up @@ -145,7 +145,7 @@ def execute(self, context: Context) -> None:
method_name="execute_complete",
)

def execute_complete(self, context: Dict[Any, Any], event: Dict[str, str]) -> None:
def execute_complete(self, context: Context, event: Dict[str, str]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
4 changes: 2 additions & 2 deletions astronomer/providers/amazon/aws/sensors/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor
from airflow.utils.context import Context

from astronomer.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterSensorTrigger,
)
from astronomer.providers.utils.typing_compat import Context


class RedshiftClusterSensorAsync(RedshiftClusterSensor):
Expand Down Expand Up @@ -41,7 +41,7 @@ def execute(self, context: Context) -> None:
method_name="execute_complete",
)

def execute_complete(self, context: "Context", event: Optional[Dict[Any, Any]] = None) -> None:
def execute_complete(self, context: Context, event: Optional[Dict[Any, Any]] = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
Expand Down
Loading