Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callback to process Azure Service Bus message contents #41601

Merged
merged 4 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions airflow/providers/microsoft/azure/hooks/asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusSender
from typing import TYPE_CHECKING, Any, Callable

from azure.servicebus import (
ServiceBusClient,
ServiceBusMessage,
ServiceBusReceivedMessage,
ServiceBusReceiver,
ServiceBusSender,
)
from azure.servicebus.management import QueueProperties, ServiceBusAdministrationClient

from airflow.hooks.base import BaseHook
Expand All @@ -28,6 +34,9 @@
get_sync_default_azure_credential,
)

MessageCallback = Callable[[ServiceBusMessage], None]


if TYPE_CHECKING:
from azure.identity import DefaultAzureCredential

Expand Down Expand Up @@ -270,14 +279,21 @@ def send_batch_message(sender: ServiceBusSender, messages: list[str]):
sender.send_messages(batch_message)

def receive_message(
self, queue_name, max_message_count: int | None = 1, max_wait_time: float | None = None
self,
queue_name: str,
max_message_count: int | None = 1,
max_wait_time: float | None = None,
message_callback: MessageCallback | None = None,
):
"""
Receive a batch of messages at once in a specified Queue name.

:param queue_name: The name of the queue name or a QueueProperties with name.
:param max_message_count: Maximum number of messages in the batch.
:param max_wait_time: Maximum time to wait in seconds for the first message to arrive.
:param message_callback: Optional callback to process each message. If not provided, then
the message will be logged and completed. If provided, and throws an exception, the
message will be abandoned for future redelivery.
"""
if queue_name is None:
raise TypeError("Queue name cannot be None.")
Expand All @@ -289,15 +305,15 @@ def receive_message(
max_message_count=max_message_count, max_wait_time=max_wait_time
)
for msg in received_msgs:
self.log.info(msg)
receiver.complete_message(msg)
self._process_message(msg, message_callback, receiver)

def receive_subscription_message(
self,
topic_name: str,
subscription_name: str,
max_message_count: int | None,
max_wait_time: float | None,
message_callback: MessageCallback | None = None,
):
"""
Receive a batch of subscription message at once.
Expand Down Expand Up @@ -326,5 +342,32 @@ def receive_subscription_message(
max_message_count=max_message_count, max_wait_time=max_wait_time
)
for msg in received_msgs:
self.log.info(msg)
subscription_receiver.complete_message(msg)
self._process_message(msg, message_callback, subscription_receiver)

def _process_message(
self,
msg: ServiceBusReceivedMessage,
message_callback: MessageCallback | None,
receiver: ServiceBusReceiver,
):
"""
Process the message by calling the message_callback or logging the message.

:param msg: The message to process.
:param message_callback: Optional callback to process each message. If not provided, then
the message will be logged and completed. If provided, and throws an exception, the
message will be abandoned for future redelivery.
:param receiver: The receiver that received the message.
"""
if message_callback is None:
self.log.info(msg)
potiuk marked this conversation as resolved.
Show resolved Hide resolved
receiver.complete_message(msg)
else:
try:
message_callback(msg)
except Exception as e:
self.log.error("Error processing message: %s", e)
receiver.abandon_message(msg)
raise e
else:
receiver.complete_message(msg)
26 changes: 23 additions & 3 deletions airflow/providers/microsoft/azure/operators/asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any, Callable, Sequence

from azure.core.exceptions import ResourceNotFoundError

Expand All @@ -26,10 +26,13 @@
if TYPE_CHECKING:
import datetime

from azure.servicebus import ServiceBusMessage
from azure.servicebus.management._models import AuthorizationRule

from airflow.utils.context import Context

MessageCallback = Callable[[ServiceBusMessage], None]


class AzureServiceBusCreateQueueOperator(BaseOperator):
"""
Expand Down Expand Up @@ -140,6 +143,9 @@ class AzureServiceBusReceiveMessageOperator(BaseOperator):
:param max_wait_time: Maximum time to wait in seconds for the first message to arrive.
:param azure_service_bus_conn_id: Reference to the
:ref: `Azure Service Bus connection <howto/connection:azure_service_bus>`.
:param message_callback: Optional callback to process each message. If not provided, then
the message will be logged and completed. If provided, and throws an exception, the
message will be abandoned for future redelivery.
"""

template_fields: Sequence[str] = ("queue_name",)
Expand All @@ -152,13 +158,15 @@ def __init__(
azure_service_bus_conn_id: str = "azure_service_bus_default",
max_message_count: int = 10,
max_wait_time: float = 5,
message_callback: MessageCallback | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.queue_name = queue_name
self.azure_service_bus_conn_id = azure_service_bus_conn_id
self.max_message_count = max_message_count
self.max_wait_time = max_wait_time
self.message_callback = message_callback

def execute(self, context: Context) -> None:
"""Receive Message in specific queue in Service Bus namespace by connecting to Service Bus client."""
Expand All @@ -167,7 +175,10 @@ def execute(self, context: Context) -> None:

# Receive message
hook.receive_message(
self.queue_name, max_message_count=self.max_message_count, max_wait_time=self.max_wait_time
self.queue_name,
max_message_count=self.max_message_count,
max_wait_time=self.max_wait_time,
message_callback=self.message_callback,
)


Expand Down Expand Up @@ -515,6 +526,9 @@ class ASBReceiveSubscriptionMessageOperator(BaseOperator):
an empty list will be returned.
:param azure_service_bus_conn_id: Reference to the
:ref:`Azure Service Bus connection <howto/connection:azure_service_bus>`.
:param message_callback: Optional callback to process each message. If not provided, then
the message will be logged and completed. If provided, and throws an exception, the
message will be abandoned for future redelivery.
"""

template_fields: Sequence[str] = ("topic_name", "subscription_name")
Expand All @@ -528,6 +542,7 @@ def __init__(
max_message_count: int | None = 1,
max_wait_time: float | None = 5,
azure_service_bus_conn_id: str = "azure_service_bus_default",
message_callback: MessageCallback | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -536,6 +551,7 @@ def __init__(
self.max_message_count = max_message_count
self.max_wait_time = max_wait_time
self.azure_service_bus_conn_id = azure_service_bus_conn_id
self.message_callback = message_callback

def execute(self, context: Context) -> None:
"""Receive Message in specific queue in Service Bus namespace by connecting to Service Bus client."""
Expand All @@ -544,7 +560,11 @@ def execute(self, context: Context) -> None:

# Receive message
hook.receive_subscription_message(
self.topic_name, self.subscription_name, self.max_message_count, self.max_wait_time
self.topic_name,
self.subscription_name,
self.max_message_count,
self.max_wait_time,
message_callback=self.message_callback,
)


Expand Down
63 changes: 62 additions & 1 deletion tests/providers/microsoft/azure/hooks/test_asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
# under the License.
from __future__ import annotations

from typing import Any
from unittest import mock

import pytest

try:
from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusMessageBatch
from azure.servicebus import (
ServiceBusClient,
ServiceBusMessage,
ServiceBusMessageBatch,
)
from azure.servicebus.management import ServiceBusAdministrationClient
except ImportError:
pytest.skip("Azure Service Bus not available", allow_module_level=True)
Expand Down Expand Up @@ -265,6 +270,31 @@ def test_receive_message(self, mock_sb_client, mock_service_bus_message):
]
mock_sb_client.assert_has_calls(expected_calls)

@mock.patch("azure.servicebus.ServiceBusReceivedMessage")
@mock.patch(f"{MODULE}.MessageHook.get_conn", autospec=True)
def test_receive_message_callback(self, mock_sb_client, mock_service_bus_message):
"""
Test `receive_message` hook function and assert the function with mock value,
mock the azure service bus `receive_messages` function
"""
hook = MessageHook(azure_service_bus_conn_id=self.conn_id)

mock_sb_client.return_value.__enter__.return_value.get_queue_receiver.return_value.__enter__.return_value.receive_messages.return_value = [
mock_service_bus_message
]

received_messages = []

def message_callback(msg: Any) -> None:
nonlocal received_messages
print("received message:", msg)
received_messages.append(msg)

hook.receive_message(self.queue_name, message_callback=message_callback)

assert len(received_messages) == 1
assert received_messages[0] == mock_service_bus_message

@mock.patch(f"{MODULE}.MessageHook.get_conn")
def test_receive_message_exception(self, mock_sb_client):
"""
Expand Down Expand Up @@ -300,6 +330,37 @@ def test_receive_subscription_message(self, mock_sb_client):
]
mock_sb_client.assert_has_calls(expected_calls)

@mock.patch("azure.servicebus.ServiceBusReceivedMessage")
@mock.patch(f"{MODULE}.MessageHook.get_conn")
def test_receive_subscription_message_callback(self, mock_sb_client, mock_sb_message):
"""
Test `receive_subscription_message` hook function and assert the function with mock value,
mock the azure service bus `receive_message` function of subscription
"""
subscription_name = "subscription_1"
topic_name = "topic_name"
max_message_count = 10
max_wait_time = 5
hook = MessageHook(azure_service_bus_conn_id=self.conn_id)

mock_sb_client.return_value.__enter__.return_value.get_subscription_receiver.return_value.__enter__.return_value.receive_messages.return_value = [
mock_sb_message,
mock_sb_message,
]

received_messages = []

def message_callback(msg: ServiceBusMessage) -> None:
nonlocal received_messages
print("received message:", msg)
received_messages.append(msg)

hook.receive_subscription_message(
topic_name, subscription_name, max_message_count, max_wait_time, message_callback=message_callback
)

assert len(received_messages) == 2

@pytest.mark.parametrize(
"mock_subscription_name, mock_topic_name, mock_max_count, mock_wait_time",
[("subscription_1", None, None, None), (None, "topic_1", None, None)],
Expand Down
69 changes: 69 additions & 0 deletions tests/providers/microsoft/azure/operators/test_asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,30 @@ def test_receive_message_queue(self, mock_get_conn):
]
mock_get_conn.assert_has_calls(expected_calls)

@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
def test_receive_message_queue_callback(self, mock_get_conn):
"""
Test AzureServiceBusReceiveMessageOperator by mock connection, values
and the service bus receive message
"""
mock_service_bus_message = ServiceBusMessage("Test message")
mock_get_conn.return_value.__enter__.return_value.get_queue_receiver.return_value.__enter__.return_value.receive_messages.return_value = [
mock_service_bus_message
]

messages_received = []

def message_callback(msg):
messages_received.append(msg)
print(msg)

asb_receive_queue_operator = AzureServiceBusReceiveMessageOperator(
task_id="asb_receive_message_queue", queue_name=QUEUE_NAME, message_callback=message_callback
)
asb_receive_queue_operator.execute(None)
assert len(messages_received) == 1
assert messages_received[0] == mock_service_bus_message


class TestABSTopicCreateOperator:
def test_init(self):
Expand Down Expand Up @@ -430,6 +454,51 @@ def test_receive_message_queue(self, mock_get_conn):
]
mock_get_conn.assert_has_calls(expected_calls)

@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
def test_receive_message_queue_callback(self, mock_get_conn):
"""
Test ASBReceiveSubscriptionMessageOperator by mock connection, values
and the service bus receive message
"""

mock_sb_message0 = ServiceBusMessage("Test message 0")
mock_sb_message1 = ServiceBusMessage("Test message 1")
mock_get_conn.return_value.__enter__.return_value.get_subscription_receiver.return_value.__enter__.return_value.receive_messages.return_value = [
mock_sb_message0,
mock_sb_message1,
]

messages_received = []

def message_callback(msg):
messages_received.append(msg)
print(msg)

asb_subscription_receive_message = ASBReceiveSubscriptionMessageOperator(
task_id="asb_subscription_receive_message",
topic_name=TOPIC_NAME,
subscription_name=SUBSCRIPTION_NAME,
max_message_count=10,
message_callback=message_callback,
)

asb_subscription_receive_message.execute(None)
expected_calls = [
mock.call()
.__enter__()
.get_subscription_receiver(SUBSCRIPTION_NAME, TOPIC_NAME)
.__enter__()
.receive_messages(max_message_count=10, max_wait_time=5)
.get_subscription_receiver(SUBSCRIPTION_NAME, TOPIC_NAME)
.__exit__()
.mock_call()
.__exit__
]
mock_get_conn.assert_has_calls(expected_calls)
assert len(messages_received) == 2
assert messages_received[0] == mock_sb_message0
assert messages_received[1] == mock_sb_message1


class TestASBTopicDeleteOperator:
def test_init(self):
Expand Down