Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cosmos] Reconfigure retry policy #7544

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def send(self, request, **kwargs): # type: ignore
allow_redirects=False,
**kwargs)

except urllib3.exceptions.NewConnectionError as err:
except (urllib3.exceptions.NewConnectionError, urllib3.exceptions.ConnectTimeoutError) as err:
error = ServiceRequestError(err, error=err)
except requests.exceptions.ReadTimeout as err:
error = ServiceResponseError(err, error=err)
Expand Down
2 changes: 2 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from ._retry_utility import ConnectionRetryPolicy
from .container import ContainerProxy
from .cosmos_client import CosmosClient
from .database import DatabaseProxy
Expand Down Expand Up @@ -56,5 +57,6 @@
"SSLConfiguration",
"TriggerOperation",
"TriggerType",
"ConnectionRetryPolicy",
)
__version__ = VERSION
37 changes: 24 additions & 13 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@
"""
from typing import Dict, Any, Optional
import six
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry # pylint: disable=import-error
from azure.core.paging import ItemPaged # type: ignore
from azure.core import PipelineClient # type: ignore
from azure.core.pipeline.transport import RequestsTransport
from azure.core.pipeline.policies import ( # type: ignore
HTTPPolicy,
ContentDecodePolicy,
HeadersPolicy,
UserAgentPolicy,
Expand All @@ -51,6 +50,7 @@
from . import _synchronized_request as synchronized_request
from . import _global_endpoint_manager as global_endpoint_manager
from ._routing import routing_map_provider
from ._retry_utility import ConnectionRetryPolicy
from . import _session
from . import _utils
from .partition_key import _Undefined, _Empty
Expand Down Expand Up @@ -151,15 +151,24 @@ def __init__(
self._useMultipleWriteLocations = False
self._global_endpoint_manager = global_endpoint_manager._GlobalEndpointManager(self)

# creating a requests session used for connection pooling and re-used by all requests
requests_session = requests.Session()

transport = None
if self.connection_policy.ConnectionRetryConfiguration is not None:
adapter = HTTPAdapter(max_retries=self.connection_policy.ConnectionRetryConfiguration)
requests_session.mount('http://', adapter)
requests_session.mount('https://', adapter)
transport = RequestsTransport(session=requests_session)
retry_policy = None
if isinstance(self.connection_policy.ConnectionRetryConfiguration, HTTPPolicy):
retry_policy = self.connection_policy.ConnectionRetryConfiguration
elif isinstance(self.connection_policy.ConnectionRetryConfiguration, int):
retry_policy = ConnectionRetryPolicy(total=self.connection_policy.ConnectionRetryConfiguration)
elif isinstance(self.connection_policy.ConnectionRetryConfiguration, Retry):
# Convert a urllib3 retry policy to a Pipeline policy
retry_policy = ConnectionRetryPolicy(
retry_total=self.connection_policy.ConnectionRetryConfiguration.total,
retry_connect=self.connection_policy.ConnectionRetryConfiguration.connect,
retry_read=self.connection_policy.ConnectionRetryConfiguration.read,
retry_status=self.connection_policy.ConnectionRetryConfiguration.status,
retry_backoff_max=self.connection_policy.ConnectionRetryConfiguration.BACKOFF_MAX,
retry_on_status_codes=list(self.connection_policy.ConnectionRetryConfiguration.status_forcelist),
retry_backoff_factor=self.connection_policy.ConnectionRetryConfiguration.backoff_factor
)
bryevdv marked this conversation as resolved.
Show resolved Hide resolved
else:
TypeError("Unsupported retry policy. Must be an azure.cosmos.ConnectionRetryPolicy, int, or urllib3.Retry")

proxies = kwargs.pop('proxies', {})
if self.connection_policy.ProxyConfiguration and self.connection_policy.ProxyConfiguration.Host:
Expand All @@ -173,11 +182,13 @@ def __init__(
ProxyPolicy(proxies=proxies),
UserAgentPolicy(base_user_agent=_utils.get_user_agent(), **kwargs),
ContentDecodePolicy(),
retry_policy,
CustomHookPolicy(**kwargs),
DistributedTracingPolicy(),
NetworkTraceLoggingPolicy(**kwargs),
]

transport = kwargs.pop("transport", None)
self.pipeline_client = PipelineClient(url_connection, "empty-config", transport=transport, policies=policies)

# Query compatibility mode.
Expand All @@ -188,7 +199,7 @@ def __init__(
# Routing map provider
self._routing_map_provider = routing_map_provider.SmartRoutingMapProvider(self)

database_account = self._global_endpoint_manager._GetDatabaseAccount()
database_account = self._global_endpoint_manager._GetDatabaseAccount(**kwargs)
self._global_endpoint_manager.force_refresh(database_account)

@property
Expand Down
18 changes: 9 additions & 9 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ def force_refresh(self, database_account):
self.refresh_needed = True
self.refresh_endpoint_list(database_account)

def refresh_endpoint_list(self, database_account):
def refresh_endpoint_list(self, database_account, **kwargs):
with self.refresh_lock:
# if refresh is not needed or refresh is already taking place, return
if not self.refresh_needed:
return
try:
self._refresh_endpoint_list_private(database_account)
self._refresh_endpoint_list_private(database_account, **kwargs)
except Exception as e:
raise e

def _refresh_endpoint_list_private(self, database_account=None):
def _refresh_endpoint_list_private(self, database_account=None, **kwargs):
if database_account:
self.location_cache.perform_on_database_account_read(database_account)
self.refresh_needed = False
Expand All @@ -107,18 +107,18 @@ def _refresh_endpoint_list_private(self, database_account=None):
and self.location_cache.current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms
):
if not database_account:
database_account = self._GetDatabaseAccount()
database_account = self._GetDatabaseAccount(**kwargs)
self.location_cache.perform_on_database_account_read(database_account)
self.last_refresh_time = self.location_cache.current_time_millis()
self.refresh_needed = False

def _GetDatabaseAccount(self):
def _GetDatabaseAccount(self, **kwargs):
"""Gets the database account first by using the default endpoint, and if that doesn't returns
use the endpoints for the preferred locations in the order they are specified to get
the database account.
"""
try:
database_account = self._GetDatabaseAccountStub(self.DefaultEndpoint)
database_account = self._GetDatabaseAccountStub(self.DefaultEndpoint, **kwargs)
return database_account
# If for any reason(non-globaldb related), we are not able to get the database
# account from the above call to GetDatabaseAccount, we would try to get this
Expand All @@ -130,18 +130,18 @@ def _GetDatabaseAccount(self):
for location_name in self.PreferredLocations:
locational_endpoint = _GlobalEndpointManager.GetLocationalEndpoint(self.DefaultEndpoint, location_name)
try:
database_account = self._GetDatabaseAccountStub(locational_endpoint)
database_account = self._GetDatabaseAccountStub(locational_endpoint, **kwargs)
return database_account
except errors.CosmosHttpResponseError:
pass

return None

def _GetDatabaseAccountStub(self, endpoint):
def _GetDatabaseAccountStub(self, endpoint, **kwargs):
"""Stub for getting database account from the client
which can be used for mocking purposes as well.
"""
return self.Client.GetDatabaseAccount(endpoint)
return self.Client.GetDatabaseAccount(endpoint, **kwargs)

@staticmethod
def GetLocationalEndpoint(default_endpoint, location_name):
Expand Down
88 changes: 88 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

import time

from azure.core.exceptions import AzureError, ClientAuthenticationError
from azure.core.pipeline.policies import RetryPolicy

from . import errors
from . import _endpoint_discovery_retry_policy
from . import _resource_throttle_retry_policy
Expand Down Expand Up @@ -64,6 +67,8 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs):
)
while True:
try:
client_timeout = kwargs.get('timeout')
start_time = time.time()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be outside the loop?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @johanste
I don't think so - we calculate the time spent for each loop iteration and deduct it from the absolute timeout, which we then feed back into kwargs on line 122 - so I think it's correct to do this with each iteration

if args:
result = ExecuteFunction(function, global_endpoint_manager, *args, **kwargs)
else:
Expand Down Expand Up @@ -113,9 +118,92 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs):

# Wait for retry_after_in_milliseconds time before the next retry
time.sleep(retry_policy.retry_after_in_milliseconds / 1000.0)
if client_timeout:
kwargs['timeout'] = client_timeout - (time.time() - start_time)
if kwargs['timeout'] <= 0:
raise errors.CosmosClientTimeoutError()


def ExecuteFunction(function, *args, **kwargs):
""" Stub method so that it can be used for mocking purposes as well.
"""
return function(*args, **kwargs)


def _configure_timeout(request, absolute, per_request):
# type: (azure.core.pipeline.PipelineRequest, Optional[int], int) -> Optional[AzureError]
if absolute is not None:
if absolute <= 0:
raise errors.CosmosClientTimeoutError()
if per_request:
# Both socket timeout and client timeout have been provided - use the shortest value.
request.context.options['connection_timeout'] = min(per_request, absolute)
else:
# Only client timeout provided.
request.context.options['connection_timeout'] = absolute
elif per_request:
# Only socket timeout provided.
request.context.options['connection_timeout'] = per_request


class ConnectionRetryPolicy(RetryPolicy):

def __init__(self, **kwargs):
clean_kwargs = {k: v for k, v in kwargs.items() if v is not None}
super(ConnectionRetryPolicy, self).__init__(**clean_kwargs)

def send(self, request):
"""Sends the PipelineRequest object to the next policy. Uses retry settings if necessary.
Also enforces an absolute client-side timeout that spans multiple retry attempts.

:param request: The PipelineRequest object
:type request: ~azure.core.pipeline.PipelineRequest
:return: Returns the PipelineResponse or raises error if maximum retries exceeded.
:rtype: ~azure.core.pipeline.PipelineResponse
:raises: ~azure.core.exceptions.AzureError if maximum retries exceeded.
:raises: ~azure.cosmos.CosmosClientTimeoutError if specified timeout exceeded.
:raises: ~azure.core.exceptions.ClientAuthenticationError if authentication
"""
absolute_timeout = request.context.options.pop('timeout', None)
per_request_timeout = request.context.options.pop('connection_timeout', 0)

retry_error = None
retry_active = True
response = None
retry_settings = self.configure_retries(request.context.options)
while retry_active:
try:
start_time = time.time()
_configure_timeout(request, absolute_timeout, per_request_timeout)

response = self.next.send(request)
if self.is_retry(retry_settings, response):
retry_active = self.increment(retry_settings, response=response)
if retry_active:
self.sleep(retry_settings, request.context.transport, response=response)
continue
break
except ClientAuthenticationError: # pylint:disable=try-except-raise
# the authentication policy failed such that the client's request can't
# succeed--we'll never have a response to it, so propagate the exception
raise
except errors.CosmosClientTimeoutError as timeout_error:
timeout_error.inner_exception = retry_error
timeout_error.response = response
timeout_error.history = retry_settings['history']
raise
except AzureError as err:
retry_error = err
if self._is_method_retryable(retry_settings, request.http_request):
retry_active = self.increment(retry_settings, response=request, error=err)
if retry_active:
self.sleep(retry_settings, request.context.transport)
continue
raise err
finally:
end_time = time.time()
if absolute_timeout:
absolute_timeout -= (end_time - start_time)

self.update_context(response.context, retry_settings)
return response
11 changes: 9 additions & 2 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""

import json
import time

from six.moves.urllib.parse import urlparse
import six
Expand Down Expand Up @@ -96,7 +97,13 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin
connection_timeout = kwargs.pop("connection_timeout", connection_timeout / 1000.0)

# Every request tries to perform a refresh
global_endpoint_manager.refresh_endpoint_list(None)
client_timeout = kwargs.get('timeout')
start_time = time.time()
global_endpoint_manager.refresh_endpoint_list(None, **kwargs)
if client_timeout is not None:
kwargs['timeout'] = client_timeout - (time.time() - start_time)
if kwargs['timeout'] <= 0:
raise errors.CosmosClientTimeoutError()

if request_params.endpoint_override:
base_url = request_params.endpoint_override
Expand Down Expand Up @@ -149,7 +156,7 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin
return (response.stream_download(pipeline_client._pipeline), headers)

data = response.body()
if not six.PY2:
if data and not six.PY2:
# python 3 compatible: convert data from byte to unicode string
data = data.decode("utf-8")

Expand Down
26 changes: 23 additions & 3 deletions sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
"""Create, read, and delete databases in the Azure Cosmos DB SQL API service.
"""

from typing import Any, Dict, Mapping, Optional, Union, cast, Iterable, List
from typing import Any, Dict, Mapping, Optional, Union, cast, Iterable, List # pylint: disable=unused-import

import six
from azure.core.tracing.decorator import distributed_trace # type: ignore

from ._cosmos_client_connection import CosmosClientConnection
from ._base import build_options
from ._retry_utility import ConnectionRetryPolicy
from .database import DatabaseProxy
from .documents import ConnectionPolicy, DatabaseAccount
from .errors import CosmosResourceNotFoundError
Expand Down Expand Up @@ -96,11 +97,25 @@ def _build_connection_policy(kwargs):

# Retry config
retry = kwargs.pop('retry_options', None) or policy.RetryOptions
retry._max_retry_attempt_count = kwargs.pop('retry_total', None) or retry._max_retry_attempt_count
total_retries = kwargs.pop('retry_total', None)
retry._max_retry_attempt_count = total_retries or retry._max_retry_attempt_count
retry._fixed_retry_interval_in_milliseconds = kwargs.pop('retry_fixed_interval', None) or \
retry._fixed_retry_interval_in_milliseconds
retry._max_wait_time_in_seconds = kwargs.pop('retry_backoff_max', None) or retry._max_wait_time_in_seconds
max_backoff = kwargs.pop('retry_backoff_max', None)
retry._max_wait_time_in_seconds = max_backoff or retry._max_wait_time_in_seconds
policy.RetryOptions = retry
connection_retry = kwargs.pop('connection_retry_policy', None) or policy.ConnectionRetryConfiguration
if not connection_retry:
connection_retry = ConnectionRetryPolicy(
retry_total=total_retries,
retry_connect=kwargs.pop('retry_connect', None),
retry_read=kwargs.pop('retry_read', None),
retry_status=kwargs.pop('retry_status', None),
retry_backoff_max=max_backoff,
retry_on_status_codes=kwargs.pop('retry_on_status_codes', []),
retry_backoff_factor=kwargs.pop('retry_backoff_factor', 0.8),
)
policy.ConnectionRetryConfiguration = connection_retry

return policy

Expand Down Expand Up @@ -130,6 +145,11 @@ class CosmosClient(object):
*retry_total* - Maximum retry attempts.
*retry_backoff_max* - Maximum retry wait time in seconds.
*retry_fixed_interval* - Fixed retry interval in milliseconds.
*retry_read* - Maximum number of socket read retry attempts.
*retry_connect* - Maximum number of connection error retry attempts.
*retry_status* - Maximum number of retry attempts on error status codes.
*retry_on_status_codes* - A list of specific status codes to retry on.
*retry_backoff_factor* - Factor to calculate wait time between retry attempts.
*enable_endpoint_discovery* - Enable endpoint discovery for geo-replicated database accounts. Default is True.
*preferred_locations* - The preferred locations for geo-replicated database accounts.
When `enable_endpoint_discovery` is true and `preferred_locations` is non-empty,
Expand Down
6 changes: 4 additions & 2 deletions sdk/cosmos/azure-cosmos/azure/cosmos/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,10 @@ class ConnectionPolicy(object): # pylint: disable=too-many-instance-attributes
:ivar boolean UseMultipleWriteLocations:
Flag to enable writes on any locations (regions) for geo-replicated database accounts
in the azure Cosmos service.
:ivar (int or requests.packages.urllib3.util.retry) ConnectionRetryConfiguration:
Retry Configuration to be used for urllib3 connection retries.
:ivar ConnectionRetryConfiguration:
Retry Configuration to be used for connection retries.
:vartype ConnectionRetryConfiguration:
int or azure.cosmos.ConnectionRetryPolicy or requests.packages.urllib3.util.retry
"""

__defaultRequestTimeout = 60000 # milliseconds
Expand Down
10 changes: 10 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,13 @@ class CosmosResourceExistsError(ResourceExistsError, CosmosHttpResponseError):

class CosmosAccessConditionFailedError(CosmosHttpResponseError):
"""An error response with status code 412."""


class CosmosClientTimeoutError(AzureError):
"""An operation failed to complete within the specified timeout."""

def __init__(self, **kwargs):
message = "Client operation failed to complete within specified timeout."
self.response = None
self.history = None
super(CosmosClientTimeoutError, self).__init__(message, **kwargs)
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
-e ../../../tools/azure-sdk-tools
-e ../../core/azure-core
Loading