Skip to content

Commit

Permalink
Implemented MSGraphSensor as a deferrable sensor (#39304)
Browse files Browse the repository at this point in the history
* refactor: Implement default response handler method and added test when JSON decode error occurs

* refactor: Reformatted some code to comply to static checks

* refactor: Changed debugging level to debug for printing response in operator

* docs: Added example on how to refresh a PowerBI dataset using the MSGraphAsyncOperator

* refactor: Changed some info logging statements to debug

* refactor: Changed some info logging statements to debug

* fix: Fixed mock_json_response

* refactor: Return content if response is not a JSON

* refactor: Make sure the operator passes the response_handler to the triggerer

* refactor: Should use get instead of directly _getitem_ brackets as payload could not have a response key if call isn't done

* refactor: If event has status failure then the sensor should stop the async poke

* refactor: Changed default_event_processor as not all responses have the status key present

* refactor: Changed default_event_processor as not all responses have the status key present

* refactor: Removed response_handler parameter as lambda cannot be serialized by MSGraphTrigger

* refactor: Changed some logging statements

* refactor: Updated PowerBI dataset refresh example

* refactor: Fixed 2 static check errors

* refactor: Refactored MSGraphSensor as a real async sensor

* refactor: Changed logging level of sensor statements back to debug

* refactor: Fixed 2 static checks

* refactor: Changed docstring hook

* refactor: Put docstring on one line

---------

Co-authored-by: David Blain <[email protected]>
  • Loading branch information
dabla and davidblain-infrabel authored May 5, 2024
1 parent 0e6c0ab commit a61f393
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 129 deletions.
41 changes: 18 additions & 23 deletions airflow/providers/microsoft/azure/hooks/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from __future__ import annotations

import json
from contextlib import suppress
from http import HTTPStatus
from io import BytesIO
from typing import TYPE_CHECKING, Any, Callable
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any
from urllib.parse import quote, urljoin, urlparse

import httpx
Expand Down Expand Up @@ -51,18 +53,17 @@
from airflow.models import Connection


class CallableResponseHandler(ResponseHandler):
"""
CallableResponseHandler executes the passed callable_function with response as parameter.
param callable_function: Function that is applied to the response.
"""
class DefaultResponseHandler(ResponseHandler):
"""DefaultResponseHandler returns JSON payload or content in bytes or response headers."""

def __init__(
self,
callable_function: Callable[[NativeResponseType, dict[str, ParsableFactory | None] | None], Any],
):
self.callable_function = callable_function
@staticmethod
def get_value(response: NativeResponseType) -> Any:
with suppress(JSONDecodeError):
return response.json()
content = response.content
if not content:
return {key: value for key, value in response.headers.items()}
return content

async def handle_response_async(
self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None
Expand All @@ -73,7 +74,7 @@ async def handle_response_async(
param response: The type of the native response object.
param error_map: The error dict to use in case of a failed request.
"""
value = self.callable_function(response, error_map)
value = self.get_value(response)
if response.status_code not in {200, 201, 202, 204, 302}:
message = value or response.reason_phrase
status_code = HTTPStatus(response.status_code)
Expand Down Expand Up @@ -269,20 +270,18 @@ async def run(
self,
url: str = "",
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
headers: dict[str, str] | None = None,
data: dict[str, Any] | str | BytesIO | None = None,
):
self.log.info("Executing url '%s' as '%s'", url, method)

response = await self.get_conn().send_primitive_async(
request_info=self.request_information(
url=url,
response_type=response_type,
response_handler=response_handler,
path_parameters=path_parameters,
method=method,
query_parameters=query_parameters,
Expand All @@ -293,17 +292,14 @@ async def run(
error_map=self.error_mapping(),
)

self.log.debug("response: %s", response)
self.log.info("response: %s", response)

return response

def request_information(
self,
url: str,
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
Expand All @@ -323,12 +319,11 @@ def request_information(
request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}"
if not response_type:
request_information.request_options[ResponseHandlerOption.get_key()] = ResponseHandlerOption(
response_handler=CallableResponseHandler(response_handler)
response_handler=DefaultResponseHandler()
)
headers = {**self.DEFAULT_HEADERS, **headers} if headers else self.DEFAULT_HEADERS
for header_name, header_value in headers.items():
request_information.headers.try_add(header_name=header_name, header_value=header_value)
self.log.info("data: %s", data)
if isinstance(data, BytesIO) or isinstance(data, bytes) or isinstance(data, str):
request_information.content = data
elif data:
Expand Down
15 changes: 2 additions & 13 deletions airflow/providers/microsoft/azure/operators/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@

from kiota_abstractions.request_adapter import ResponseType
from kiota_abstractions.request_information import QueryParams
from kiota_abstractions.response_handler import NativeResponseType
from kiota_abstractions.serialization import ParsableFactory
from msgraph_core import APIVersion

from airflow.utils.context import Context
Expand All @@ -59,9 +57,6 @@ class MSGraphAsyncOperator(BaseOperator):
:param url: The url being executed on the Microsoft Graph API (templated).
:param response_type: The expected return type of the response as a string. Possible value are: `bytes`,
`str`, `int`, `float`, `bool` and `datetime` (default is None).
:param response_handler: Function to convert the native HTTPX response returned by the hook (default is
lambda response, error_map: response.json()). The default expression will convert the native response
to JSON. If response_type parameter is specified, then the response_handler will be ignored.
:param method: The HTTP method being used to do the REST call (default is GET).
:param conn_id: The HTTP Connection ID to run the operator against (templated).
:param key: The key that will be used to store `XCom's` ("return_value" is default).
Expand Down Expand Up @@ -94,9 +89,6 @@ def __init__(
*,
url: str,
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
Expand All @@ -116,7 +108,6 @@ def __init__(
super().__init__(**kwargs)
self.url = url
self.response_type = response_type
self.response_handler = response_handler
self.path_parameters = path_parameters
self.url_template = url_template
self.method = method
Expand All @@ -134,7 +125,6 @@ def __init__(
self.results: list[Any] | None = None

def execute(self, context: Context) -> None:
self.log.info("Executing url '%s' as '%s'", self.url, self.method)
self.defer(
trigger=MSGraphTrigger(
url=self.url,
Expand Down Expand Up @@ -167,14 +157,14 @@ def execute_complete(
self.log.debug("context: %s", context)

if event:
self.log.info("%s completed with %s: %s", self.task_id, event.get("status"), event)
self.log.debug("%s completed with %s: %s", self.task_id, event.get("status"), event)

if event.get("status") == "failure":
raise AirflowException(event.get("message"))

response = event.get("response")

self.log.info("response: %s", response)
self.log.debug("response: %s", response)

if response:
response = self.serializer.deserialize(response)
Expand Down Expand Up @@ -281,7 +271,6 @@ def trigger_next_link(self, response, method_name="execute_complete") -> None:
url=url,
query_parameters=query_parameters,
response_type=self.response_type,
response_handler=self.response_handler,
conn_id=self.conn_id,
timeout=self.timeout,
proxies=self.proxies,
Expand Down
118 changes: 66 additions & 52 deletions airflow/providers/microsoft/azure/sensors/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,32 @@
# under the License.
from __future__ import annotations

import asyncio
import json
from typing import TYPE_CHECKING, Any, Callable, Sequence

from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import TimeDeltaTrigger

if TYPE_CHECKING:
from datetime import timedelta
from io import BytesIO

from kiota_abstractions.request_information import QueryParams
from kiota_abstractions.response_handler import NativeResponseType
from kiota_abstractions.serialization import ParsableFactory
from kiota_http.httpx_request_adapter import ResponseType
from msgraph_core import APIVersion

from airflow.triggers.base import TriggerEvent
from airflow.utils.context import Context


def default_event_processor(context: Context, event: TriggerEvent) -> bool:
if event.payload["status"] == "success":
return json.loads(event.payload["response"])["status"] == "Succeeded"
return False


class MSGraphSensor(BaseSensorOperator):
"""
A Microsoft Graph API sensor which allows you to poll an async REST call to the Microsoft Graph API.
:param url: The url being executed on the Microsoft Graph API (templated).
:param response_type: The expected return type of the response as a string. Possible value are: `bytes`,
`str`, `int`, `float`, `bool` and `datetime` (default is None).
:param response_handler: Function to convert the native HTTPX response returned by the hook (default is
lambda response, error_map: response.json()). The default expression will convert the native response
to JSON. If response_type parameter is specified, then the response_handler will be ignored.
:param method: The HTTP method being used to do the REST call (default is GET).
:param conn_id: The HTTP Connection ID to run the operator against (templated).
:param proxies: A dict defining the HTTP proxies to be used (default is None).
Expand Down Expand Up @@ -85,9 +74,6 @@ def __init__(
self,
url: str,
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
Expand All @@ -97,15 +83,15 @@ def __init__(
conn_id: str = KiotaRequestAdapterHook.default_conn_name,
proxies: dict | None = None,
api_version: APIVersion | None = None,
event_processor: Callable[[Context, TriggerEvent], bool] = default_event_processor,
event_processor: Callable[[Context, Any], bool] = lambda context, e: e.get("status") == "Succeeded",
result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
serializer: type[ResponseSerializer] = ResponseSerializer,
retry_delay: timedelta | float = 60,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(retry_delay=retry_delay, **kwargs)
self.url = url
self.response_type = response_type
self.response_handler = response_handler
self.path_parameters = path_parameters
self.url_template = url_template
self.method = method
Expand All @@ -119,45 +105,73 @@ def __init__(
self.result_processor = result_processor
self.serializer = serializer()

@property
def trigger(self):
return MSGraphTrigger(
url=self.url,
response_type=self.response_type,
response_handler=self.response_handler,
path_parameters=self.path_parameters,
url_template=self.url_template,
method=self.method,
query_parameters=self.query_parameters,
headers=self.headers,
data=self.data,
conn_id=self.conn_id,
timeout=self.timeout,
proxies=self.proxies,
api_version=self.api_version,
serializer=type(self.serializer),
def execute(self, context: Context):
self.defer(
trigger=MSGraphTrigger(
url=self.url,
response_type=self.response_type,
path_parameters=self.path_parameters,
url_template=self.url_template,
method=self.method,
query_parameters=self.query_parameters,
headers=self.headers,
data=self.data,
conn_id=self.conn_id,
timeout=self.timeout,
proxies=self.proxies,
api_version=self.api_version,
serializer=type(self.serializer),
),
method_name=self.execute_complete.__name__,
)

async def async_poke(self, context: Context) -> bool | PokeReturnValue:
self.log.info("Sensor triggered")
def retry_execute(
self,
context: Context,
) -> Any:
self.execute(context=context)

def execute_complete(
self,
context: Context,
event: dict[Any, Any] | None = None,
) -> Any:
"""
Execute callback when MSGraphSensor finishes execution.
This method gets executed automatically when MSGraphTrigger completes its execution.
"""
self.log.debug("context: %s", context)

if event:
self.log.debug("%s completed with %s: %s", self.task_id, event.get("status"), event)

if event.get("status") == "failure":
raise AirflowException(event.get("message"))

response = event.get("response")

self.log.debug("response: %s", response)

async for event in self.trigger.run():
self.log.debug("event: %s", event)
if response:
response = self.serializer.deserialize(response)

is_done = self.event_processor(context, event)
self.log.debug("deserialize response: %s", response)

self.log.debug("is_done: %s", is_done)
is_done = self.event_processor(context, response)

response = self.serializer.deserialize(event.payload["response"])
self.log.debug("is_done: %s", is_done)

self.log.debug("deserialize event: %s", response)
if is_done:
result = self.result_processor(context, response)

result = self.result_processor(context, response)
self.log.debug("processed response: %s", result)

self.log.debug("result: %s", result)
return result

return PokeReturnValue(is_done=is_done, xcom_value=result)
return PokeReturnValue(is_done=True)
self.defer(
trigger=TimeDeltaTrigger(self.retry_delay),
method_name=self.retry_execute.__name__,
)

def poke(self, context) -> bool | PokeReturnValue:
return asyncio.run(self.async_poke(context))
return None
Loading

0 comments on commit a61f393

Please sign in to comment.