Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deferrable implementation in HTTPSensor #36904

Merged
merged 28 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
105aae4
add defferable to httpsensor
vatsrahul1001 Jan 19, 2024
83cd8c6
removing not required fields
vatsrahul1001 Jan 19, 2024
fb7a5be
adding a diff trigger for sensor
vatsrahul1001 Jan 19, 2024
2f3e09d
adding sensor triggerer
vatsrahul1001 Jan 19, 2024
a9e3339
adding http sensor triggerer
vatsrahul1001 Jan 19, 2024
273c92e
refactor trigger class file
vatsrahul1001 Jan 19, 2024
b1e38d2
add defferable to httpsensor
vatsrahul1001 Jan 19, 2024
6e7e49b
removing not required fields
vatsrahul1001 Jan 19, 2024
5b4e70d
adding a diff trigger for sensor
vatsrahul1001 Jan 19, 2024
7a7ee8f
adding sensor triggerer
vatsrahul1001 Jan 19, 2024
facb5ee
adding http sensor triggerer
vatsrahul1001 Jan 19, 2024
b4d6b10
refactor trigger class file
vatsrahul1001 Jan 19, 2024
a6ad584
Apply suggestions from code review
vatsrahul1001 Jan 20, 2024
33f2277
Merge branch 'add_deferrable_to_httpsensor' of github.com:astronomer/…
vatsrahul1001 Jan 20, 2024
4b06b2f
implementing review comments
vatsrahul1001 Jan 20, 2024
82be69d
Merge branch 'main' into add_deferrable_to_httpsensor
vatsrahul1001 Jan 20, 2024
56819db
Merge branch 'main' into add_deferrable_to_httpsensor
vatsrahul1001 Jan 22, 2024
eaa51cf
Apply suggestions from code review
vatsrahul1001 Jan 22, 2024
176c5f9
implementing review comments
vatsrahul1001 Jan 22, 2024
73b3d76
Update tests/providers/http/sensors/test_http.py
vatsrahul1001 Jan 22, 2024
f9a0e26
updating docs with async
vatsrahul1001 Jan 22, 2024
5faae31
Merge branch 'add_deferrable_to_httpsensor' of github.com:astronomer/…
vatsrahul1001 Jan 22, 2024
ccf5dd5
fixing docs
vatsrahul1001 Jan 22, 2024
c0260e7
Merge branch 'main' into add_deferrable_to_httpsensor
vatsrahul1001 Jan 22, 2024
4ab9621
Merge branch 'main' into add_deferrable_to_httpsensor
vatsrahul1001 Jan 22, 2024
bc6e53c
Merge branch 'main' into add_deferrable_to_httpsensor
vatsrahul1001 Jan 22, 2024
249afd5
renaming test from async to deferrable
vatsrahul1001 Jan 23, 2024
c6d851b
Merge branch 'main' into add_deferrable_to_httpsensor
vatsrahul1001 Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions airflow/providers/http/sensors/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
# under the License.
from __future__ import annotations

from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.http.hooks.http import HttpHook
from airflow.providers.http.triggers.http import HttpSensorTrigger
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -78,6 +81,8 @@ def response_check(response, task_instance):
:param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``)
:param tcp_keep_alive_interval: The TCP Keep Alive interval parameter (corresponds to
``socket.TCP_KEEPINTVL``)
:param deferrable: If waiting for completion, whether to defer the task until done,
default is ``False``
"""

template_fields: Sequence[str] = ("endpoint", "request_params", "headers")
Expand All @@ -97,6 +102,7 @@ def __init__(
tcp_keep_alive_idle: int = 120,
tcp_keep_alive_count: int = 20,
tcp_keep_alive_interval: int = 30,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -114,6 +120,7 @@ def __init__(
self.tcp_keep_alive_idle = tcp_keep_alive_idle
self.tcp_keep_alive_count = tcp_keep_alive_count
self.tcp_keep_alive_interval = tcp_keep_alive_interval
self.deferrable = deferrable

def poke(self, context: Context) -> bool:
from airflow.utils.operator_helpers import determine_kwargs
Expand All @@ -135,9 +142,12 @@ def poke(self, context: Context) -> bool:
headers=self.headers,
extra_options=self.extra_options,
)

if self.response_check:
kwargs = determine_kwargs(self.response_check, [response], context)

return self.response_check(response, **kwargs)

except AirflowException as exc:
if str(exc).startswith(self.response_error_codes_allowlist):
return False
Expand All @@ -148,3 +158,24 @@ def poke(self, context: Context) -> bool:
raise exc

return True

def execute(self, context: Context) -> None:
if not self.deferrable or self.response_check:
super().execute(context=context)
elif not self.poke(context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=HttpSensorTrigger(
endpoint=self.endpoint,
http_conn_id=self.http_conn_id,
data=self.request_params,
headers=self.headers,
method=self.method,
extra_options=self.extra_options,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
self.log.info("%s completed successfully.", self.task_id)
72 changes: 72 additions & 0 deletions airflow/providers/http/triggers/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import asyncio
import base64
import pickle
from typing import TYPE_CHECKING, Any, AsyncIterator
Expand All @@ -24,6 +25,7 @@
from requests.cookies import RequestsCookieJar
from requests.structures import CaseInsensitiveDict

from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

Expand Down Expand Up @@ -124,3 +126,73 @@ async def _convert_response(client_response: ClientResponse) -> requests.Respons
cookies.set(k, v)
response.cookies = cookies
return response


class HttpSensorTrigger(BaseTrigger):
"""
A trigger that fires when the request to a URL returns a non-404 status code.

:param endpoint: The relative part of the full url
:param http_conn_id: The HTTP Connection ID to run the sensor against
:param method: The HTTP request method to use
:param data: payload to be uploaded or aiohttp parameters
:param headers: The HTTP headers to be added to the GET request
:param extra_options: Additional kwargs to pass when creating a request.
For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)``
:param poke_interval: Time to sleep using asyncio
"""

def __init__(
self,
endpoint: str | None = None,
http_conn_id: str = "http_default",
method: str = "GET",
data: dict[str, Any] | str | None = None,
headers: dict[str, str] | None = None,
extra_options: dict[str, Any] | None = None,
poke_interval: float = 5.0,
):
super().__init__()
self.endpoint = endpoint
self.method = method
self.data = data
self.headers = headers
self.extra_options = extra_options or {}
self.http_conn_id = http_conn_id
self.poke_interval = poke_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes HttpTrigger arguments and classpath."""
return (
"airflow.providers.http.triggers.http.HttpSensorTrigger",
{
"endpoint": self.endpoint,
"data": self.data,
"headers": self.headers,
"extra_options": self.extra_options,
"http_conn_id": self.http_conn_id,
"poke_interval": self.poke_interval,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Makes a series of asynchronous http calls via an http hook."""
hook = self._get_async_hook()
while True:
try:
await hook.run(
endpoint=self.endpoint,
data=self.data,
headers=self.headers,
extra_options=self.extra_options,
)
yield TriggerEvent(True)
except AirflowException as exc:
if str(exc).startswith("404"):
await asyncio.sleep(self.poke_interval)

def _get_async_hook(self) -> HttpAsyncHook:
return HttpAsyncHook(
method=self.method,
http_conn_id=self.http_conn_id,
)
54 changes: 53 additions & 1 deletion tests/providers/http/sensors/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
import pytest
import requests

from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException
from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException, TaskDeferred
from airflow.models.dag import DAG
from airflow.providers.http.operators.http import HttpOperator
from airflow.providers.http.sensors.http import HttpSensor
from airflow.providers.http.triggers.http import HttpSensorTrigger
from airflow.utils.timezone import datetime

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -330,3 +331,54 @@ def test_sensor(self):
dag=self.dag,
)
sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)


class TestHttpSensorAsync:
@mock.patch("airflow.providers.http.sensors.http.HttpSensor.defer")
@mock.patch(
"airflow.providers.http.sensors.http.HttpSensor.poke",
return_value=True,
)
def test_execute_finished_before_deferred(
self,
mock_poke,
mock_defer,
):
"""
Asserts that a task is not deferred when task is already finished
"""

task = HttpSensor(task_id="run_now", endpoint="test-endpoint", deferrable=True)

task.execute({})
assert not mock_defer.called

@mock.patch(
"airflow.providers.http.sensors.http.HttpSensor.poke",
return_value=False,
)
def test_execute_is_deferred(self, mock_poke):
"""
Asserts that a task is deferred and a HttpTrigger will be fired
when the HttpSensor is executed in deferrable mode.
"""

task = HttpSensor(task_id="run_now", endpoint="test-endpoint", deferrable=True)

with pytest.raises(TaskDeferred) as exc:
task.execute({})

assert isinstance(exc.value.trigger, HttpSensorTrigger), "Trigger is not a HttpTrigger"

@mock.patch("airflow.providers.http.sensors.http.HttpSensor.defer")
@mock.patch("airflow.sensors.base.BaseSensorOperator.execute")
def test_sensor_not_defer(self, mock_execute, mock_defer):
vatsrahul1001 marked this conversation as resolved.
Show resolved Hide resolved
vatsrahul1001 marked this conversation as resolved.
Show resolved Hide resolved
task = HttpSensor(
task_id="run_now",
endpoint="test-endpoint",
response_check=lambda response: "httpbin" in response.text,
deferrable=True,
)
task.execute({})
mock_execute.assert_called_once()
mock_defer.assert_not_called()