From 09f0ab00abf13620a27c3af8d90761d55ef857de Mon Sep 17 00:00:00 2001 From: Kashif Khan <361477+kashifkhan@users.noreply.github.com> Date: Mon, 13 Nov 2023 16:55:16 -0600 Subject: [PATCH] [AMQP] Fix Filter Set Encoding For 2 Char Length Session id (#32860) * fix for len 2 string * fix for char length * fix pylint * fix to keep the right data value * pylint * switch order * raise error * encode unit tests * get behavior in line with uamqp * modified to add any value * narrow exception * live test * changelog --- .../azure/eventhub/_pyamqp/_encode.py | 21 +++++++++------ .../pyamqp_tests/unittest/test_encode.py | 17 ++++++++++++ sdk/servicebus/azure-servicebus/CHANGELOG.md | 8 ++---- .../azure/servicebus/_pyamqp/_encode.py | 21 +++++++++------ .../azure-servicebus/tests/test_sessions.py | 27 +++++++++++++++++++ 5 files changed, 72 insertions(+), 22 deletions(-) create mode 100644 sdk/eventhub/azure-eventhub/tests/pyamqp_tests/unittest/test_encode.py diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_encode.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_encode.py index 17baa5a490f2..c1112d9f6a14 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_encode.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_encode.py @@ -765,14 +765,19 @@ def encode_filter_set(value): else: if isinstance(name, str): name = name.encode("utf-8") # type: ignore - try: - descriptor, filter_value = data - described_filter = { - TYPE: AMQPTypes.described, - VALUE: ({TYPE: AMQPTypes.symbol, VALUE: descriptor}, filter_value), - } - except ValueError: - described_filter = data + if isinstance(data, (str, bytes)): + described_filter = data # type: ignore + # handle the situation when data is a tuple or list of length 2 + else: + try: + descriptor, filter_value = data + described_filter = { + TYPE: AMQPTypes.described, + VALUE: ({TYPE: AMQPTypes.symbol, VALUE: descriptor}, filter_value), + } + # if its not a type that is known, raise the error from the server + except (ValueError, TypeError): + described_filter = data cast(List, fields[VALUE]).append( ({TYPE: AMQPTypes.symbol, VALUE: name}, described_filter) diff --git a/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/unittest/test_encode.py b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/unittest/test_encode.py new file mode 100644 index 000000000000..a73a6c266bac --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/pyamqp_tests/unittest/test_encode.py @@ -0,0 +1,17 @@ +import pytest +from azure.eventhub._pyamqp._encode import encode_filter_set + +@pytest.mark.parametrize("value,expected", [ + ({b'com.microsoft:session-filter': 'ababa'}, 'ababa'), + ({b'com.microsoft:session-filter': 'abab'}, 'abab'), + ({b'com.microsoft:session-filter': 'aba'}, 'aba'), + ({b'com.microsoft:session-filter': 'ab'}, 'ab'), + ({b'com.microsoft:session-filter': 'a'}, 'a'), + ({b'com.microsoft:session-filter': 1}, 1), +]) +def test_valid_filter_encode(value, expected): + fields = encode_filter_set(value) + assert len(fields) == 2 + assert fields['VALUE'][0][1] == expected + + diff --git a/sdk/servicebus/azure-servicebus/CHANGELOG.md b/sdk/servicebus/azure-servicebus/CHANGELOG.md index c84b170d5d67..ef6024b2e29e 100644 --- a/sdk/servicebus/azure-servicebus/CHANGELOG.md +++ b/sdk/servicebus/azure-servicebus/CHANGELOG.md @@ -1,14 +1,10 @@ # Release History -## 7.11.4 (Unreleased) - -### Features Added - -### Breaking Changes +## 7.11.4 (2023-11-13) ### Bugs Fixed -### Other Changes +- Fixed a bug where a two character count session id was being incorrectly parsed by azure amqp. ## 7.11.3 (2023-10-11) diff --git a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_encode.py b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_encode.py index 17baa5a490f2..c1112d9f6a14 100644 --- a/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_encode.py +++ b/sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_encode.py @@ -765,14 +765,19 @@ def encode_filter_set(value): else: if isinstance(name, str): name = name.encode("utf-8") # type: ignore - try: - descriptor, filter_value = data - described_filter = { - TYPE: AMQPTypes.described, - VALUE: ({TYPE: AMQPTypes.symbol, VALUE: descriptor}, filter_value), - } - except ValueError: - described_filter = data + if isinstance(data, (str, bytes)): + described_filter = data # type: ignore + # handle the situation when data is a tuple or list of length 2 + else: + try: + descriptor, filter_value = data + described_filter = { + TYPE: AMQPTypes.described, + VALUE: ({TYPE: AMQPTypes.symbol, VALUE: descriptor}, filter_value), + } + # if its not a type that is known, raise the error from the server + except (ValueError, TypeError): + described_filter = data cast(List, fields[VALUE]).append( ({TYPE: AMQPTypes.symbol, VALUE: name}, described_filter) diff --git a/sdk/servicebus/azure-servicebus/tests/test_sessions.py b/sdk/servicebus/azure-servicebus/tests/test_sessions.py index 8ddc1786af36..494e85ec2b3e 100644 --- a/sdk/servicebus/azure-servicebus/tests/test_sessions.py +++ b/sdk/servicebus/azure-servicebus/tests/test_sessions.py @@ -1300,3 +1300,30 @@ def test_session_non_session_send_to_session_queue_should_fail(self, uamqp_trans message = ServiceBusMessage("This should be an invalid non session message") with pytest.raises(ServiceBusError): sender.send_messages(message) + + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedServiceBusResourceGroupPreparer(name_prefix='servicebustest') + @CachedServiceBusNamespacePreparer(name_prefix='servicebustest') + @ServiceBusQueuePreparer(name_prefix='servicebustest', requires_session=True) + @pytest.mark.parametrize("uamqp_transport", uamqp_transport_params, ids=uamqp_transport_ids) + @ArgPasser() + def test_session_id_str_bytes(self, uamqp_transport, *, servicebus_namespace_connection_string=None, servicebus_queue=None, **kwargs): + + with ServiceBusClient.from_connection_string( + servicebus_namespace_connection_string, logging_enable=False, uamqp_transport=uamqp_transport) as sb_client: + + sessions = [] + start_time = utc_now() + for i in range(5): + sessions.append('a' * (i + 1)) + + for session_id in sessions: + with sb_client.get_queue_sender(servicebus_queue.name) as sender: + message = ServiceBusMessage("Test message no. {}".format(i), session_id=session_id) + sender.send_messages(message) + for session_id in sessions: + with sb_client.get_queue_receiver(servicebus_queue.name, session_id=session_id) as receiver: + messages = receiver.receive_messages(max_wait_time=10) + assert len(messages) == 1 + assert messages[0].session_id == session_id