Skip to content

Commit

Permalink
[Cosmos] Test fixes for logging policy and retry policy (Azure#29141)
Browse files Browse the repository at this point in the history
* fix logging and retry tests

* Update test_retry_policy.py

* Update test_retry_policy.py

* Update test_retry_policy.py

* added policy to async client, fixed test

* Update test_retry_policy.py

* Update test_retry_policy.py
  • Loading branch information
simorenoh authored Mar 9, 2023
1 parent 62d2639 commit 394ddf1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .. import _default_retry_policy
from .. import _session_retry_policy
from .. import _gone_retry_policy
from .. import _timeout_failover_retry_policy


# pylint: disable=protected-access
Expand Down Expand Up @@ -70,6 +71,9 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg
client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, *args
)
partition_key_range_gone_retry_policy = _gone_retry_policy.PartitionKeyRangeGoneRetryPolicy(client, *args)
timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy(
client.connection_policy, global_endpoint_manager, *args
)

while True:
try:
Expand Down Expand Up @@ -105,6 +109,8 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg
retry_policy = sessionRetry_policy
elif exceptions._partition_range_is_gone(e):
retry_policy = partition_key_range_gone_retry_policy
elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code == StatusCodes.SERVICE_UNAVAILABLE:
retry_policy = timeout_failover_retry_policy
else:
retry_policy = defaultRetry_policy

Expand Down
22 changes: 11 additions & 11 deletions sdk/cosmos/azure-cosmos/test/test_cosmos_http_logging_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def setUpClass(cls):

def test_default_http_logging_policy(self):
# Test if we can log into from creating a database
test_db = self.client_default.create_database(id="database_test")
self.client_default.create_database(id="database_test")
assert all(m.levelname == 'INFO' for m in self.mock_handler_default.messages)
messages_request = self.mock_handler_default.messages[0].message.split("\n")
messages_response = self.mock_handler_default.messages[1].message.split("\n")
Expand All @@ -99,36 +99,36 @@ def test_default_http_logging_policy(self):

def test_cosmos_http_logging_policy(self):
# Test if we can log into from creating a database
test_db = self.client_diagnostic.create_database(id="database_test")
self.client_diagnostic.create_database(id="database_test")
assert all(m.levelname == 'INFO' for m in self.mock_handler_diagnostic.messages)
messages_request = self.mock_handler_diagnostic.messages[2].message.split("\n")
messages_response = self.mock_handler_diagnostic.messages[3].message.split("\n")
messages_request = self.mock_handler_diagnostic.messages[3].message.split("\n")
messages_response = self.mock_handler_diagnostic.messages[4].message.split("\n")
elapsed_time = self.mock_handler_diagnostic.messages[2].message.split("\n")
assert "/dbs" in messages_request[0]
assert messages_request[1] == "Request method: 'POST'"
assert 'Request headers:' in messages_request[2]
assert messages_request[13] == 'A body is sent with the request'
assert messages_response[0] == 'Response status: 201'
assert messages_response[1] == 'Response status reason: Created'
assert "Elapsed Time:" in messages_response[2]
assert "Response headers" in messages_response[3]
assert "Elapsed time in seconds:" in elapsed_time[0]
assert "Response headers" in messages_response[1]

self.mock_handler_diagnostic.reset()
# now test in case of an error
try:
test_db_error = self.client_diagnostic.create_database(id="database_test")
self.client_diagnostic.create_database(id="database_test")
except:
pass
assert all(m.levelname == 'INFO' for m in self.mock_handler_diagnostic.messages)
messages_request = self.mock_handler_diagnostic.messages[0].message.split("\n")
messages_response = self.mock_handler_diagnostic.messages[1].message.split("\n")
elapsed_time = self.mock_handler_diagnostic.messages[2].message.split("\n")
assert "/dbs" in messages_request[0]
assert messages_request[1] == "Request method: 'POST'"
assert 'Request headers:' in messages_request[2]
assert messages_request[13] == 'A body is sent with the request'
assert messages_response[0] == 'Response status: 409'
assert messages_response[1] == 'Response status reason: Conflict'
assert "Elapsed Time:" in messages_response[2]
assert "Response headers" in messages_response[14]
assert "Elapsed time in seconds:" in elapsed_time[0]
assert "Response headers" in messages_response[1]

# delete database
self.client_diagnostic.delete_database("database_test")
Expand Down
66 changes: 30 additions & 36 deletions sdk/cosmos/azure-cosmos/test/test_retry_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_resource_throttle_retry_policy_default_retry_after(self):
connection_policy = Test_retry_policy_tests.connectionPolicy
connection_policy.RetryOptions = retry_options.RetryOptions(5)

self.OriginalExecuteFunction = _retry_utility.ExecuteFunction
self.original_execute_function = _retry_utility.ExecuteFunction
try:
_retry_utility.ExecuteFunction = self._MockExecuteFunction

Expand All @@ -94,13 +94,13 @@ def test_resource_throttle_retry_policy_default_retry_after(self):
self.assertGreaterEqual(self.created_collection.client_connection.last_response_headers[HttpHeaders.ThrottleRetryWaitTimeInMs],
connection_policy.RetryOptions.MaxRetryAttemptCount * self.retry_after_in_milliseconds)
finally:
_retry_utility.ExecuteFunction = self.OriginalExecuteFunction
_retry_utility.ExecuteFunction = self.original_execute_function

def test_resource_throttle_retry_policy_fixed_retry_after(self):
connection_policy = Test_retry_policy_tests.connectionPolicy
connection_policy.RetryOptions = retry_options.RetryOptions(5, 2000)

self.OriginalExecuteFunction = _retry_utility.ExecuteFunction
self.original_execute_function = _retry_utility.ExecuteFunction
try:
_retry_utility.ExecuteFunction = self._MockExecuteFunction

Expand All @@ -117,13 +117,13 @@ def test_resource_throttle_retry_policy_fixed_retry_after(self):
connection_policy.RetryOptions.MaxRetryAttemptCount * connection_policy.RetryOptions.FixedRetryIntervalInMilliseconds)

finally:
_retry_utility.ExecuteFunction = self.OriginalExecuteFunction
_retry_utility.ExecuteFunction = self.original_execute_function

def test_resource_throttle_retry_policy_max_wait_time(self):
connection_policy = Test_retry_policy_tests.connectionPolicy
connection_policy.RetryOptions = retry_options.RetryOptions(5, 2000, 3)

self.OriginalExecuteFunction = _retry_utility.ExecuteFunction
self.original_execute_function = _retry_utility.ExecuteFunction
try:
_retry_utility.ExecuteFunction = self._MockExecuteFunction

Expand All @@ -138,7 +138,7 @@ def test_resource_throttle_retry_policy_max_wait_time(self):
self.assertGreaterEqual(self.created_collection.client_connection.last_response_headers[HttpHeaders.ThrottleRetryWaitTimeInMs],
connection_policy.RetryOptions.MaxWaitTimeInSeconds * 1000)
finally:
_retry_utility.ExecuteFunction = self.OriginalExecuteFunction
_retry_utility.ExecuteFunction = self.original_execute_function

def test_resource_throttle_retry_policy_query(self):
connection_policy = Test_retry_policy_tests.connectionPolicy
Expand All @@ -150,7 +150,7 @@ def test_resource_throttle_retry_policy_query(self):

self.created_collection.create_item(body=document_definition)

self.OriginalExecuteFunction = _retry_utility.ExecuteFunction
self.original_execute_function = _retry_utility.ExecuteFunction
try:
_retry_utility.ExecuteFunction = self._MockExecuteFunction

Expand All @@ -169,7 +169,7 @@ def test_resource_throttle_retry_policy_query(self):
self.assertGreaterEqual(self.created_collection.client_connection.last_response_headers[HttpHeaders.ThrottleRetryWaitTimeInMs],
connection_policy.RetryOptions.MaxRetryAttemptCount * self.retry_after_in_milliseconds)
finally:
_retry_utility.ExecuteFunction = self.OriginalExecuteFunction
_retry_utility.ExecuteFunction = self.original_execute_function

@pytest.mark.xfail
def test_default_retry_policy_for_query(self):
Expand All @@ -182,10 +182,9 @@ def test_default_retry_policy_for_query(self):

self.created_collection.create_item(body=document_definition_1)
self.created_collection.create_item(body=document_definition_2)

self.original_execute_function = _retry_utility.ExecuteFunction
try:
original_execute_function = _retry_utility.ExecuteFunction
mf = self.MockExecuteFunctionConnectionReset(original_execute_function)
mf = self.MockExecuteFunctionConnectionReset(self.original_execute_function)
_retry_utility.ExecuteFunction = mf

docs = self.created_collection.query_items(query="Select * from c", max_item_count=1, enable_cross_partition_query=True)
Expand All @@ -200,7 +199,7 @@ def test_default_retry_policy_for_query(self):
else:
self.assertEqual(mf.counter, 18)
finally:
_retry_utility.ExecuteFunction = original_execute_function
_retry_utility.ExecuteFunction = self.original_execute_function

self.created_collection.delete_item(item=result_docs[0], partition_key=result_docs[0]['id'])
self.created_collection.delete_item(item=result_docs[1], partition_key=result_docs[1]['id'])
Expand All @@ -211,18 +210,17 @@ def test_default_retry_policy_for_read(self):
'key': 'value'}

created_document = self.created_collection.create_item(body=document_definition)

self.original_execute_function = _retry_utility.ExecuteFunction
try:
original_execute_function = _retry_utility.ExecuteFunction
mf = self.MockExecuteFunctionConnectionReset(original_execute_function)
mf = self.MockExecuteFunctionConnectionReset(self.original_execute_function)
_retry_utility.ExecuteFunction = mf

doc = self.created_collection.read_item(item=created_document['id'], partition_key=created_document['id'])
self.assertEqual(doc['id'], 'doc')
self.assertEqual(mf.counter, 3)

finally:
_retry_utility.ExecuteFunction = original_execute_function
_retry_utility.ExecuteFunction = self.original_execute_function

self.created_collection.delete_item(item=created_document, partition_key=created_document['id'])

Expand All @@ -232,8 +230,8 @@ def test_default_retry_policy_for_create(self):
'key': 'value'}

try:
original_execute_function = _retry_utility.ExecuteFunction
mf = self.MockExecuteFunctionConnectionReset(original_execute_function)
self.original_execute_function = _retry_utility.ExecuteFunction
mf = self.MockExecuteFunctionConnectionReset(self.original_execute_function)
_retry_utility.ExecuteFunction = mf

created_document = {}
Expand All @@ -252,25 +250,26 @@ def test_default_retry_policy_for_create(self):
mf.counter = mf.counter - 3
self.assertEqual(mf.counter, 1)
finally:
_retry_utility.ExecuteFunction = original_execute_function
_retry_utility.ExecuteFunction = self.original_execute_function

def test_timeout_failover_retry_policy_for_read(self):
document_definition = {'id': 'failoverDoc',
'name': 'sample document',
'key': 'value'}

created_document = self.created_collection.create_item(body=document_definition)
self.original_execute_function = _retry_utility.ExecuteFunction
try:
original_execute_function = _retry_utility.ExecuteFunction
mf = self.MockExecuteFunctionTimeout(original_execute_function)

mf = self.MockExecuteFunctionTimeout(self.original_execute_function)
_retry_utility.ExecuteFunction = mf
doc = self.created_collection.read_item(item=created_document['id'], partition_key=created_document['id'])
self.assertEqual(doc['id'], 'doc')
try:
doc = self.created_collection.read_item(item=created_document['id'],
partition_key=created_document['id'])
self.assertEqual(doc['id'], 'doc')
except exceptions.CosmosHttpResponseError as err:
self.assertEqual(err.status_code, 408)
finally:
_retry_utility.ExecuteFunction = original_execute_function

self.created_collection.delete_item(item=created_document, partition_key=created_document['id'])
_retry_utility.ExecuteFunction = self.original_execute_function

def _MockExecuteFunction(self, function, *args, **kwargs):
response = test_config.FakeResponse({HttpHeaders.RetryAfterInMilliseconds: self.retry_after_in_milliseconds})
Expand All @@ -285,15 +284,10 @@ def __init__(self, org_func):
self.counter = 0

def __call__(self, func, *args, **kwargs):
self.counter += 1

if self.counter > 1:
return self.org_func(func, *args, **kwargs)
else:
raise exceptions.CosmosHttpResponseError(
status_code=408,
message="Timeout",
response=test_config.FakeResponse({}))
raise exceptions.CosmosHttpResponseError(
status_code=408,
message="Timeout",
response=test_config.FakeResponse({}))

class MockExecuteFunctionConnectionReset(object):

Expand Down

0 comments on commit 394ddf1

Please sign in to comment.