Skip to content

Commit

Permalink
Adding MSGraphOperator in Microsoft Azure provider (#38111)
Browse files Browse the repository at this point in the history
* refactor: Initial commit contains the new MSGraphOperator

* refactor: Extracted common method into Base class for patching airflow connection and request adapter + make multiple patches into one context manager Python 3.8 compatible

* refactor: Refactored some typing issues related to msgraph

* refactor: Added some docstrings and fixed additional typing issues

* refactor: Fixed more static checks

* refactor: Added license on top of test serializer and fixed import

* Revert "refactor: Added license on top of test serializer and fixed import"

This reverts commit 04d6b85.

* refactor: Added license on top of serializer files and fixed additional static checks

* refactor: Added new line at end of json test files

* refactor: Try fixing docstrings on operator and serializer

* refactor: Replaced NoneType with None

* refactor: Made type unions Python 3.8 compatible

* refactor: Reformatted some files to comply with static checks formatting

* refactor: Reformatted base to comply with static checks formatting

* refactor: Added msgraph-core dependency to provider.yaml

* refactor: Added msgraph integration info to provider.yaml

* refactor: Added init in resources

* fix: Fixed typing of response_handler

* refactor: Added assertions on conn_id, tenant_id, client_id and client_secret

* refactor: Fixed some static checks

* Revert "refactor: Added assertions on conn_id, tenant_id, client_id and client_secret"

This reverts commit 88aa7dc.

* refactor: Changed imports in hook as we don't use mockito anymore we don't need the module before constructor

* refactor: Renamed test methods

* refactor: Replace List type with list

* refactor: Moved docstring as one line

* refactor: Fixed typing for tests and added test for response_handler

* refactor: Refactored tests

* fix: Fixed MS Graph logo filename

* refactor: Fixed additional static checks remarks

* refactor: Added white line in type checking block

* refactor: Added msgraph-core dependency to provider_dependencies.json

* refactor: Updated docstring on response handler

* refactor: Moved ResponseHandler and Serializer to triggers module

* docs: Added documentation on how to use the MSGraphAsyncOperator

* docs: Fixed END tag in examples

* refactor: Removed docstring from CallableResponseHandler

* refactor: Ignore UP031 Use format specifiers instead of percent format as this is not possible here the way the DAG is evaluated in Airflow (due to XCom's)

* Revert "refactor: Removed docstring from CallableResponseHandler"

This reverts commit 6a14ebe.

* refactor: Simplified docstring on CallableResponseHandler

* refactor: Updated provider.yaml to add reference of msgraph to how-to-guide

* refactor: Updated docstrings on operator and trigger

* refactor: Fixed additional static checks

* refactor: Ignore UP031 Use format specifiers instead of percent format as this is not possible here the way the DAG is evaluated in Airflow (due to XCom's)

* refactor: Added param to docstring ResponseHandler

* refactor: Updated pyproject.toml as main

* refactor: Reformatted docstrings in trigger

* refactor: Removed unused serialization module

* fix: Fixed execution of consecutive tasks in execute_operator method

* refactor: Added customizable pagination_function parameter to Operator and made operator PowerBI compatible

* refactor: Reformatted operator and trigger

* refactor: Added check if query_parameters is not None

* refactor: Removed typing of top and odata_count

* refactor: Ignore type for tenant_id (this is an issue in the ClientSecretCredential class)

* refactor: Changed docstring on MSGraphTrigger

* refactor: Changed docstring on MSGraphTrigger

* refactor: Added docstring to handle_response_async method

* refactor: Fixed docstring to imperative for handle_response_async method

* refactor: Try quoting Sharepoint so it doesn't get spell checked

* refactor: Try double quoting Sharepoint so it doesn't get spell checked

* refactor: Always get a new event loop and close it after test is done

* refactor: Reordered imports from contextlib

* refactor: Added Sharepoint to spelling_wordlist.txt

* refactor: Removed connection-type for KiotaRequestAdapterHook

* refactor: Refactored encoded_query_parameters

* refactor: Suppress ImportError

* refactor: Added return type to paginate method

* refactor: Updated paging_function type in MSGraphAsyncOperator

* refactor: Pass the method name from method reference instead of hard coded string which is re-factor friendly

* refactor: Changed return type of paginate method

* refactor: Added MSGraphSensor which easily allows us to poll PowerBI statuses

* refactor: Moved BytesIO and Context to type checking block for MSGraphSensor

* refactor: Added noqa check on pull_execute_complete method of MSGraphOperator

* fix: Fixed test_serialize of TestMSGraphTrigger

* refactor: Added docstring to MSGraphSensor and updated the docstring of the MSGraphAsyncOperator

* refactor: Reformatted docstring of MSGraphSensor

* refactor: Added white line at end of status.json file to keep static check happy

* refactor: Removed timeout parameter from constructor MSGraphSensor as it is already defined in the BaseSensorOperator

* fix: Added missing return for async_poke in MSGraphSensor

* Revert "refactor: Added noqa check on pull_execute_complete method of MSGraphOperator"

This reverts commit ca6f92c.

* refactor: Reorganised imports on MSGraphSensor

* refactor: Reformatted TestMSGraphSensor

* refactor: Added MSGraph sensor integration name in provider.yaml

* refactor: Updated apache-airflow version to at least 2.7.0 in provider.yaml of microsoft-azure provider

* refactor: Exclude microsoft-azure from compatibility check with airflow 2.6.0 as version 2.7.0 will at least be required

* refactor: Also updated the apache-airflow dependency version from 2.6.0 to 2.7.0 for microsoft-azure provider in provider_dependencies.json

* refactor: Reformatted global_constants.py

* refactor: Add logging statements for proxies and authority related stuff

* fix: Fixed exclusion of microsoft.azure dependency in global_constants.py

* refactor: Some Azure related imports should be ignored when running Airflow 2.6.0 or lower

* refactor: Import of ADLSListOperator should be ignored when running Airflow 2.6.0 or lower

* refactor: Moved optional provider imports that should be ignored when running Airflow 2.6.0 or lower at top of file

* refactor: Fixed the event loop closed issue when executing long running tests on the MSGraphOperator

* refactor: Extracted reusable mock_context method

* refactor: Moved import of Session into type checking block

* refactor: Updated the TestMSGraphSensor

* refactor: Reformatted the mock_context method

* refactor: Try implementing cached connections on MSGraphTrigger

* docs: Added example for the MSGraphSensor and additional examples on how you can use the operator for PowerBI

* Revert "refactor: Try implementing cached connections on MSGraphTrigger"

This reverts commit 693975e.

* fix: Fixed serialization of event payload as xcom_value for the MSGraphSensor

* refactor: TestMSGraphAsyncOperator should be allowed to run as a db test

* Revert "refactor: TestMSGraphAsyncOperator should be allowed to run as a db test"

This reverts commit c7a06db.

* refactor: TestMSGraphAsyncOperator should be allowed to run as a db test

* refactor: Also added result_processor to MSGraphSensor

* refactor: Fixed template_fields in operator, trigger and sensor

---------

Co-authored-by: David Blain <[email protected]>
  • Loading branch information
dabla and davidblain-infrabel authored Apr 14, 2024
1 parent ac1f744 commit 1c9a660
Show file tree
Hide file tree
Showing 28 changed files with 1,973 additions and 10 deletions.
8 changes: 7 additions & 1 deletion airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook

try:
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
except ModuleNotFoundError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(e)

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down
10 changes: 8 additions & 2 deletions airflow/providers/google/cloud/transfers/adls_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@
from typing import TYPE_CHECKING, Sequence

from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
from airflow.providers.microsoft.azure.operators.adls import ADLSListOperator

try:
from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook
from airflow.providers.microsoft.azure.operators.adls import ADLSListOperator
except ModuleNotFoundError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(e)

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@

from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook

try:
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
except ModuleNotFoundError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(e)

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url, gcs_object_is_directory
from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook

try:
from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook
except ModuleNotFoundError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(e)

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down
208 changes: 208 additions & 0 deletions airflow/providers/microsoft/azure/hooks/msgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING
from urllib.parse import urljoin, urlparse

import httpx
from azure.identity import ClientSecretCredential
from httpx import Timeout
from kiota_authentication_azure.azure_identity_authentication_provider import (
AzureIdentityAuthenticationProvider,
)
from kiota_http.httpx_request_adapter import HttpxRequestAdapter
from msgraph_core import GraphClientFactory
from msgraph_core._enums import APIVersion, NationalClouds

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
from kiota_abstractions.request_adapter import RequestAdapter

from airflow.models import Connection


class KiotaRequestAdapterHook(BaseHook):
"""
A Microsoft Graph API interaction hook, a Wrapper around KiotaRequestAdapter.
https://github.com/microsoftgraph/msgraph-sdk-python-core
:param conn_id: The HTTP Connection ID to run the trigger against.
:param timeout: The HTTP timeout being used by the KiotaRequestAdapter (default is None).
When no timeout is specified or set to None then no HTTP timeout is applied on each request.
:param proxies: A Dict defining the HTTP proxies to be used (default is None).
:param api_version: The API version of the Microsoft Graph API to be used (default is v1).
You can pass an enum named APIVersion which has 2 possible members v1 and beta,
or you can pass a string as "v1.0" or "beta".
"""

cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {}
default_conn_name: str = "msgraph_default"

def __init__(
self,
conn_id: str = default_conn_name,
timeout: float | None = None,
proxies: dict | None = None,
api_version: APIVersion | str | None = None,
):
super().__init__()
self.conn_id = conn_id
self.timeout = timeout
self.proxies = proxies
self._api_version = self.resolve_api_version_from_value(api_version)

@property
def api_version(self) -> APIVersion:
self.get_conn() # Make sure config has been loaded through get_conn to have correct api version!
return self._api_version

@staticmethod
def resolve_api_version_from_value(
api_version: APIVersion | str, default: APIVersion | None = None
) -> APIVersion:
if isinstance(api_version, APIVersion):
return api_version
return next(
filter(lambda version: version.value == api_version, APIVersion),
default,
)

def get_api_version(self, config: dict) -> APIVersion:
if self._api_version is None:
return self.resolve_api_version_from_value(
api_version=config.get("api_version"), default=APIVersion.v1
)
return self._api_version

@staticmethod
def get_host(connection: Connection) -> str:
if connection.schema and connection.host:
return f"{connection.schema}://{connection.host}"
return NationalClouds.Global.value

@staticmethod
def format_no_proxy_url(url: str) -> str:
if "://" not in url:
url = f"all://{url}"
return url

@classmethod
def to_httpx_proxies(cls, proxies: dict) -> dict:
proxies = proxies.copy()
if proxies.get("http"):
proxies["http://"] = proxies.pop("http")
if proxies.get("https"):
proxies["https://"] = proxies.pop("https")
if proxies.get("no"):
for url in proxies.pop("no", "").split(","):
proxies[cls.format_no_proxy_url(url.strip())] = None
return proxies

@classmethod
def to_msal_proxies(cls, authority: str | None, proxies: dict):
if authority:
no_proxies = proxies.get("no")
if no_proxies:
for url in no_proxies.split(","):
domain_name = urlparse(url).path.replace("*", "")
if authority.endswith(domain_name):
return None
return proxies

def get_conn(self) -> RequestAdapter:
if not self.conn_id:
raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!")

api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None))

if not request_adapter:
connection = self.get_connection(conn_id=self.conn_id)
client_id = connection.login
client_secret = connection.password
config = connection.extra_dejson if connection.extra else {}
tenant_id = config.get("tenant_id")
api_version = self.get_api_version(config)
host = self.get_host(connection)
base_url = config.get("base_url", urljoin(host, api_version.value))
authority = config.get("authority")
proxies = self.proxies or config.get("proxies", {})
msal_proxies = self.to_msal_proxies(authority=authority, proxies=proxies)
httpx_proxies = self.to_httpx_proxies(proxies=proxies)
scopes = config.get("scopes", ["https://graph.microsoft.com/.default"])
verify = config.get("verify", True)
trust_env = config.get("trust_env", False)
disable_instance_discovery = config.get("disable_instance_discovery", False)
allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",")

self.log.info(
"Creating Microsoft Graph SDK client %s for conn_id: %s",
api_version.value,
self.conn_id,
)
self.log.info("Host: %s", host)
self.log.info("Base URL: %s", base_url)
self.log.info("Tenant id: %s", tenant_id)
self.log.info("Client id: %s", client_id)
self.log.info("Client secret: %s", client_secret)
self.log.info("API version: %s", api_version.value)
self.log.info("Scope: %s", scopes)
self.log.info("Verify: %s", verify)
self.log.info("Timeout: %s", self.timeout)
self.log.info("Trust env: %s", trust_env)
self.log.info("Authority: %s", authority)
self.log.info("Disable instance discovery: %s", disable_instance_discovery)
self.log.info("Allowed hosts: %s", allowed_hosts)
self.log.info("Proxies: %s", proxies)
self.log.info("MSAL Proxies: %s", msal_proxies)
self.log.info("HTTPX Proxies: %s", httpx_proxies)
credentials = ClientSecretCredential(
tenant_id=tenant_id, # type: ignore
client_id=connection.login,
client_secret=connection.password,
authority=authority,
proxies=msal_proxies,
disable_instance_discovery=disable_instance_discovery,
connection_verify=verify,
)
http_client = GraphClientFactory.create_with_default_middleware(
api_version=api_version,
client=httpx.AsyncClient(
proxies=httpx_proxies,
timeout=Timeout(timeout=self.timeout),
verify=verify,
trust_env=trust_env,
),
host=host,
)
auth_provider = AzureIdentityAuthenticationProvider(
credentials=credentials,
scopes=scopes,
allowed_hosts=allowed_hosts,
)
request_adapter = HttpxRequestAdapter(
authentication_provider=auth_provider,
http_client=http_client,
base_url=base_url,
)
self.cached_request_adapters[self.conn_id] = (api_version, request_adapter)
self._api_version = api_version
return request_adapter
Loading

0 comments on commit 1c9a660

Please sign in to comment.