diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py index 84b2252bd25ae..56abfa155da7c 100644 --- a/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -18,9 +18,11 @@ from __future__ import annotations import json +from contextlib import suppress from http import HTTPStatus from io import BytesIO -from typing import TYPE_CHECKING, Any, Callable +from json import JSONDecodeError +from typing import TYPE_CHECKING, Any from urllib.parse import quote, urljoin, urlparse import httpx @@ -51,18 +53,17 @@ from airflow.models import Connection -class CallableResponseHandler(ResponseHandler): - """ - CallableResponseHandler executes the passed callable_function with response as parameter. - - param callable_function: Function that is applied to the response. - """ +class DefaultResponseHandler(ResponseHandler): + """DefaultResponseHandler returns JSON payload or content in bytes or response headers.""" - def __init__( - self, - callable_function: Callable[[NativeResponseType, dict[str, ParsableFactory | None] | None], Any], - ): - self.callable_function = callable_function + @staticmethod + def get_value(response: NativeResponseType) -> Any: + with suppress(JSONDecodeError): + return response.json() + content = response.content + if not content: + return {key: value for key, value in response.headers.items()} + return content async def handle_response_async( self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None @@ -73,7 +74,7 @@ async def handle_response_async( param response: The type of the native response object. param error_map: The error dict to use in case of a failed request. """ - value = self.callable_function(response, error_map) + value = self.get_value(response) if response.status_code not in {200, 201, 202, 204, 302}: message = value or response.reason_phrase status_code = HTTPStatus(response.status_code) @@ -269,20 +270,18 @@ async def run( self, url: str = "", response_type: ResponseType | None = None, - response_handler: Callable[ - [NativeResponseType, dict[str, ParsableFactory | None] | None], Any - ] = lambda response, error_map: response.json(), path_parameters: dict[str, Any] | None = None, method: str = "GET", query_parameters: dict[str, QueryParams] | None = None, headers: dict[str, str] | None = None, data: dict[str, Any] | str | BytesIO | None = None, ): + self.log.info("Executing url '%s' as '%s'", url, method) + response = await self.get_conn().send_primitive_async( request_info=self.request_information( url=url, response_type=response_type, - response_handler=response_handler, path_parameters=path_parameters, method=method, query_parameters=query_parameters, @@ -293,7 +292,7 @@ async def run( error_map=self.error_mapping(), ) - self.log.debug("response: %s", response) + self.log.info("response: %s", response) return response @@ -301,9 +300,6 @@ def request_information( self, url: str, response_type: ResponseType | None = None, - response_handler: Callable[ - [NativeResponseType, dict[str, ParsableFactory | None] | None], Any - ] = lambda response, error_map: response.json(), path_parameters: dict[str, Any] | None = None, method: str = "GET", query_parameters: dict[str, QueryParams] | None = None, @@ -323,12 +319,11 @@ def request_information( request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}" if not response_type: request_information.request_options[ResponseHandlerOption.get_key()] = ResponseHandlerOption( - response_handler=CallableResponseHandler(response_handler) + response_handler=DefaultResponseHandler() ) headers = {**self.DEFAULT_HEADERS, **headers} if headers else self.DEFAULT_HEADERS for header_name, header_value in headers.items(): request_information.headers.try_add(header_name=header_name, header_value=header_value) - self.log.info("data: %s", data) if isinstance(data, BytesIO) or isinstance(data, bytes) or isinstance(data, str): request_information.content = data elif data: diff --git a/airflow/providers/microsoft/azure/operators/msgraph.py b/airflow/providers/microsoft/azure/operators/msgraph.py index 6411f9cc4ac2d..39ca32d2b6106 100644 --- a/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/airflow/providers/microsoft/azure/operators/msgraph.py @@ -39,8 +39,6 @@ from kiota_abstractions.request_adapter import ResponseType from kiota_abstractions.request_information import QueryParams - from kiota_abstractions.response_handler import NativeResponseType - from kiota_abstractions.serialization import ParsableFactory from msgraph_core import APIVersion from airflow.utils.context import Context @@ -59,9 +57,6 @@ class MSGraphAsyncOperator(BaseOperator): :param url: The url being executed on the Microsoft Graph API (templated). :param response_type: The expected return type of the response as a string. Possible value are: `bytes`, `str`, `int`, `float`, `bool` and `datetime` (default is None). - :param response_handler: Function to convert the native HTTPX response returned by the hook (default is - lambda response, error_map: response.json()). The default expression will convert the native response - to JSON. If response_type parameter is specified, then the response_handler will be ignored. :param method: The HTTP method being used to do the REST call (default is GET). :param conn_id: The HTTP Connection ID to run the operator against (templated). :param key: The key that will be used to store `XCom's` ("return_value" is default). @@ -94,9 +89,6 @@ def __init__( *, url: str, response_type: ResponseType | None = None, - response_handler: Callable[ - [NativeResponseType, dict[str, ParsableFactory | None] | None], Any - ] = lambda response, error_map: response.json(), path_parameters: dict[str, Any] | None = None, url_template: str | None = None, method: str = "GET", @@ -116,7 +108,6 @@ def __init__( super().__init__(**kwargs) self.url = url self.response_type = response_type - self.response_handler = response_handler self.path_parameters = path_parameters self.url_template = url_template self.method = method @@ -134,7 +125,6 @@ def __init__( self.results: list[Any] | None = None def execute(self, context: Context) -> None: - self.log.info("Executing url '%s' as '%s'", self.url, self.method) self.defer( trigger=MSGraphTrigger( url=self.url, @@ -167,14 +157,14 @@ def execute_complete( self.log.debug("context: %s", context) if event: - self.log.info("%s completed with %s: %s", self.task_id, event.get("status"), event) + self.log.debug("%s completed with %s: %s", self.task_id, event.get("status"), event) if event.get("status") == "failure": raise AirflowException(event.get("message")) response = event.get("response") - self.log.info("response: %s", response) + self.log.debug("response: %s", response) if response: response = self.serializer.deserialize(response) @@ -281,7 +271,6 @@ def trigger_next_link(self, response, method_name="execute_complete") -> None: url=url, query_parameters=query_parameters, response_type=self.response_type, - response_handler=self.response_handler, conn_id=self.conn_id, timeout=self.timeout, proxies=self.proxies, diff --git a/airflow/providers/microsoft/azure/sensors/msgraph.py b/airflow/providers/microsoft/azure/sensors/msgraph.py index ffbf244dbe88c..3e1b10cbeb1e6 100644 --- a/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -17,33 +17,25 @@ # under the License. from __future__ import annotations -import asyncio -import json from typing import TYPE_CHECKING, Any, Callable, Sequence +from airflow.exceptions import AirflowException from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer -from airflow.sensors.base import BaseSensorOperator, PokeReturnValue +from airflow.sensors.base import BaseSensorOperator +from airflow.triggers.temporal import TimeDeltaTrigger if TYPE_CHECKING: + from datetime import timedelta from io import BytesIO from kiota_abstractions.request_information import QueryParams - from kiota_abstractions.response_handler import NativeResponseType - from kiota_abstractions.serialization import ParsableFactory from kiota_http.httpx_request_adapter import ResponseType from msgraph_core import APIVersion - from airflow.triggers.base import TriggerEvent from airflow.utils.context import Context -def default_event_processor(context: Context, event: TriggerEvent) -> bool: - if event.payload["status"] == "success": - return json.loads(event.payload["response"])["status"] == "Succeeded" - return False - - class MSGraphSensor(BaseSensorOperator): """ A Microsoft Graph API sensor which allows you to poll an async REST call to the Microsoft Graph API. @@ -51,9 +43,6 @@ class MSGraphSensor(BaseSensorOperator): :param url: The url being executed on the Microsoft Graph API (templated). :param response_type: The expected return type of the response as a string. Possible value are: `bytes`, `str`, `int`, `float`, `bool` and `datetime` (default is None). - :param response_handler: Function to convert the native HTTPX response returned by the hook (default is - lambda response, error_map: response.json()). The default expression will convert the native response - to JSON. If response_type parameter is specified, then the response_handler will be ignored. :param method: The HTTP method being used to do the REST call (default is GET). :param conn_id: The HTTP Connection ID to run the operator against (templated). :param proxies: A dict defining the HTTP proxies to be used (default is None). @@ -85,9 +74,6 @@ def __init__( self, url: str, response_type: ResponseType | None = None, - response_handler: Callable[ - [NativeResponseType, dict[str, ParsableFactory | None] | None], Any - ] = lambda response, error_map: response.json(), path_parameters: dict[str, Any] | None = None, url_template: str | None = None, method: str = "GET", @@ -97,15 +83,15 @@ def __init__( conn_id: str = KiotaRequestAdapterHook.default_conn_name, proxies: dict | None = None, api_version: APIVersion | None = None, - event_processor: Callable[[Context, TriggerEvent], bool] = default_event_processor, + event_processor: Callable[[Context, Any], bool] = lambda context, e: e.get("status") == "Succeeded", result_processor: Callable[[Context, Any], Any] = lambda context, result: result, serializer: type[ResponseSerializer] = ResponseSerializer, + retry_delay: timedelta | float = 60, **kwargs, ): - super().__init__(**kwargs) + super().__init__(retry_delay=retry_delay, **kwargs) self.url = url self.response_type = response_type - self.response_handler = response_handler self.path_parameters = path_parameters self.url_template = url_template self.method = method @@ -119,45 +105,73 @@ def __init__( self.result_processor = result_processor self.serializer = serializer() - @property - def trigger(self): - return MSGraphTrigger( - url=self.url, - response_type=self.response_type, - response_handler=self.response_handler, - path_parameters=self.path_parameters, - url_template=self.url_template, - method=self.method, - query_parameters=self.query_parameters, - headers=self.headers, - data=self.data, - conn_id=self.conn_id, - timeout=self.timeout, - proxies=self.proxies, - api_version=self.api_version, - serializer=type(self.serializer), + def execute(self, context: Context): + self.defer( + trigger=MSGraphTrigger( + url=self.url, + response_type=self.response_type, + path_parameters=self.path_parameters, + url_template=self.url_template, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + api_version=self.api_version, + serializer=type(self.serializer), + ), + method_name=self.execute_complete.__name__, ) - async def async_poke(self, context: Context) -> bool | PokeReturnValue: - self.log.info("Sensor triggered") + def retry_execute( + self, + context: Context, + ) -> Any: + self.execute(context=context) + + def execute_complete( + self, + context: Context, + event: dict[Any, Any] | None = None, + ) -> Any: + """ + Execute callback when MSGraphSensor finishes execution. + + This method gets executed automatically when MSGraphTrigger completes its execution. + """ + self.log.debug("context: %s", context) + + if event: + self.log.debug("%s completed with %s: %s", self.task_id, event.get("status"), event) + + if event.get("status") == "failure": + raise AirflowException(event.get("message")) + + response = event.get("response") + + self.log.debug("response: %s", response) - async for event in self.trigger.run(): - self.log.debug("event: %s", event) + if response: + response = self.serializer.deserialize(response) - is_done = self.event_processor(context, event) + self.log.debug("deserialize response: %s", response) - self.log.debug("is_done: %s", is_done) + is_done = self.event_processor(context, response) - response = self.serializer.deserialize(event.payload["response"]) + self.log.debug("is_done: %s", is_done) - self.log.debug("deserialize event: %s", response) + if is_done: + result = self.result_processor(context, response) - result = self.result_processor(context, response) + self.log.debug("processed response: %s", result) - self.log.debug("result: %s", result) + return result - return PokeReturnValue(is_done=is_done, xcom_value=result) - return PokeReturnValue(is_done=True) + self.defer( + trigger=TimeDeltaTrigger(self.retry_delay), + method_name=self.retry_execute.__name__, + ) - def poke(self, context) -> bool | PokeReturnValue: - return asyncio.run(self.async_poke(context)) + return None diff --git a/airflow/providers/microsoft/azure/triggers/msgraph.py b/airflow/providers/microsoft/azure/triggers/msgraph.py index 1848f969f8431..4b9ccb7a71716 100644 --- a/airflow/providers/microsoft/azure/triggers/msgraph.py +++ b/airflow/providers/microsoft/azure/triggers/msgraph.py @@ -27,7 +27,6 @@ TYPE_CHECKING, Any, AsyncIterator, - Callable, Sequence, ) from uuid import UUID @@ -43,8 +42,6 @@ from kiota_abstractions.request_adapter import RequestAdapter from kiota_abstractions.request_information import QueryParams - from kiota_abstractions.response_handler import NativeResponseType - from kiota_abstractions.serialization import ParsableFactory from kiota_http.httpx_request_adapter import ResponseType from msgraph_core import APIVersion @@ -89,9 +86,6 @@ class MSGraphTrigger(BaseTrigger): :param url: The url being executed on the Microsoft Graph API (templated). :param response_type: The expected return type of the response as a string. Possible value are: `bytes`, `str`, `int`, `float`, `bool` and `datetime` (default is None). - :param response_handler: Function to convert the native HTTPX response returned by the hook (default is - lambda response, error_map: response.json()). The default expression will convert the native response - to JSON. If response_type parameter is specified, then the response_handler will be ignored. :param method: The HTTP method being used to do the REST call (default is GET). :param conn_id: The HTTP Connection ID to run the operator against (templated). :param timeout: The HTTP timeout being used by the `KiotaRequestAdapter` (default is None). @@ -119,9 +113,6 @@ def __init__( self, url: str, response_type: ResponseType | None = None, - response_handler: Callable[ - [NativeResponseType, dict[str, ParsableFactory | None] | None], Any - ] = lambda response, error_map: response.json(), path_parameters: dict[str, Any] | None = None, url_template: str | None = None, method: str = "GET", @@ -143,7 +134,6 @@ def __init__( ) self.url = url self.response_type = response_type - self.response_handler = response_handler self.path_parameters = path_parameters self.url_template = url_template self.method = method @@ -207,7 +197,6 @@ async def run(self) -> AsyncIterator[TriggerEvent]: response = await self.hook.run( url=self.url, response_type=self.response_type, - response_handler=self.response_handler, path_parameters=self.path_parameters, method=self.method, query_parameters=self.query_parameters, diff --git a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst index 817b14f783142..342bf542762ac 100644 --- a/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst +++ b/docs/apache-airflow-providers-microsoft-azure/operators/msgraph.rst @@ -64,6 +64,14 @@ Below is an example of using this operator to get PowerBI workspaces info. :start-after: [START howto_operator_powerbi_workspaces_info] :end-before: [END howto_operator_powerbi_workspaces_info] +Below is an example of using this operator to refresh PowerBI dataset. + +.. exampleinclude:: /../../tests/system/providers/microsoft/azure/example_powerbi.py + :language: python + :dedent: 0 + :start-after: [START howto_operator_powerbi_refresh_dataset] + :end-before: [END howto_operator_powerbi_refresh_dataset] + Reference --------- diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py index 9d2db07acf709..71d280a1971da 100644 --- a/tests/providers/microsoft/azure/hooks/test_msgraph.py +++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio +from json import JSONDecodeError from unittest.mock import patch import pytest @@ -24,12 +25,17 @@ from msgraph_core import APIVersion, NationalClouds from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowNotFoundException -from airflow.providers.microsoft.azure.hooks.msgraph import CallableResponseHandler, KiotaRequestAdapterHook +from airflow.providers.microsoft.azure.hooks.msgraph import ( + DefaultResponseHandler, + KiotaRequestAdapterHook, +) from tests.providers.microsoft.conftest import ( get_airflow_connection, + load_file, load_json, mock_connection, mock_json_response, + mock_response, ) @@ -95,45 +101,53 @@ def test_encoded_query_parameters(self): class TestResponseHandler: - def test_handle_response_async_when_ok(self): + def test_default_response_handler_when_json(self): users = load_json("resources", "users.json") response = mock_json_response(200, users) - actual = asyncio.run( - CallableResponseHandler(lambda response, error_map: response.json()).handle_response_async( - response, None - ) - ) + actual = asyncio.run(DefaultResponseHandler().handle_response_async(response, None)) assert isinstance(actual, dict) assert actual == users + def test_default_response_handler_when_not_json(self): + response = mock_json_response(200, JSONDecodeError("", "", 0)) + + actual = asyncio.run(DefaultResponseHandler().handle_response_async(response, None)) + + assert actual == {} + + def test_default_response_handler_when_content(self): + users = load_file("resources", "users.json").encode() + response = mock_response(200, users) + + actual = asyncio.run(DefaultResponseHandler().handle_response_async(response, None)) + + assert isinstance(actual, bytes) + assert actual == users + + def test_default_response_handler_when_no_content_but_headers(self): + response = mock_response(200, headers={"RequestId": "ffb6096e-d409-4826-aaeb-b5d4b165dc4d"}) + + actual = asyncio.run(DefaultResponseHandler().handle_response_async(response, None)) + + assert isinstance(actual, dict) + assert actual["requestid"] == "ffb6096e-d409-4826-aaeb-b5d4b165dc4d" + def test_handle_response_async_when_bad_request(self): response = mock_json_response(400, {}) with pytest.raises(AirflowBadRequest): - asyncio.run( - CallableResponseHandler(lambda response, error_map: response.json()).handle_response_async( - response, None - ) - ) + asyncio.run(DefaultResponseHandler().handle_response_async(response, None)) def test_handle_response_async_when_not_found(self): response = mock_json_response(404, {}) with pytest.raises(AirflowNotFoundException): - asyncio.run( - CallableResponseHandler(lambda response, error_map: response.json()).handle_response_async( - response, None - ) - ) + asyncio.run(DefaultResponseHandler().handle_response_async(response, None)) def test_handle_response_async_when_internal_server_error(self): response = mock_json_response(500, {}) with pytest.raises(AirflowException): - asyncio.run( - CallableResponseHandler(lambda response, error_map: response.json()).handle_response_async( - response, None - ) - ) + asyncio.run(DefaultResponseHandler().handle_response_async(response, None)) diff --git a/tests/providers/microsoft/azure/sensors/test_msgraph.py b/tests/providers/microsoft/azure/sensors/test_msgraph.py index 50fd2474ab454..e257984affb1a 100644 --- a/tests/providers/microsoft/azure/sensors/test_msgraph.py +++ b/tests/providers/microsoft/azure/sensors/test_msgraph.py @@ -16,9 +16,12 @@ # under the License. from __future__ import annotations +import json + from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor +from airflow.triggers.base import TriggerEvent from tests.providers.microsoft.azure.base import Base -from tests.providers.microsoft.conftest import load_json, mock_context, mock_json_response +from tests.providers.microsoft.conftest import load_json, mock_json_response class TestMSGraphSensor(Base): @@ -35,10 +38,16 @@ def test_execute(self): result_processor=lambda context, result: result["id"], timeout=350.0, ) - actual = sensor.execute(context=mock_context(task=sensor)) - assert isinstance(actual, str) - assert actual == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef" + results, events = self.execute_operator(sensor) + + assert isinstance(results, str) + assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef" + assert len(events) == 1 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["status"] == "success" + assert events[0].payload["type"] == "builtins.dict" + assert events[0].payload["response"] == json.dumps(status) def test_template_fields(self): sensor = MSGraphSensor( diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index dfba931023901..aa3c48c5d7e63 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -20,12 +20,13 @@ import json import random import string +from json import JSONDecodeError from os.path import dirname, join from typing import TYPE_CHECKING, Any, Iterable, TypeVar from unittest.mock import MagicMock import pytest -from httpx import Response +from httpx import Headers, Response from msgraph_core import APIVersion from airflow.models import Connection @@ -89,18 +90,21 @@ def mock_connection(schema: str | None = None, host: str | None = None) -> Conne def mock_json_response(status_code, *contents) -> Response: response = MagicMock(spec=Response) response.status_code = status_code + response.headers = Headers({}) + response.content = b"" if contents: - contents = list(contents) - response.json.side_effect = lambda: contents.pop(0) + response.json.side_effect = list(contents) else: response.json.return_value = None return response -def mock_response(status_code, content: Any = None) -> Response: +def mock_response(status_code, content: Any = None, headers: dict | None = None) -> Response: response = MagicMock(spec=Response) response.status_code = status_code + response.headers = Headers(headers or {}) response.content = content + response.json.side_effect = JSONDecodeError("", "", 0) return response diff --git a/tests/system/providers/microsoft/azure/example_powerbi.py b/tests/system/providers/microsoft/azure/example_powerbi.py index cbee9a62af0c4..0a1bfde54a7a9 100644 --- a/tests/system/providers/microsoft/azure/example_powerbi.py +++ b/tests/system/providers/microsoft/azure/example_powerbi.py @@ -66,7 +66,36 @@ ).expand(path_parameters=workspaces_info_task.output) # [END howto_sensor_powerbi_scan_status] + # [START howto_operator_powerbi_refresh_dataset] + refresh_dataset_task = MSGraphAsyncOperator( + task_id="refresh_dataset", + conn_id="powerbi_api", + url="myorg/groups/{workspaceId}/datasets/{datasetId}/refreshes", + method="POST", + path_parameters={ + "workspaceId": "9a7e14c6-9a7d-4b4c-b0f2-799a85e60a51", + "datasetId": "ffb6096e-d409-4826-aaeb-b5d4b165dc4d", + }, + data={"type": "full"}, # Needed for enhanced refresh + result_processor=lambda context, response: response["requestid"], + ) + + refresh_dataset_history_task = MSGraphSensor( + task_id="refresh_dataset_history", + conn_id="powerbi_api", + url="myorg/groups/{workspaceId}/datasets/{datasetId}/refreshes/{refreshId}", + path_parameters={ + "workspaceId": "9a7e14c6-9a7d-4b4c-b0f2-799a85e60a51", + "datasetId": "ffb6096e-d409-4826-aaeb-b5d4b165dc4d", + "refreshId": refresh_dataset_task.output, + }, + timeout=350.0, + event_processor=lambda context, event: event["status"] == "Completed", + ) + # [END howto_operator_powerbi_refresh_dataset] + workspaces_task >> workspaces_info_task >> check_workspace_status_task + refresh_dataset_task >> refresh_dataset_history_task from tests.system.utils.watcher import watcher