Skip to content

Commit

Permalink
Apply PROVIDE_PROJECT_ID mypy workaround across Google provider (#39129)
Browse files Browse the repository at this point in the history
There is a simple workaround implemented several years ago for Google
provider `project_id` default value being PROVIDE_PROJECT_ID that
satisfy mypy checks for project_id being set. They way how
`fallback_to_default_project_id` works is that across all the
providers the project_id is actually set, even if technically
it's default value is set to None.

This is similar typing workaround as we use for NEW_SESSION in the
core of Airflow.

The workaround has not been applied consistently across all the
google provider code and occasionally it causes MyPy complaining
when newer version of a google library introduces more strict
type checking and expects the provider_id to be set.

This PR applies the workaround across all the Google provider
code.

This is - generally speaking a no-op operation. Nothing changes,
except MyPy being aware that the project_id is actually going to
be set even if it is technically set to None.

(cherry picked from commit 90acbfb)
  • Loading branch information
potiuk committed Apr 19, 2024
1 parent eee0478 commit e61cb8f
Show file tree
Hide file tree
Showing 62 changed files with 425 additions and 347 deletions.
65 changes: 36 additions & 29 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.cloud.utils.bigquery import bq_cast
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
get_field,
)

try:
from airflow.utils.hashlib_wrapper import md5
Expand Down Expand Up @@ -150,7 +155,7 @@ def get_service(self) -> Resource:
http_authorized = self._authorize()
return build("bigquery", "v2", http=http_authorized, cache_discovery=False)

def get_client(self, project_id: str | None = None, location: str | None = None) -> Client:
def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None) -> Client:
"""Get an authenticated BigQuery Client.
:param project_id: Project ID for the project which the client acts on behalf of.
Expand Down Expand Up @@ -203,7 +208,7 @@ def get_records(self, sql, parameters=None):
@staticmethod
def _resolve_table_reference(
table_resource: dict[str, Any],
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
dataset_id: str | None = None,
table_id: str | None = None,
) -> dict[str, Any]:
Expand Down Expand Up @@ -313,7 +318,7 @@ def table_partition_exists(
@GoogleBaseHook.fallback_to_default_project_id
def create_empty_table(
self,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
dataset_id: str | None = None,
table_id: str | None = None,
table_resource: dict[str, Any] | None = None,
Expand Down Expand Up @@ -427,7 +432,7 @@ def create_empty_table(
def create_empty_dataset(
self,
dataset_id: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
location: str | None = None,
dataset_reference: dict[str, Any] | None = None,
exists_ok: bool = True,
Expand Down Expand Up @@ -489,7 +494,7 @@ def create_empty_dataset(
def get_dataset_tables(
self,
dataset_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
max_results: int | None = None,
retry: Retry = DEFAULT_RETRY,
) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -518,7 +523,7 @@ def get_dataset_tables(
def delete_dataset(
self,
dataset_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
delete_contents: bool = False,
retry: Retry = DEFAULT_RETRY,
) -> None:
Expand Down Expand Up @@ -567,7 +572,7 @@ def create_external_table(
description: str | None = None,
encryption_configuration: dict | None = None,
location: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> Table:
"""Create an external table in the dataset with data from Google Cloud Storage.
Expand Down Expand Up @@ -703,7 +708,7 @@ def update_table(
fields: list[str] | None = None,
dataset_id: str | None = None,
table_id: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> dict[str, Any]:
"""Change some fields of a table.
Expand Down Expand Up @@ -749,7 +754,7 @@ def patch_table(
self,
dataset_id: str,
table_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
description: str | None = None,
expiration_time: int | None = None,
external_data_configuration: dict | None = None,
Expand Down Expand Up @@ -906,7 +911,7 @@ def update_dataset(
fields: Sequence[str],
dataset_resource: dict[str, Any],
dataset_id: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
retry: Retry = DEFAULT_RETRY,
) -> Dataset:
"""Change some fields of a dataset.
Expand Down Expand Up @@ -952,7 +957,9 @@ def update_dataset(
),
category=AirflowProviderDeprecationWarning,
)
def patch_dataset(self, dataset_id: str, dataset_resource: dict, project_id: str | None = None) -> dict:
def patch_dataset(
self, dataset_id: str, dataset_resource: dict, project_id: str = PROVIDE_PROJECT_ID
) -> dict:
"""Patches information in an existing dataset.
It only replaces fields that are provided in the submitted dataset resource.
Expand Down Expand Up @@ -1000,7 +1007,7 @@ def patch_dataset(self, dataset_id: str, dataset_resource: dict, project_id: str
def get_dataset_tables_list(
self,
dataset_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
table_prefix: str | None = None,
max_results: int | None = None,
) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -1037,7 +1044,7 @@ def get_dataset_tables_list(
@GoogleBaseHook.fallback_to_default_project_id
def get_datasets_list(
self,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
include_all: bool = False,
filter_: str | None = None,
max_results: int | None = None,
Expand Down Expand Up @@ -1087,7 +1094,7 @@ def get_datasets_list(
return datasets_list

@GoogleBaseHook.fallback_to_default_project_id
def get_dataset(self, dataset_id: str, project_id: str | None = None) -> Dataset:
def get_dataset(self, dataset_id: str, project_id: str = PROVIDE_PROJECT_ID) -> Dataset:
"""Fetch the dataset referenced by *dataset_id*.
:param dataset_id: The BigQuery Dataset ID
Expand All @@ -1111,7 +1118,7 @@ def run_grant_dataset_view_access(
view_dataset: str,
view_table: str,
view_project: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> dict[str, Any]:
"""Grant authorized view access of a dataset to a view table.
Expand Down Expand Up @@ -1163,7 +1170,7 @@ def run_grant_dataset_view_access(

@GoogleBaseHook.fallback_to_default_project_id
def run_table_upsert(
self, dataset_id: str, table_resource: dict[str, Any], project_id: str | None = None
self, dataset_id: str, table_resource: dict[str, Any], project_id: str = PROVIDE_PROJECT_ID
) -> dict[str, Any]:
"""Update a table if it exists, otherwise create a new one.
Expand Down Expand Up @@ -1220,7 +1227,7 @@ def delete_table(
self,
table_id: str,
not_found_ok: bool = True,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> None:
"""Delete an existing table from the dataset.
Expand Down Expand Up @@ -1287,7 +1294,7 @@ def list_rows(
selected_fields: list[str] | str | None = None,
page_token: str | None = None,
start_index: int | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
location: str | None = None,
retry: Retry = DEFAULT_RETRY,
return_iterator: bool = False,
Expand Down Expand Up @@ -1340,7 +1347,7 @@ def list_rows(
return list(iterator)

@GoogleBaseHook.fallback_to_default_project_id
def get_schema(self, dataset_id: str, table_id: str, project_id: str | None = None) -> dict:
def get_schema(self, dataset_id: str, table_id: str, project_id: str = PROVIDE_PROJECT_ID) -> dict:
"""Get the schema for a given dataset and table.
.. seealso:: https://cloud.google.com/bigquery/docs/reference/v2/tables#resource
Expand All @@ -1362,7 +1369,7 @@ def update_table_schema(
include_policy_tags: bool,
dataset_id: str,
table_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> dict[str, Any]:
"""Update fields within a schema for a given dataset and table.
Expand Down Expand Up @@ -1455,7 +1462,7 @@ def _remove_policy_tags(schema: list[dict[str, Any]]):
def poll_job_complete(
self,
job_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
location: str | None = None,
retry: Retry = DEFAULT_RETRY,
) -> bool:
Expand Down Expand Up @@ -1485,7 +1492,7 @@ def cancel_query(self) -> None:
def cancel_job(
self,
job_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
location: str | None = None,
) -> None:
"""Cancel a job and wait for cancellation to complete.
Expand Down Expand Up @@ -1529,7 +1536,7 @@ def cancel_job(
def get_job(
self,
job_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
location: str | None = None,
) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob:
"""Retrieve a BigQuery job.
Expand Down Expand Up @@ -1560,7 +1567,7 @@ def insert_job(
self,
configuration: dict,
job_id: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
location: str | None = None,
nowait: bool = False,
retry: Retry = DEFAULT_RETRY,
Expand Down Expand Up @@ -3244,7 +3251,7 @@ async def get_job_instance(
)

async def _get_job(
self, job_id: str | None, project_id: str | None = None, location: str | None = None
self, job_id: str | None, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None
) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob:
"""
Get BigQuery job by its ID, project ID and location.
Expand Down Expand Up @@ -3287,7 +3294,7 @@ def _get_job_sync(self, job_id, project_id, location):
return hook.get_job(job_id=job_id, project_id=project_id, location=location)

async def get_job_status(
self, job_id: str | None, project_id: str | None = None, location: str | None = None
self, job_id: str | None, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None
) -> dict[str, str]:
job = await self._get_job(job_id=job_id, project_id=project_id, location=location)
if job.state == "DONE":
Expand All @@ -3299,7 +3306,7 @@ async def get_job_status(
async def get_job_output(
self,
job_id: str | None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> dict[str, Any]:
"""Get the BigQuery job output for a given job ID asynchronously."""
async with ClientSession() as session:
Expand All @@ -3312,7 +3319,7 @@ async def create_job_for_partition_get(
self,
dataset_id: str | None,
table_id: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
):
"""Create a new job and get the job_id using gcloud-aio."""
async with ClientSession() as session:
Expand Down
9 changes: 7 additions & 2 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
get_field,
)
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -482,7 +487,7 @@ def __init__(
path_prefix: str,
instance_specification: str,
gcp_conn_id: str = "google_cloud_default",
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
sql_proxy_version: str | None = None,
sql_proxy_binary_path: str | None = None,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@
from googleapiclient.errors import HttpError

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
)

if TYPE_CHECKING:
from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import (
Expand Down Expand Up @@ -503,7 +507,7 @@ def operations_contain_expected_statuses(
class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
"""Asynchronous hook for Google Storage Transfer Service."""

def __init__(self, project_id: str | None = None, **kwargs: Any) -> None:
def __init__(self, project_id: str = PROVIDE_PROJECT_ID, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self._client: StorageTransferServiceAsyncClient | None = None
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/google/cloud/hooks/compute_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook
from airflow.providers.google.cloud.hooks.os_login import OSLoginHook
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.utils.types import NOTSET, ArgNotSet

Expand Down Expand Up @@ -109,7 +110,7 @@ def __init__(
instance_name: str | None = None,
zone: str | None = None,
user: str | None = "root",
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
hostname: str | None = None,
use_internal_ip: bool = False,
use_iap_tunnel: bool = False,
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/google/cloud/hooks/dataplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@

from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
)

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand Down Expand Up @@ -665,7 +669,7 @@ def wait_for_data_scan_job(
self,
data_scan_id: str,
job_id: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
region: str | None = None,
wait_time: int = 10,
result_timeout: float | None = None,
Expand Down
Loading

0 comments on commit e61cb8f

Please sign in to comment.