Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[fix #1726] Use boto3 for SQS async requests" #1799

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 164 additions & 38 deletions kombu/asynchronous/aws/sqs/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from __future__ import annotations

from kombu.asynchronous import get_event_loop
from vine import transform

from kombu.asynchronous.aws.connection import AsyncAWSQueryConnection

from .ext import boto3
from .message import AsyncMessage
from .queue import AsyncQueue

__all__ = ('AsyncSQSConnection',)

Expand All @@ -18,50 +21,173 @@ def __init__(self, sqs_connection, debug=0, region=None, **kwargs):
raise ImportError('boto3 is not installed')
super().__init__(
sqs_connection,
region_name=region,
debug=debug,
region_name=region, debug=debug,
**kwargs
)
self.hub = kwargs.get('hub') or get_event_loop()

def _async_sqs_request(self, api, callback, *args, **kwargs):
"""Makes an asynchronous request to an SQS API.

Arguments:
---------
api -- The name of the API, e.g. 'receive_message'.
callback -- The callback to pass the response to when it is available.
*args, **kwargs -- The arguments and keyword arguments to pass to the
SQS API. Those are API dependent and can be found in the boto3
documentation.
"""
# Define a method to execute the SQS API synchronously.
def sqs_request(api, callback, args, kwargs):
method = getattr(self.sqs_connection, api)
resp = method(*args, **kwargs)
if callback:
callback(resp)

# Hand off the request to the event loop to execute it asynchronously.
self.hub.call_soon(sqs_request, api, callback, args, kwargs)

def create_queue(self, queue_name,
visibility_timeout=None, callback=None):
params = {'QueueName': queue_name}
if visibility_timeout:
params['DefaultVisibilityTimeout'] = format(
visibility_timeout, 'd',
)
return self.get_object('CreateQueue', params,
callback=callback)

def delete_queue(self, queue, force_deletion=False, callback=None):
return self.get_status('DeleteQueue', None, queue.id,
callback=callback)

def get_queue_url(self, queue):
res = self.sqs_connection.get_queue_url(QueueName=queue)
return res['QueueUrl']

def get_queue_attributes(self, queue, attribute='All', callback=None):
return self.get_object(
'GetQueueAttributes', {'AttributeName': attribute},
queue.id, callback=callback,
)

def set_queue_attribute(self, queue, attribute, value, callback=None):
return self.get_status(
'SetQueueAttribute',
{'Attribute.Name': attribute, 'Attribute.Value': value},
queue.id, callback=callback,
)

def receive_message(
self, queue_url, number_messages=1, visibility_timeout=None,
self, queue, queue_url, number_messages=1, visibility_timeout=None,
attributes=('ApproximateReceiveCount',), wait_time_seconds=None,
callback=None
):
kwargs = {
"QueueUrl": queue_url,
"MaxNumberOfMessages": number_messages,
"MessageAttributeNames": attributes,
"WaitTimeSeconds": wait_time_seconds,
}
params = {'MaxNumberOfMessages': number_messages}
if visibility_timeout:
kwargs["VisibilityTimeout"] = visibility_timeout
params['VisibilityTimeout'] = visibility_timeout
if attributes:
attrs = {}
for idx, attr in enumerate(attributes):
attrs['AttributeName.' + str(idx + 1)] = attr
params.update(attrs)
if wait_time_seconds is not None:
params['WaitTimeSeconds'] = wait_time_seconds
return self.get_list(
'ReceiveMessage', params, [('Message', AsyncMessage)],
queue_url, callback=callback, parent=queue,
)

def delete_message(self, queue, receipt_handle, callback=None):
return self.delete_message_from_handle(
queue, receipt_handle, callback,
)

def delete_message_batch(self, queue, messages, callback=None):
params = {}
for i, m in enumerate(messages):
prefix = f'DeleteMessageBatchRequestEntry.{i + 1}'
params.update({
f'{prefix}.Id': m.id,
f'{prefix}.ReceiptHandle': m.receipt_handle,
})
return self.get_object(
'DeleteMessageBatch', params, queue.id,
verb='POST', callback=callback,
)

def delete_message_from_handle(self, queue, receipt_handle,
callback=None):
return self.get_status(
'DeleteMessage', {'ReceiptHandle': receipt_handle},
queue, callback=callback,
)

def send_message(self, queue, message_content,
delay_seconds=None, callback=None):
params = {'MessageBody': message_content}
if delay_seconds:
params['DelaySeconds'] = int(delay_seconds)
return self.get_object(
'SendMessage', params, queue.id,
verb='POST', callback=callback,
)

return self._async_sqs_request('receive_message', callback, **kwargs)
def send_message_batch(self, queue, messages, callback=None):
params = {}
for i, msg in enumerate(messages):
prefix = f'SendMessageBatchRequestEntry.{i + 1}'
params.update({
f'{prefix}.Id': msg[0],
f'{prefix}.MessageBody': msg[1],
f'{prefix}.DelaySeconds': msg[2],
})
return self.get_object(
'SendMessageBatch', params, queue.id,
verb='POST', callback=callback,
)

def delete_message(self, queue_url, receipt_handle, callback=None):
return self._async_sqs_request('delete_message', callback,
QueueUrl=queue_url,
ReceiptHandle=receipt_handle)
def change_message_visibility(self, queue, receipt_handle,
visibility_timeout, callback=None):
return self.get_status(
'ChangeMessageVisibility',
{'ReceiptHandle': receipt_handle,
'VisibilityTimeout': visibility_timeout},
queue.id, callback=callback,
)

def change_message_visibility_batch(self, queue, messages, callback=None):
params = {}
for i, t in enumerate(messages):
pre = f'ChangeMessageVisibilityBatchRequestEntry.{i + 1}'
params.update({
f'{pre}.Id': t[0].id,
f'{pre}.ReceiptHandle': t[0].receipt_handle,
f'{pre}.VisibilityTimeout': t[1],
})
return self.get_object(
'ChangeMessageVisibilityBatch', params, queue.id,
verb='POST', callback=callback,
)

def get_all_queues(self, prefix='', callback=None):
params = {}
if prefix:
params['QueueNamePrefix'] = prefix
return self.get_list(
'ListQueues', params, [('QueueUrl', AsyncQueue)],
callback=callback,
)

def get_queue(self, queue_name, callback=None):
# TODO Does not support owner_acct_id argument
return self.get_all_queues(
queue_name,
transform(self._on_queue_ready, callback, queue_name),
)
lookup = get_queue

def _on_queue_ready(self, name, queues):
return next(
(q for q in queues if q.url.endswith(name)), None,
)

def get_dead_letter_source_queues(self, queue, callback=None):
return self.get_list(
'ListDeadLetterSourceQueues', {'QueueUrl': queue.url},
[('QueueUrl', AsyncQueue)],
callback=callback,
)

def add_permission(self, queue, label, aws_account_id, action_name,
callback=None):
return self.get_status(
'AddPermission',
{'Label': label,
'AWSAccountId': aws_account_id,
'ActionName': action_name},
queue.id, callback=callback,
)

def remove_permission(self, queue, label, callback=None):
return self.get_status(
'RemovePermission', {'Label': label}, queue.id, callback=callback,
)
130 changes: 130 additions & 0 deletions kombu/asynchronous/aws/sqs/queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Amazon SQS queue implementation."""

from __future__ import annotations

from vine import transform

from .message import AsyncMessage

_all__ = ['AsyncQueue']


def list_first(rs):
"""Get the first item in a list, or None if list empty."""
return rs[0] if len(rs) == 1 else None


class AsyncQueue:
"""Async SQS Queue."""

def __init__(self, connection=None, url=None, message_class=AsyncMessage):
self.connection = connection
self.url = url
self.message_class = message_class
self.visibility_timeout = None

def _NA(self, *args, **kwargs):
raise NotImplementedError()
count_slow = dump = save_to_file = save_to_filename = save = \
save_to_s3 = load_from_s3 = load_from_file = load_from_filename = \
load = clear = _NA

def get_attributes(self, attributes='All', callback=None):
return self.connection.get_queue_attributes(
self, attributes, callback,
)

def set_attribute(self, attribute, value, callback=None):
return self.connection.set_queue_attribute(
self, attribute, value, callback,
)

def get_timeout(self, callback=None, _attr='VisibilityTimeout'):
return self.get_attributes(
_attr, transform(
self._coerce_field_value, callback, _attr, int,
),
)

def _coerce_field_value(self, key, type, response):
return type(response[key])

def set_timeout(self, visibility_timeout, callback=None):
return self.set_attribute(
'VisibilityTimeout', visibility_timeout,
transform(
self._on_timeout_set, callback,
)
)

def _on_timeout_set(self, visibility_timeout):
if visibility_timeout:
self.visibility_timeout = visibility_timeout
return self.visibility_timeout

def add_permission(self, label, aws_account_id, action_name,
callback=None):
return self.connection.add_permission(
self, label, aws_account_id, action_name, callback,
)

def remove_permission(self, label, callback=None):
return self.connection.remove_permission(self, label, callback)

def read(self, visibility_timeout=None, wait_time_seconds=None,
callback=None):
return self.get_messages(
1, visibility_timeout,
wait_time_seconds=wait_time_seconds,
callback=transform(list_first, callback),
)

def write(self, message, delay_seconds=None, callback=None):
return self.connection.send_message(
self, message.get_body_encoded(), delay_seconds,
callback=transform(self._on_message_sent, callback, message),
)

def write_batch(self, messages, callback=None):
return self.connection.send_message_batch(
self, messages, callback=callback,
)

def _on_message_sent(self, orig_message, new_message):
orig_message.id = new_message.id
orig_message.md5 = new_message.md5
return new_message

def get_messages(self, num_messages=1, visibility_timeout=None,
attributes=None, wait_time_seconds=None, callback=None):
return self.connection.receive_message(
self, number_messages=num_messages,
visibility_timeout=visibility_timeout,
attributes=attributes,
wait_time_seconds=wait_time_seconds,
callback=callback,
)

def delete_message(self, message, callback=None):
return self.connection.delete_message(self, message, callback)

def delete_message_batch(self, messages, callback=None):
return self.connection.delete_message_batch(
self, messages, callback=callback,
)

def change_message_visibility_batch(self, messages, callback=None):
return self.connection.change_message_visibility_batch(
self, messages, callback=callback,
)

def delete(self, callback=None):
return self.connection.delete_queue(self, callback=callback)

def count(self, page_size=10, vtimeout=10, callback=None,
_attr='ApproximateNumberOfMessages'):
return self.get_attributes(
_attr, callback=transform(
self._coerce_field_value, callback, _attr, int,
),
)
12 changes: 9 additions & 3 deletions kombu/transport/SQS.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,10 @@ def _get_from_sqs(self, queue_name, queue_url,
Uses long polling and returns :class:`~vine.promises.promise`.
"""
return connection.receive_message(
queue_url,
number_messages=count,
queue_name, queue_url, number_messages=count,
wait_time_seconds=self.wait_time_seconds,
callback=callback)
callback=callback,
)

def _restore(self, message,
unwanted_delivery_info=('sqs_message', 'sqs_queue')):
Expand Down Expand Up @@ -674,6 +674,12 @@ def _purge(self, queue):

def close(self):
super().close()
# if self._asynsqs:
# try:
# self.asynsqs().close()
# except AttributeError as exc: # FIXME ???
# if "can't set attribute" not in str(exc):
# raise

def new_sqs_client(self, region, access_key_id,
secret_access_key, session_token=None):
Expand Down
Loading