From 14751afd7876bc8e62a85fa5c85bcf0116eaf9de Mon Sep 17 00:00:00 2001 From: Josh Fell <48934154+josh-fell@users.noreply.github.com> Date: Wed, 14 Sep 2022 11:38:10 -0400 Subject: [PATCH] Add compat module for typing `airflow.utils.context.Context` (#619) The `airflow.utils.context` module was not introduced to OSS Airflow until 2.2.3. Importing this module as a top-level import sets an implicit, minimum requirement for Airflow 2.2.3 which is higher than the minumum `apache-airflow` requirement set for Astronomer Providers. This PR adds a typing compat module to be used throughout the repo for consistent typing of `context`. --- .../pre_commit_context_typing_compat.py | 77 +++++++++++++++++++ .pre-commit-config.yaml | 5 ++ .../providers/amazon/aws/operators/batch.py | 6 +- .../providers/amazon/aws/operators/emr.py | 6 +- .../amazon/aws/operators/redshift_cluster.py | 18 ++--- .../amazon/aws/operators/redshift_data.py | 6 +- .../amazon/aws/operators/redshift_sql.py | 10 +-- .../providers/amazon/aws/sensors/batch.py | 6 +- .../providers/amazon/aws/sensors/emr.py | 8 +- .../amazon/aws/sensors/redshift_cluster.py | 4 +- astronomer/providers/amazon/aws/sensors/s3.py | 8 +- .../apache/hive/sensors/hive_partition.py | 6 +- .../hive/sensors/named_hive_partition.py | 6 +- .../providers/apache/livy/operators/livy.py | 6 +- .../kubernetes/operators/kubernetes_pod.py | 2 +- .../providers/core/sensors/external_task.py | 2 +- .../providers/core/sensors/filesystem.py | 4 +- .../databricks/operators/databricks.py | 14 ++-- .../providers/dbt/cloud/operators/dbt.py | 6 +- astronomer/providers/dbt/cloud/sensors/dbt.py | 6 +- .../google/cloud/operators/bigquery.py | 2 +- .../google/cloud/operators/dataproc.py | 12 +-- .../cloud/operators/kubernetes_engine.py | 6 +- .../google/cloud/sensors/bigquery.py | 4 +- .../providers/google/cloud/sensors/gcs.py | 2 +- astronomer/providers/http/sensors/http.py | 4 +- .../microsoft/azure/operators/data_factory.py | 8 +- .../microsoft/azure/sensors/data_factory.py | 4 +- .../providers/microsoft/azure/sensors/wasb.py | 6 +- .../snowflake/operators/snowflake.py | 6 +- astronomer/providers/utils/__init__.py | 0 astronomer/providers/utils/typing_compat.py | 17 ++++ tests/utils/test_typing_compat.py | 19 +++++ 33 files changed, 203 insertions(+), 93 deletions(-) create mode 100755 .circleci/scripts/pre_commit_context_typing_compat.py create mode 100644 astronomer/providers/utils/__init__.py create mode 100644 astronomer/providers/utils/typing_compat.py create mode 100644 tests/utils/test_typing_compat.py diff --git a/.circleci/scripts/pre_commit_context_typing_compat.py b/.circleci/scripts/pre_commit_context_typing_compat.py new file mode 100755 index 000000000..fb0093d2d --- /dev/null +++ b/.circleci/scripts/pre_commit_context_typing_compat.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +""" +Pre-commit hook to verify ``airflow.utils.context.Context`` is not imported in provider modules. + +# TODO: This pre-commit hook can be removed once the repo has a minimum Apache Airflow requirement of 2.3.3+. +""" +from __future__ import annotations + +import os +from ast import ImportFrom, NodeVisitor, parse +from pathlib import Path, PosixPath + +ASTRONOMER_PROVIDERS_SOURCES_ROOT = Path(__file__).parents[2] +PROVIDERS_ROOT = ASTRONOMER_PROVIDERS_SOURCES_ROOT / "astronomer" / "providers" +TYPING_COMPAT_PATH = "astronomer/providers/utils/typing_compat.py" + + +class ImportCrawler(NodeVisitor): + """AST crawler to determine if a module has an incompatible `airflow.utils.context.Context` import.""" + + def __init__(self) -> None: + self.has_incompatible_context_imports = False + + def visit_ImportFrom(self, node: ImportFrom) -> None: + """Visit an ImportFrom node to determine if `airflow.utils.context.Context` is imported directly.""" + if self.has_incompatible_context_imports: + return + + for alias in node.names: + if f"{node.module}.{alias.name}" == "airflow.utils.context.Context": + if not self.has_incompatible_context_imports: + self.has_incompatible_context_imports = True + + +def get_all_provider_files() -> list[PosixPath]: + """Retrieve all eligible provider module files.""" + provider_files = [] + for (root, _, file_names) in os.walk(PROVIDERS_ROOT): + for file_name in file_names: + file_path = Path(root, file_name) + if ( + file_path.is_file() + and file_path.name.endswith(".py") + and TYPING_COMPAT_PATH not in str(file_path) + ): + provider_files.append(file_path) + + return provider_files + + +def find_incompatible_context_imports(file_paths: list[PosixPath]) -> list[str]: + """Retrieve any provider files that import `airflow.utils.context.Context` directly.""" + incompatible_context_imports = [] + for file_path in file_paths: + file_ast = parse(file_path.read_text(), filename=file_path.name) + crawler = ImportCrawler() + crawler.visit(file_ast) + if crawler.has_incompatible_context_imports: + incompatible_context_imports.append(str(file_path)) + + return incompatible_context_imports + + +if __name__ == "__main__": + provider_files = get_all_provider_files() + files_needing_typing_compat = find_incompatible_context_imports(provider_files) + + if len(files_needing_typing_compat) > 0: + error_message = ( + "The following files are importing `airflow.utils.context.Context`. " + "This is not compatible with the minimum `apache-airflow` requirement of this repository. " + "Please use `astronomer.providers.utils.typing_compat.Context` instead.\n\n\t{}".format( + "\n\t".join(files_needing_typing_compat) + ) + ) + + raise SystemExit(error_message) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a246e1fe7..0105c6980 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -138,3 +138,8 @@ repos: files: ^setup.cfg$|^README.rst$ pass_filenames: false entry: .circleci/scripts/pre_commit_readme_extra.py + - id: check-context-typing-compat + name: Ensure modules use local typing compat for airflow.utils.context.Context + entry: .circleci/scripts/pre_commit_context_typing_compat.py + language: python + pass_filenames: false diff --git a/astronomer/providers/amazon/aws/operators/batch.py b/astronomer/providers/amazon/aws/operators/batch.py index 07656f65d..e343c1b16 100644 --- a/astronomer/providers/amazon/aws/operators/batch.py +++ b/astronomer/providers/amazon/aws/operators/batch.py @@ -11,9 +11,9 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.operators.batch import BatchOperator -from airflow.utils.context import Context from astronomer.providers.amazon.aws.triggers.batch import BatchOperatorTrigger +from astronomer.providers.utils.typing_compat import Context class BatchOperatorAsync(BatchOperator): @@ -53,7 +53,7 @@ class BatchOperatorAsync(BatchOperator): | ``waiter = waiters.get_waiter("JobComplete")`` """ - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """ Airflow runs this method on the worker and defers using the trigger. Submit the job and get the job_id using which we defer and poll in trigger @@ -79,7 +79,7 @@ def execute(self, context: "Context") -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Dict[str, Any]) -> None: + def execute_complete(self, context: Context, event: Dict[str, Any]) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/amazon/aws/operators/emr.py b/astronomer/providers/amazon/aws/operators/emr.py index bfc49e193..b7ec8cc5e 100644 --- a/astronomer/providers/amazon/aws/operators/emr.py +++ b/astronomer/providers/amazon/aws/operators/emr.py @@ -3,9 +3,9 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator -from airflow.utils.context import Context from astronomer.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger +from astronomer.providers.utils.typing_compat import Context class EmrContainerOperatorAsync(EmrContainerOperator): @@ -28,7 +28,7 @@ class EmrContainerOperatorAsync(EmrContainerOperator): Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state. """ - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """Deferred and give control to trigger""" hook = EmrContainerHook(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) job_id = hook.submit_job( @@ -52,7 +52,7 @@ def execute(self, context: "Context") -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Dict[str, Any]) -> str: + def execute_complete(self, context: Context, event: Dict[str, Any]) -> str: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/amazon/aws/operators/redshift_cluster.py b/astronomer/providers/amazon/aws/operators/redshift_cluster.py index 7d05c669e..8d6ce2b71 100644 --- a/astronomer/providers/amazon/aws/operators/redshift_cluster.py +++ b/astronomer/providers/amazon/aws/operators/redshift_cluster.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import Any, Optional from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook @@ -11,9 +11,7 @@ from astronomer.providers.amazon.aws.triggers.redshift_cluster import ( RedshiftClusterTrigger, ) - -if TYPE_CHECKING: - from airflow.utils.context import Context +from astronomer.providers.utils.typing_compat import Context class RedshiftDeleteClusterOperatorAsync(RedshiftDeleteClusterOperator): @@ -44,7 +42,7 @@ def __init__( self.poll_interval = poll_interval super().__init__(**kwargs) - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """ Logic that the operator uses to correctly identify which trigger to execute, and defer execution as expected. @@ -74,7 +72,7 @@ def execute(self, context: "Context") -> None: "Unable to delete cluster since cluster is currently in status: %s", cluster_state ) - def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None: + def execute_complete(self, context: Context, event: Any = None) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was @@ -114,7 +112,7 @@ def __init__( self.poll_interval = poll_interval super().__init__(**kwargs) - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """ Logic that the operator uses to correctly identify which trigger to execute, and defer execution as expected. @@ -138,7 +136,7 @@ def execute(self, context: "Context") -> None: "Unable to resume cluster since cluster is currently in status: %s", cluster_state ) - def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None: + def execute_complete(self, context: Context, event: Any = None) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was @@ -180,7 +178,7 @@ def __init__( self.poll_interval = poll_interval super().__init__(**kwargs) - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """ Logic that the operator uses to correctly identify which trigger to execute, and defer execution as expected. @@ -204,7 +202,7 @@ def execute(self, context: "Context") -> None: "Unable to pause cluster since cluster is currently in status: %s", cluster_state ) - def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None: + def execute_complete(self, context: Context, event: Any = None) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/amazon/aws/operators/redshift_data.py b/astronomer/providers/amazon/aws/operators/redshift_data.py index cd80f3bb3..be0df6a5f 100644 --- a/astronomer/providers/amazon/aws/operators/redshift_data.py +++ b/astronomer/providers/amazon/aws/operators/redshift_data.py @@ -2,10 +2,10 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator -from airflow.utils.context import Context from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook from astronomer.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger +from astronomer.providers.utils.typing_compat import Context class RedshiftDataOperatorAsync(RedshiftDataOperator): @@ -32,7 +32,7 @@ def __init__( self.poll_interval = poll_interval super().__init__(**kwargs) - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """ Makes a sync call to RedshiftDataHook, executes the query and gets back the list of query_ids and defers trigger to poll for the status for the queries executed. @@ -54,7 +54,7 @@ def execute(self, context: "Context") -> None: method_name="execute_complete", ) - def execute_complete(self, context: "Context", event: Any = None) -> None: + def execute_complete(self, context: Context, event: Any = None) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/amazon/aws/operators/redshift_sql.py b/astronomer/providers/amazon/aws/operators/redshift_sql.py index 4730a6390..00c69f9a2 100644 --- a/astronomer/providers/amazon/aws/operators/redshift_sql.py +++ b/astronomer/providers/amazon/aws/operators/redshift_sql.py @@ -1,11 +1,11 @@ -from typing import Any, Dict, cast +from typing import Any, cast from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator -from airflow.utils.context import Context from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook from astronomer.providers.amazon.aws.triggers.redshift_sql import RedshiftSQLTrigger +from astronomer.providers.utils.typing_compat import Context class RedshiftSQLOperatorAsync(RedshiftSQLOperator): @@ -30,7 +30,7 @@ def __init__( self.poll_interval = poll_interval super().__init__(**kwargs) - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """ Makes a sync call to RedshiftDataHook and execute the query and gets back the query_ids list and defers trigger to poll for the status for the query executed @@ -38,7 +38,7 @@ def execute(self, context: "Context") -> None: redshift_data_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id) query_ids, response = redshift_data_hook.execute_query(sql=cast(str, self.sql), params=self.params) if response.get("status") == "error": - self.execute_complete({}, response) + self.execute_complete(cast(Context, {}), response) return context["ti"].xcom_push(key="return_value", value=query_ids) self.defer( @@ -52,7 +52,7 @@ def execute(self, context: "Context") -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None: + def execute_complete(self, context: Context, event: Any = None) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/amazon/aws/sensors/batch.py b/astronomer/providers/amazon/aws/sensors/batch.py index e7aaa368c..89c4a8225 100644 --- a/astronomer/providers/amazon/aws/sensors/batch.py +++ b/astronomer/providers/amazon/aws/sensors/batch.py @@ -3,9 +3,9 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.sensors.batch import BatchSensor -from airflow.utils.context import Context from astronomer.providers.amazon.aws.triggers.batch import BatchSensorTrigger +from astronomer.providers.utils.typing_compat import Context class BatchSensorAsync(BatchSensor): @@ -34,7 +34,7 @@ def __init__( self.poll_interval = poll_interval super().__init__(**kwargs) - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """Defers trigger class to poll for state of the job run until it reaches a failure or a success state""" self.defer( timeout=timedelta(seconds=self.timeout), @@ -47,7 +47,7 @@ def execute(self, context: "Context") -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Dict[str, Any]) -> None: + def execute_complete(self, context: Context, event: Dict[str, Any]) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/amazon/aws/sensors/emr.py b/astronomer/providers/amazon/aws/sensors/emr.py index 3d6cd2d2b..2e0aefd2d 100644 --- a/astronomer/providers/amazon/aws/sensors/emr.py +++ b/astronomer/providers/amazon/aws/sensors/emr.py @@ -7,13 +7,13 @@ EmrJobFlowSensor, EmrStepSensor, ) -from airflow.utils.context import Context from astronomer.providers.amazon.aws.triggers.emr import ( EmrContainerSensorTrigger, EmrJobFlowSensorTrigger, EmrStepSensorTrigger, ) +from astronomer.providers.utils.typing_compat import Context class EmrContainerSensorAsync(EmrContainerSensor): @@ -45,7 +45,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[Any, Any], event: Dict[str, str]) -> None: + def execute_complete(self, context: Context, event: Dict[str, str]) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was @@ -92,7 +92,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Dict[str, Any]) -> None: + def execute_complete(self, context: Context, event: Dict[str, Any]) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was @@ -145,7 +145,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[Any, Any], event: Dict[str, str]) -> None: + def execute_complete(self, context: Context, event: Dict[str, str]) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/amazon/aws/sensors/redshift_cluster.py b/astronomer/providers/amazon/aws/sensors/redshift_cluster.py index 0627a567b..a253ce869 100644 --- a/astronomer/providers/amazon/aws/sensors/redshift_cluster.py +++ b/astronomer/providers/amazon/aws/sensors/redshift_cluster.py @@ -3,11 +3,11 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor -from airflow.utils.context import Context from astronomer.providers.amazon.aws.triggers.redshift_cluster import ( RedshiftClusterSensorTrigger, ) +from astronomer.providers.utils.typing_compat import Context class RedshiftClusterSensorAsync(RedshiftClusterSensor): @@ -41,7 +41,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: "Context", event: Optional[Dict[Any, Any]] = None) -> None: + def execute_complete(self, context: Context, event: Optional[Dict[Any, Any]] = None) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/amazon/aws/sensors/s3.py b/astronomer/providers/amazon/aws/sensors/s3.py index 76f9e6b89..a28239be2 100644 --- a/astronomer/providers/amazon/aws/sensors/s3.py +++ b/astronomer/providers/amazon/aws/sensors/s3.py @@ -1,18 +1,18 @@ import typing import warnings from datetime import timedelta -from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast +from typing import Any, Callable, List, Optional, Sequence, Union, cast from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.sensors.s3 import S3KeysUnchangedSensor from airflow.sensors.base import BaseSensorOperator -from airflow.utils.context import Context from astronomer.providers.amazon.aws.triggers.s3 import ( S3KeysUnchangedTrigger, S3KeyTrigger, ) +from astronomer.providers.utils.typing_compat import Context class S3KeySensorAsync(BaseSensorOperator): @@ -91,7 +91,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Any = None) -> Optional[bool]: + def execute_complete(self, context: Context, event: Any = None) -> Optional[bool]: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was @@ -188,7 +188,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Any = None) -> None: + def execute_complete(self, context: Context, event: Any = None) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/apache/hive/sensors/hive_partition.py b/astronomer/providers/apache/hive/sensors/hive_partition.py index cd6993363..34ba199ce 100644 --- a/astronomer/providers/apache/hive/sensors/hive_partition.py +++ b/astronomer/providers/apache/hive/sensors/hive_partition.py @@ -1,13 +1,13 @@ from datetime import timedelta -from typing import Any, Dict, Optional +from typing import Dict, Optional from airflow.exceptions import AirflowException from airflow.providers.apache.hive.sensors.hive_partition import HivePartitionSensor -from airflow.utils.context import Context from astronomer.providers.apache.hive.triggers.hive_partition import ( HivePartitionTrigger, ) +from astronomer.providers.utils.typing_compat import Context class HivePartitionSensorAsync(HivePartitionSensor): @@ -54,7 +54,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Optional[Dict[str, str]] = None) -> str: + def execute_complete(self, context: Context, event: Optional[Dict[str, str]] = None) -> str: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/apache/hive/sensors/named_hive_partition.py b/astronomer/providers/apache/hive/sensors/named_hive_partition.py index 2ead1d41d..136a2a3f1 100644 --- a/astronomer/providers/apache/hive/sensors/named_hive_partition.py +++ b/astronomer/providers/apache/hive/sensors/named_hive_partition.py @@ -5,11 +5,11 @@ from airflow.providers.apache.hive.sensors.named_hive_partition import ( NamedHivePartitionSensor, ) -from airflow.utils.context import Context from astronomer.providers.apache.hive.triggers.named_hive_partition import ( NamedHivePartitionTrigger, ) +from astronomer.providers.utils.typing_compat import Context class NamedHivePartitionSensorAsync(NamedHivePartitionSensor): @@ -41,7 +41,7 @@ class NamedHivePartitionSensorAsync(NamedHivePartitionSensor): :param metastore_conn_id: Metastore thrift service connection id. """ - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """Submit a job to Hive and defer""" if not self.partition_names: raise ValueError("Partition array can't be empty") @@ -55,7 +55,7 @@ def execute(self, context: "Context") -> None: method_name="execute_complete", ) - def execute_complete(self, context: "Context", event: Optional[Dict[str, str]] = None) -> None: + def execute_complete(self, context: Context, event: Optional[Dict[str, str]] = None) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/apache/livy/operators/livy.py b/astronomer/providers/apache/livy/operators/livy.py index 23788daff..b62508bfc 100644 --- a/astronomer/providers/apache/livy/operators/livy.py +++ b/astronomer/providers/apache/livy/operators/livy.py @@ -3,9 +3,9 @@ from airflow.exceptions import AirflowException from airflow.providers.apache.livy.operators.livy import LivyOperator -from airflow.utils.context import Context from astronomer.providers.apache.livy.triggers.livy import LivyTrigger +from astronomer.providers.utils.typing_compat import Context class LivyOperatorAsync(LivyOperator): @@ -40,7 +40,7 @@ class LivyOperatorAsync(LivyOperator): See Tenacity documentation at https://github.com/jd/tenacity """ - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """ Airflow runs this method on the worker and defers using the trigger. Submit the job and get the job_id using which we defer and poll in trigger @@ -61,7 +61,7 @@ def execute(self, context: "Context") -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[Any, Any], event: Dict[str, Any]) -> Any: + def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/cncf/kubernetes/operators/kubernetes_pod.py b/astronomer/providers/cncf/kubernetes/operators/kubernetes_pod.py index 154d99035..5cbc29ef3 100644 --- a/astronomer/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/astronomer/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -5,7 +5,6 @@ from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import ( KubernetesPodOperator, ) -from airflow.utils.context import Context from kubernetes.client import models as k8s from pendulum import DateTime @@ -13,6 +12,7 @@ PodLaunchTimeoutException, WaitContainerTrigger, ) +from astronomer.providers.utils.typing_compat import Context class PodNotFoundException(AirflowException): diff --git a/astronomer/providers/core/sensors/external_task.py b/astronomer/providers/core/sensors/external_task.py index dd268e1a8..a2c716cbe 100644 --- a/astronomer/providers/core/sensors/external_task.py +++ b/astronomer/providers/core/sensors/external_task.py @@ -3,13 +3,13 @@ from airflow.exceptions import AirflowException from airflow.sensors.external_task import ExternalTaskSensor -from airflow.utils.context import Context from airflow.utils.session import provide_session from astronomer.providers.core.triggers.external_task import ( DagStateTrigger, TaskStateTrigger, ) +from astronomer.providers.utils.typing_compat import Context if TYPE_CHECKING: from sqlalchemy.orm.session import Session diff --git a/astronomer/providers/core/sensors/filesystem.py b/astronomer/providers/core/sensors/filesystem.py index 30c1a4cfa..e965139a9 100644 --- a/astronomer/providers/core/sensors/filesystem.py +++ b/astronomer/providers/core/sensors/filesystem.py @@ -4,9 +4,9 @@ from airflow.hooks.filesystem import FSHook from airflow.sensors.filesystem import FileSensor -from airflow.utils.context import Context from astronomer.providers.core.triggers.filesystem import FileTrigger +from astronomer.providers.utils.typing_compat import Context class FileSensorAsync(FileSensor): @@ -41,7 +41,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Optional[Dict[str, Any]]) -> None: + def execute_complete(self, context: Context, event: Optional[Dict[str, Any]]) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/databricks/operators/databricks.py b/astronomer/providers/databricks/operators/databricks.py index 1ea6af696..64ddeb6f5 100644 --- a/astronomer/providers/databricks/operators/databricks.py +++ b/astronomer/providers/databricks/operators/databricks.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any +from typing import Any from airflow.exceptions import AirflowException from airflow.providers.databricks.operators.databricks import ( @@ -9,13 +9,11 @@ ) from astronomer.providers.databricks.triggers.databricks import DatabricksTrigger - -if TYPE_CHECKING: - from airflow.utils.context import Context +from astronomer.providers.utils.typing_compat import Context class DatabricksSubmitRunOperatorAsync(DatabricksSubmitRunOperator): # noqa: D101 - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """ Execute the Databricks trigger, and defer execution as expected. It makes two non-async API calls to submit the run, and retrieve the run page URL. It also pushes these @@ -57,7 +55,7 @@ def execute(self, context: "Context") -> None: method_name="execute_complete", ) - def execute_complete(self, context: "Context", event: Any = None) -> None: + def execute_complete(self, context: Context, event: Any = None) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was @@ -71,7 +69,7 @@ def execute_complete(self, context: "Context", event: Any = None) -> None: class DatabricksRunNowOperatorAsync(DatabricksRunNowOperator): # noqa: D101 - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """ Logic that the operator uses to execute the Databricks trigger, and defer execution as expected. It makes two non-async API calls to @@ -112,7 +110,7 @@ def execute(self, context: "Context") -> None: ) def execute_complete( - self, context: "Context", event: Any = None + self, context: Context, event: Any = None ) -> None: # pylint: disable=unused-argument """ Callback for when the trigger fires - returns immediately. diff --git a/astronomer/providers/dbt/cloud/operators/dbt.py b/astronomer/providers/dbt/cloud/operators/dbt.py index 0f02a402c..c41e3c5c4 100644 --- a/astronomer/providers/dbt/cloud/operators/dbt.py +++ b/astronomer/providers/dbt/cloud/operators/dbt.py @@ -1,14 +1,12 @@ import time -from typing import TYPE_CHECKING, Any, Dict +from typing import Any, Dict from airflow import AirflowException from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger - -if TYPE_CHECKING: # pragma: no cover - from airflow.utils.context import Context +from astronomer.providers.utils.typing_compat import Context class DbtCloudRunJobOperatorAsync(DbtCloudRunJobOperator): diff --git a/astronomer/providers/dbt/cloud/sensors/dbt.py b/astronomer/providers/dbt/cloud/sensors/dbt.py index 5e3330fd1..7c018a0c7 100644 --- a/astronomer/providers/dbt/cloud/sensors/dbt.py +++ b/astronomer/providers/dbt/cloud/sensors/dbt.py @@ -1,13 +1,11 @@ import time -from typing import TYPE_CHECKING, Any, Dict +from typing import Any, Dict from airflow import AirflowException from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger - -if TYPE_CHECKING: # pragma: no cover - from airflow.utils.context import Context +from astronomer.providers.utils.typing_compat import Context class DbtCloudJobRunSensorAsync(DbtCloudJobRunSensor): diff --git a/astronomer/providers/google/cloud/operators/bigquery.py b/astronomer/providers/google/cloud/operators/bigquery.py index d12e4969d..82c9bedc8 100644 --- a/astronomer/providers/google/cloud/operators/bigquery.py +++ b/astronomer/providers/google/cloud/operators/bigquery.py @@ -11,7 +11,6 @@ BigQueryIntervalCheckOperator, BigQueryValueCheckOperator, ) -from airflow.utils.context import Context from google.api_core.exceptions import Conflict from astronomer.providers.google.cloud.triggers.bigquery import ( @@ -21,6 +20,7 @@ BigQueryIntervalCheckTrigger, BigQueryValueCheckTrigger, ) +from astronomer.providers.utils.typing_compat import Context BIGQUERY_JOB_DETAILS_LINK_FMT = "https://console.cloud.google.com/bigquery?j={job_id}" diff --git a/astronomer/providers/google/cloud/operators/dataproc.py b/astronomer/providers/google/cloud/operators/dataproc.py index 37a2bec82..a1d246902 100644 --- a/astronomer/providers/google/cloud/operators/dataproc.py +++ b/astronomer/providers/google/cloud/operators/dataproc.py @@ -14,7 +14,6 @@ DataprocSubmitJobOperator, DataprocUpdateClusterOperator, ) -from airflow.utils.context import Context from google.api_core.exceptions import AlreadyExists from astronomer.providers.google.cloud.triggers.dataproc import ( @@ -22,6 +21,7 @@ DataprocDeleteClusterTrigger, DataProcSubmitTrigger, ) +from astronomer.providers.utils.typing_compat import Context class DataprocCreateClusterOperatorAsync(DataprocCreateClusterOperator): @@ -113,7 +113,7 @@ def execute(self, context: Context) -> None: # type: ignore[override] method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Optional[Dict[str, Any]] = None) -> Any: + def execute_complete(self, context: Context, event: Optional[Dict[str, Any]] = None) -> Any: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was @@ -167,7 +167,7 @@ def __init__( if self.timeout is None: self.timeout: float = 24 * 60 * 60 - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """Call delete cluster API and defer to wait for cluster to completely deleted""" hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) self.log.info("Deleting cluster: %s", self.cluster_name) @@ -196,7 +196,7 @@ def execute(self, context: "Context") -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Optional[Dict[str, Any]] = None) -> Any: + def execute_complete(self, context: Context, event: Optional[Dict[str, Any]] = None) -> Any: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was @@ -239,7 +239,7 @@ class DataprocSubmitJobOperatorAsync(DataprocSubmitJobOperator): :param cancel_on_kill: Flag which indicates whether cancel the hook's job or not, when on_kill is called """ - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """ Airflow runs this method on the worker and defers using the trigger. Submit the job and get the job_id using which we defer and poll in trigger @@ -269,7 +269,7 @@ def execute(self, context: "Context") -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Optional[Dict[str, str]] = None) -> str: + def execute_complete(self, context: Context, event: Optional[Dict[str, str]] = None) -> str: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/google/cloud/operators/kubernetes_engine.py b/astronomer/providers/google/cloud/operators/kubernetes_engine.py index 68b8b45e6..719a2f182 100644 --- a/astronomer/providers/google/cloud/operators/kubernetes_engine.py +++ b/astronomer/providers/google/cloud/operators/kubernetes_engine.py @@ -8,12 +8,12 @@ from airflow.providers.google.cloud.operators.kubernetes_engine import ( GKEStartPodOperator, ) -from airflow.utils.context import Context from kubernetes.client import models as k8s from astronomer.providers.google.cloud.triggers.kubernetes_engine import ( GKEStartPodTrigger, ) +from astronomer.providers.utils.typing_compat import Context class GKEStartPodOperatorAsync(KubernetesPodOperator): @@ -80,7 +80,7 @@ def __init__( self.pod_namespace: str = "" self.poll_interval = poll_interval - def _get_or_create_pod(self, context: "Context") -> None: + def _get_or_create_pod(self, context: Context) -> None: """A wrapper to fetch GKE config and get or create a pod""" with GKEStartPodOperator.get_gke_config_file( gcp_conn_id=self.gcp_conn_id, @@ -97,7 +97,7 @@ def _get_or_create_pod(self, context: "Context") -> None: self.pod_name = self.pod.metadata.name self.pod_namespace = self.pod.metadata.namespace - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """Look for a pod, if not found then create one and defer""" self._get_or_create_pod(context) self.log.info("Created pod=%s in namespace=%s", self.pod_name, self.pod_namespace) diff --git a/astronomer/providers/google/cloud/sensors/bigquery.py b/astronomer/providers/google/cloud/sensors/bigquery.py index 392696762..9da088ef8 100644 --- a/astronomer/providers/google/cloud/sensors/bigquery.py +++ b/astronomer/providers/google/cloud/sensors/bigquery.py @@ -4,11 +4,11 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.sensors.bigquery import BigQueryTableExistenceSensor -from airflow.utils.context import Context from astronomer.providers.google.cloud.triggers.bigquery import ( BigQueryTableExistenceTrigger, ) +from astronomer.providers.utils.typing_compat import Context class BigQueryTableExistenceSensorAsync(BigQueryTableExistenceSensor): @@ -66,7 +66,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Optional[Dict[str, str]] = None) -> str: + def execute_complete(self, context: Context, event: Optional[Dict[str, str]] = None) -> str: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/google/cloud/sensors/gcs.py b/astronomer/providers/google/cloud/sensors/gcs.py index 37d0affbb..496fe4646 100644 --- a/astronomer/providers/google/cloud/sensors/gcs.py +++ b/astronomer/providers/google/cloud/sensors/gcs.py @@ -10,7 +10,6 @@ GCSObjectUpdateSensor, GCSUploadSessionCompleteSensor, ) -from airflow.utils.context import Context from astronomer.providers.google.cloud.triggers.gcs import ( GCSBlobTrigger, @@ -18,6 +17,7 @@ GCSPrefixBlobTrigger, GCSUploadSessionTrigger, ) +from astronomer.providers.utils.typing_compat import Context class GCSObjectExistenceSensorAsync(GCSObjectExistenceSensor): diff --git a/astronomer/providers/http/sensors/http.py b/astronomer/providers/http/sensors/http.py index 8f665d4bf..783aca7ac 100644 --- a/astronomer/providers/http/sensors/http.py +++ b/astronomer/providers/http/sensors/http.py @@ -3,9 +3,9 @@ from airflow.providers.http.hooks.http import HttpHook from airflow.providers.http.sensors.http import HttpSensor -from airflow.utils.context import Context from astronomer.providers.http.triggers.http import HttpTrigger +from astronomer.providers.utils.typing_compat import Context class HttpSensorAsync(HttpSensor): @@ -106,7 +106,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[str, Any], event: Optional[Dict[Any, Any]] = None) -> None: + def execute_complete(self, context: Context, event: Optional[Dict[Any, Any]] = None) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/microsoft/azure/operators/data_factory.py b/astronomer/providers/microsoft/azure/operators/data_factory.py index 7bd69683d..9eed0a019 100644 --- a/astronomer/providers/microsoft/azure/operators/data_factory.py +++ b/astronomer/providers/microsoft/azure/operators/data_factory.py @@ -1,16 +1,16 @@ import time -from typing import Any, Dict +from typing import Dict from airflow.exceptions import AirflowException from airflow.providers.microsoft.azure.hooks.data_factory import AzureDataFactoryHook from airflow.providers.microsoft.azure.operators.data_factory import ( AzureDataFactoryRunPipelineOperator, ) -from airflow.utils.context import Context from astronomer.providers.microsoft.azure.triggers.data_factory import ( AzureDataFactoryTrigger, ) +from astronomer.providers.utils.typing_compat import Context class AzureDataFactoryRunPipelineOperatorAsync(AzureDataFactoryRunPipelineOperator): @@ -45,7 +45,7 @@ class AzureDataFactoryRunPipelineOperatorAsync(AzureDataFactoryRunPipelineOperat Used only if ``wait_for_termination`` """ - def execute(self, context: "Context") -> None: + def execute(self, context: Context) -> None: """Submits a job which generates a run_id and gets deferred""" hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id) response = hook.run_pipeline( @@ -75,7 +75,7 @@ def execute(self, context: "Context") -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[Any, Any], event: Dict[str, str]) -> None: + def execute_complete(self, context: Context, event: Dict[str, str]) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/microsoft/azure/sensors/data_factory.py b/astronomer/providers/microsoft/azure/sensors/data_factory.py index bc5a19871..2fff92dfc 100644 --- a/astronomer/providers/microsoft/azure/sensors/data_factory.py +++ b/astronomer/providers/microsoft/azure/sensors/data_factory.py @@ -5,11 +5,11 @@ from airflow.providers.microsoft.azure.sensors.data_factory import ( AzureDataFactoryPipelineRunStatusSensor, ) -from airflow.utils.context import Context from astronomer.providers.microsoft.azure.triggers.data_factory import ( ADFPipelineRunStatusSensorTrigger, ) +from astronomer.providers.utils.typing_compat import Context class AzureDataFactoryPipelineRunStatusSensorAsync(AzureDataFactoryPipelineRunStatusSensor): @@ -46,7 +46,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[Any, Any], event: Dict[str, str]) -> None: + def execute_complete(self, context: Context, event: Dict[str, str]) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/microsoft/azure/sensors/wasb.py b/astronomer/providers/microsoft/azure/sensors/wasb.py index 4303ea7ca..b4e0ef294 100644 --- a/astronomer/providers/microsoft/azure/sensors/wasb.py +++ b/astronomer/providers/microsoft/azure/sensors/wasb.py @@ -6,12 +6,12 @@ WasbBlobSensor, WasbPrefixSensor, ) -from airflow.utils.context import Context from astronomer.providers.microsoft.azure.triggers.wasb import ( WasbBlobSensorTrigger, WasbPrefixSensorTrigger, ) +from astronomer.providers.utils.typing_compat import Context class WasbBlobSensorAsync(WasbBlobSensor): @@ -56,7 +56,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[Any, Any], event: Dict[str, str]) -> None: + def execute_complete(self, context: Context, event: Dict[str, str]) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was @@ -122,7 +122,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: Dict[Any, Any], event: Dict[str, str]) -> None: + def execute_complete(self, context: Context, event: Dict[str, str]) -> None: """ Callback for when the trigger fires - returns immediately. Relies on trigger to throw an exception, otherwise it assumes execution was diff --git a/astronomer/providers/snowflake/operators/snowflake.py b/astronomer/providers/snowflake/operators/snowflake.py index a70b4eb78..53e6007b0 100644 --- a/astronomer/providers/snowflake/operators/snowflake.py +++ b/astronomer/providers/snowflake/operators/snowflake.py @@ -4,7 +4,6 @@ from airflow.exceptions import AirflowException from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator -from airflow.utils.context import Context from astronomer.providers.snowflake.hooks.snowflake import SnowflakeHookAsync from astronomer.providers.snowflake.hooks.snowflake_sql_api import ( @@ -15,6 +14,7 @@ SnowflakeTrigger, get_db_hook, ) +from astronomer.providers.utils.typing_compat import Context class SnowflakeOperatorAsync(SnowflakeOperator): @@ -124,7 +124,7 @@ def execute(self, context: Context) -> None: ) def execute_complete( - self, context: Dict[str, Any], event: Optional[Dict[str, Union[str, List[str]]]] = None + self, context: Context, event: Optional[Dict[str, Union[str, List[str]]]] = None ) -> None: """ Callback for when the trigger fires - returns immediately. @@ -263,7 +263,7 @@ def execute(self, context: Context) -> None: ) def execute_complete( - self, context: Dict[str, Any], event: Optional[Dict[str, Union[str, List[str]]]] = None + self, context: Context, event: Optional[Dict[str, Union[str, List[str]]]] = None ) -> None: """ Callback for when the trigger fires - returns immediately. diff --git a/astronomer/providers/utils/__init__.py b/astronomer/providers/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/astronomer/providers/utils/typing_compat.py b/astronomer/providers/utils/typing_compat.py new file mode 100644 index 000000000..0ed6bc2e2 --- /dev/null +++ b/astronomer/providers/utils/typing_compat.py @@ -0,0 +1,17 @@ +from typing import Any, MutableMapping + +# The ``airflow.utils.context.Context`` class was not available in Apache Airflow until 2.3.3. This class is +# typically used as the typing for the ``context`` arg in operators and sensors. However, using this class for +# typing outside of TYPE_CHECKING in modules sets an implicit, minimum requirement for Apache Airflow 2.2.3 +# which is currently more recent than the current minimum requirement of Apache Airflow 2.2.0. +# +# TODO: Remove this once the repo has a minimum Apache Airflow requirement of 2.2.3+. +try: + from airflow.utils.context import Context +except ModuleNotFoundError: + + class Context(MutableMapping[str, Any]): # type: ignore[no-redef] + """Placeholder typing class for ``airflow.utils.context.Context``.""" + + +__all__ = ["Context"] diff --git a/tests/utils/test_typing_compat.py b/tests/utils/test_typing_compat.py new file mode 100644 index 000000000..554e37e32 --- /dev/null +++ b/tests/utils/test_typing_compat.py @@ -0,0 +1,19 @@ +from airflow.version import version as airflow_version +from packaging.version import Version + +from astronomer.providers.utils.typing_compat import Context + + +def test_context_typing_compat(): + """ + Ensure the ``Context`` class is imported correctly in the typing_compat module based on the Apache Airflow + version. + + The ``airflow.utils.context.Context`` class was not available in Apache Airflow until 2.3.3. Therefore, + if the Apache Airflow version installed is older than 2.3.3, the ``Context`` class should be imported + directly from the typing_compat module which contains a placeholder ``Context`` class. + """ + if Version(airflow_version).release >= (2, 3, 3): + assert Context.__module__ == "airflow.utils.context" + else: + assert Context.__module__ == "astronomer.providers.utils.typing_compat"