Skip to content

Commit

Permalink
Deprecate HttpSensorAsync (#1459)
Browse files Browse the repository at this point in the history
This PR deprecates HttpSensorAsync by proxying them to their Airflow OSS provider's.

closes:#1420
---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
vatsrahul1001 and pre-commit-ci[bot] authored Jan 29, 2024
1 parent 6167579 commit 63b0421
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 214 deletions.
18 changes: 15 additions & 3 deletions astronomer/providers/core/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import warnings
from typing import TYPE_CHECKING, Any

from airflow.providers.http.hooks.http import HttpHook
from airflow.providers.http.sensors.http import HttpSensor
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.session import provide_session

Expand All @@ -12,7 +14,6 @@
ExternalDeploymentTaskTrigger,
TaskStateTrigger,
)
from astronomer.providers.http.sensors.http import HttpSensorAsync
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context

Expand Down Expand Up @@ -101,10 +102,10 @@ def get_execution_dates(self, context: Context) -> list[datetime.datetime]:
return execution_dates


class ExternalDeploymentTaskSensorAsync(HttpSensorAsync):
class ExternalDeploymentTaskSensorAsync(HttpSensor):
"""
External deployment task sensor Make HTTP call and poll for the response state of externally
deployed DAG task to complete. Inherits from HttpSensorAsync, the host should be external deployment url, header
deployed DAG task to complete. Inherits from HttpSensor, the host should be external deployment url, header
with access token
.. seealso::
Expand All @@ -125,6 +126,17 @@ class ExternalDeploymentTaskSensorAsync(HttpSensorAsync):
:param poke_interval: Time in seconds that the job should wait in between each tries
""" # noqa

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.hook = HttpHook(
method=self.method,
http_conn_id=self.http_conn_id,
tcp_keep_alive=self.tcp_keep_alive,
tcp_keep_alive_idle=self.tcp_keep_alive_idle,
tcp_keep_alive_count=self.tcp_keep_alive_count,
tcp_keep_alive_interval=self.tcp_keep_alive_interval,
)

def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state"""
self.defer(
Expand Down
19 changes: 8 additions & 11 deletions astronomer/providers/http/hooks/http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import warnings
from typing import TYPE_CHECKING, Any, Callable

import aiohttp
Expand All @@ -15,17 +16,8 @@

class HttpHookAsync(BaseHook):
"""
Interact with HTTP servers using Python Async.
:param method: the API method to be called
:param http_conn_id: http connection id that has the base
API url i.e https://www.google.com/ and optional authentication credentials. Default
headers can also be specified in the Extra field in json format.
:param auth_type: The auth type for the service
:param keep_response: Keep the aiohttp response returned by run method without releasing it.
Use it with caution. Without properly releasing response, it might cause "Unclosed connection" error.
See https://github.com/astronomer/astronomer-providers/issues/909
:type auth_type: AuthBase of python aiohttp lib
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.http.hooks.http.HttpAsyncHook` instead.
"""

conn_name_attr = "http_conn_id"
Expand All @@ -43,6 +35,11 @@ def __init__(
*,
keep_response: bool = False,
) -> None:
warnings.warn(
("This class is deprecated. " "Use `~airflow.providers.http.hooks.http.HttpAsyncHook` instead."),
DeprecationWarning,
stacklevel=2,
)
self.http_conn_id = http_conn_id
self.method = method.upper()
self.base_url: str = ""
Expand Down
130 changes: 17 additions & 113 deletions astronomer/providers/http/sensors/http.py
Original file line number Diff line number Diff line change
@@ -1,129 +1,33 @@
import warnings
from datetime import timedelta
from typing import Any, Dict, Optional
from typing import Any

from airflow.providers.http.hooks.http import HttpHook
from airflow.providers.http.sensors.http import HttpSensor

from astronomer.providers.http.triggers.http import HttpTrigger
from astronomer.providers.utils.sensor_util import poke
from astronomer.providers.utils.typing_compat import Context


class HttpSensorAsync(HttpSensor):
"""
Executes a HTTP GET statement and returns False on failure caused by
404 Not Found or `response_check` returning False.
.. note::
If ``response_check`` is passed, the sync version of the sensor will be used.
The response check can access the template context to the operator:
.. code-block:: python
def response_check(response, task_instance):
# The task_instance is injected, so you can pull data form xcom
# Other context variables such as dag, ds, execution_date are also available.
xcom_data = task_instance.xcom_pull(task_ids="pushing_task")
# In practice you would do something more sensible with this data..
print(xcom_data)
return True
HttpSensorAsync(task_id="my_http_sensor", ..., response_check=response_check)
:param http_conn_id: The Connection ID to run the sensor against
:type http_conn_id: str
:param method: The HTTP request method to use
:type method: str
:param endpoint: The relative part of the full url
:type endpoint: str
:param request_params: The parameters to be added to the GET url
:type request_params: a dictionary of string key/value pairs
:param headers: The HTTP headers to be added to the GET request
:type headers: a dictionary of string key/value pairs
:param response_check: A check against the 'requests' response object.
The callable takes the response object as the first positional argument
and optionally any number of keyword arguments available in the context dictionary.
It should return True for 'pass' and False otherwise.
Currently if this parameter is specified then sync version of the sensor will be used.
:type response_check: A lambda or defined function.
:param extra_options: Extra options for the 'requests' library, see the
'requests' documentation (options to modify timeout, ssl, etc.)
:type extra_options: A dictionary of options, where key is string and value
depends on the option that's being modified.
:param tcp_keep_alive: Enable TCP Keep Alive for the connection.
:param tcp_keep_alive_idle: The TCP Keep Alive Idle parameter (corresponds to ``socket.TCP_KEEPIDLE``).
:param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``)
:param tcp_keep_alive_interval: The TCP Keep Alive interval parameter (corresponds to
``socket.TCP_KEEPINTVL``)
This class is deprecated.
Use :class: `~airflow.providers.http.sensors.http.HttpSensor` instead
and set `deferrable` param to `True` instead.
"""

def __init__(
self,
*,
endpoint: str,
poll_interval: float = 5,
**kwargs: Any,
) -> None:
self.endpoint = endpoint
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.http.sensors.http.HttpSensor` "
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
# TODO: Remove once deprecated
if kwargs.get("poke_interval") is None:
self.poke_interval = poll_interval
if kwargs.get("poll_interval"):
warnings.warn(
"Argument `poll_interval` is deprecated and will be removed "
"in a future release. Please use `poke_interval` instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(endpoint=endpoint, **kwargs)
try:
# for apache-airflow-providers-http>=4.0.0
self.hook = HttpHook(
method=self.method,
http_conn_id=self.http_conn_id,
tcp_keep_alive=self.tcp_keep_alive,
tcp_keep_alive_idle=self.tcp_keep_alive_idle,
tcp_keep_alive_count=self.tcp_keep_alive_count,
tcp_keep_alive_interval=self.tcp_keep_alive_interval,
)
except AttributeError:
# for apache-airflow-providers-http<4.0.0
# Since the hook is an instance variable of the operator, we need no action.
pass

def execute(self, context: Context) -> None:
"""
Logic that the sensor uses to correctly identify which trigger to
execute, and defer execution as expected.
"""
# TODO: We can't currently serialize arbitrary function
# Maybe we set method_name as users function??? to run it again
# and evaluate the response.
if self.response_check:
self.log.warning("Since response_check param is passed, using the sync version of the sensor.")
super().execute(context=context)
elif not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=HttpTrigger(
endpoint=self.endpoint,
http_conn_id=self.http_conn_id,
method=self.hook.method, # TODO: Fix this to directly get method from ctor
data=self.request_params,
headers=self.headers,
extra_options=self.extra_options,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: Optional[Dict[Any, Any]] = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
self.log.info("%s completed successfully.", self.task_id)
return None
kwargs["poke_interval"] = kwargs.pop("poll_interval")
super().__init__(*args, deferrable=True, **kwargs)
22 changes: 11 additions & 11 deletions astronomer/providers/http/triggers/http.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import warnings
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union

from airflow.exceptions import AirflowException
Expand All @@ -9,17 +10,8 @@

class HttpTrigger(BaseTrigger):
"""
A trigger that fires when the request to a URL returns a non-404 status code
:param endpoint: The relative part of the full url
:param http_conn_id: The HTTP Connection ID to run the sensor against
:param method: The HTTP request method to use
:param data: payload to be uploaded or aiohttp parameters
:param headers: The HTTP headers to be added to the GET request
:type headers: a dictionary of string key/value pairs
:param extra_options: Additional kwargs to pass when creating a request.
For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)``
:param poke_interval: Time to sleep using asyncio
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.http.triggers.http.HttpSensorTrigger` instead.
"""

def __init__(
Expand All @@ -32,6 +24,14 @@ def __init__(
extra_options: Optional[Dict[str, Any]] = None,
poke_interval: float = 5.0,
):
warnings.warn(
(
"This class is deprecated. "
"Use `~airflow.providers.http.triggers.http.HttpSensorTrigger` instead."
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.endpoint = endpoint
self.method = method
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ google =
gcloud-aio-storage
gcloud-aio-bigquery
http =
apache-airflow-providers-http
apache-airflow-providers-http>=4.9.0
microsoft.azure =
apache-airflow-providers-microsoft-azure>=8.5.1
sftp =
Expand Down Expand Up @@ -124,7 +124,7 @@ all =
apache-airflow-providers-cncf-kubernetes>=4
apache-airflow-providers-databricks>=6.1.0
apache-airflow-providers-google>=10.14.0
apache-airflow-providers-http
apache-airflow-providers-http>=4.9.0
apache-airflow-providers-snowflake>=5.3.0
apache-airflow-providers-sftp
apache-airflow-providers-microsoft-azure>=8.5.1
Expand Down
78 changes: 4 additions & 74 deletions tests/http/sensors/test_http.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,16 @@
from unittest import mock

import pytest
from airflow.exceptions import TaskDeferred
from airflow.providers.http.sensors.http import HttpSensor

from astronomer.providers.http.sensors.http import HttpSensorAsync
from astronomer.providers.http.triggers.http import HttpTrigger

MODULE = "astronomer.providers.http.sensors.http"


class TestHttpSensorAsync:
@mock.patch(f"{MODULE}.HttpSensorAsync.defer")
@mock.patch(
f"{MODULE}.HttpSensorAsync.poke",
return_value=True,
)
def test_http_sensor_async_finish_before_deferred(
self,
mock_poke,
mock_defer,
):
"""
Asserts that a task is deferred and a HttpTrigger will be fired
when the HttpSensorAsync is executed.
"""

task = HttpSensorAsync(
task_id="run_now",
endpoint="test-endpoint",
)

task.execute({})
assert not mock_defer.called

@mock.patch(
f"{MODULE}.HttpSensorAsync.poke",
return_value=False,
)
def test_http_run_now_sensor_async(self, mock_poke):
"""
Asserts that a task is deferred and a HttpTrigger will be fired
when the HttpSensorAsync is executed.
"""

task = HttpSensorAsync(
task_id="run_now",
endpoint="test-endpoint",
)

with pytest.raises(TaskDeferred) as exc:
task.execute({})

assert isinstance(exc.value.trigger, HttpTrigger), "Trigger is not a HttpTrigger"

@mock.patch("astronomer.providers.http.sensors.http.HttpSensorAsync.defer")
@mock.patch("airflow.sensors.base.BaseSensorOperator.execute")
def test_sensor_not_defer(self, mock_execute, mock_defer):
task = HttpSensorAsync(
task_id="run_now",
endpoint="test-endpoint",
response_check=lambda response: "httpbin" in response.text,
)
task.execute({})
mock_execute.assert_called_once()
mock_defer.assert_not_called()

@mock.patch("airflow.providers.http.sensors.http.HttpSensor.poke")
def test_sensor_defer(self, mock_poke):
mock_poke.return_value = False
def test_init(self):
task = HttpSensorAsync(task_id="run_now", endpoint="test-endpoint")
with pytest.raises(TaskDeferred) as exc:
task.execute({})
assert isinstance(exc.value.trigger, HttpTrigger), "Trigger is not a HttpTrigger"

@mock.patch("astronomer.providers.http.sensors.http.HttpSensor")
def test_http_sensor_async_hook_initialisation_attribute_error(self, mock_http_sensor):
"""
Asserts that an attribute error that may be raised across different versions of the HTTP provider is handled
while initialising the hook in the sensor.
"""
mock_http_sensor.side_effect = AttributeError()
HttpSensorAsync(task_id="check_hook_initialisation", endpoint="")
assert isinstance(task, HttpSensor)
assert task.deferrable is True

def test_poll_interval_deprecation_warning(self):
"""Test DeprecationWarning for HttpSensorAsync by setting param poll_interval"""
Expand Down

0 comments on commit 63b0421

Please sign in to comment.