diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 265d108e178b..41d9adab68e9 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -178,6 +178,18 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches if options.get("offerThroughput"): headers[http_constants.HttpHeaders.OfferThroughput] = options["offerThroughput"] + if options.get("contentType"): + headers[http_constants.HttpHeaders.ContentType] = options['contentType'] + + if options.get("isQueryPlanRequest"): + headers[http_constants.HttpHeaders.IsQueryPlanRequest] = options['isQueryPlanRequest'] + + if options.get("supportedQueryFeatures"): + headers[http_constants.HttpHeaders.SupportedQueryFeatures] = options['supportedQueryFeatures'] + + if options.get("queryVersion"): + headers[http_constants.HttpHeaders.QueryVersion] = options['queryVersion'] + if "partitionKey" in options: # if partitionKey value is Undefined, serialize it as [{}] to be consistent with other SDKs. if options.get("partitionKey") is partition_key._Undefined: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index eaeff4440c36..a1360e24aa54 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2625,6 +2625,7 @@ def __QueryFeed( options=None, partition_key_range_id=None, response_hook=None, + is_query_plan=False, **kwargs ): """Query for more than one Azure Cosmos resources. @@ -2639,6 +2640,9 @@ def __QueryFeed( The request options for the request. :param str partition_key_range_id: Specifies partition key range id. + :param function response_hook: + :param bool is_query_plan: + Specififes if the call is to fetch query plan :rtype: list @@ -2664,7 +2668,8 @@ def __GetBodiesFromQueryResult(result): # Copy to make sure that default_headers won't be changed. if query is None: # Query operations will use ReadEndpoint even though it uses GET(for feed requests) - request_params = _request_object.RequestObject(typ, documents._OperationType.ReadFeed) + request_params = _request_object.RequestObject(typ, + documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed) headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, options, partition_key_range_id) result, self.last_response_headers = self.__Get(path, request_params, headers, **kwargs) if response_hook: @@ -2674,6 +2679,9 @@ def __GetBodiesFromQueryResult(result): query = self.__CheckAndUnifyQueryFormat(query) initial_headers[http_constants.HttpHeaders.IsQuery] = "true" + if not is_query_plan: + initial_headers[http_constants.HttpHeaders.IsQuery] = "true" + if ( self._query_compatibility_mode == CosmosClientConnection._QueryCompatibilityMode.Default or self._query_compatibility_mode == CosmosClientConnection._QueryCompatibilityMode.Query @@ -2694,6 +2702,36 @@ def __GetBodiesFromQueryResult(result): return __GetBodiesFromQueryResult(result) + def _GetQueryPlanThroughGateway(self, query, resource_link, **kwargs): + supported_query_features = (documents._QueryFeature.Aggregate + "," + + documents._QueryFeature.CompositeAggregate + "," + + documents._QueryFeature.Distinct + "," + + documents._QueryFeature.MultipleOrderBy + "," + + documents._QueryFeature.OffsetAndLimit + "," + + documents._QueryFeature.OrderBy + "," + + documents._QueryFeature.Top) + + options = { + "contentType": runtime_constants.MediaTypes.Json, + "isQueryPlanRequest": True, + "supportedQueryFeatures": supported_query_features, + "queryVersion": http_constants.Versions.QueryVersion + } + + resource_link = base.TrimBeginningAndEndingSlashes(resource_link) + path = base.GetPathFromLink(resource_link, "docs") + resource_id = base.GetResourceIdOrFullNameFromLink(resource_link) + + return self.__QueryFeed(path, + "docs", + resource_id, + lambda r: r, + None, + query, + options, + is_query_plan=True, + **kwargs) + def __CheckAndUnifyQueryFormat(self, query_body): """Checks and unifies the format of the query body. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_default_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_default_retry_policy.py index 6b5e52769193..0d8a8c7eb9ac 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_default_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_default_retry_policy.py @@ -57,7 +57,8 @@ def __init__(self, *args): def needsRetry(self, error_code): if error_code in DefaultRetryPolicy.CONNECTION_ERROR_CODES: if self.args: - if (self.args[3].method == "GET") or (http_constants.HttpHeaders.IsQuery in self.args[3].headers): + if (self.args[3].method == "GET") or (http_constants.HttpHeaders.IsQuery in self.args[3].headers) \ + or (http_constants.HttpHeaders.IsQueryPlanRequest in self.args[3].headers): return True return False return True diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py index 8fec53bad54a..ee4981d06bd8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py @@ -25,7 +25,6 @@ from collections import deque from .. import _retry_utility from .. import http_constants -from .. import _base # pylint: disable=protected-access @@ -171,100 +170,3 @@ def __init__(self, client, options, fetch_function): def _fetch_next_block(self): while super(_DefaultQueryExecutionContext, self)._has_more_pages() and not self._buffer: return self._fetch_items_helper_with_retries(self._fetch_function) - - -class _MultiCollectionQueryExecutionContext(_QueryExecutionContextBase): - """ - This class is used if it is client side partitioning - """ - - def __init__(self, client, options, database_link, query, partition_key): - """ - Constructor - :param CosmosClient client: - :param dict options: - The request options for the request. - :param str database_link: database self link or ID based link - :param (str or dict) query: - Partition_key (str): partition key for the query - - """ - super(_MultiCollectionQueryExecutionContext, self).__init__(client, options) - - self._current_collection_index = 0 - self._collection_links = [] - self._collection_links_length = 0 - - self._query = query - self._client = client - - partition_resolver = client.GetPartitionResolver(database_link) - - if partition_resolver is None: - raise ValueError(client.PartitionResolverErrorMessage) - - self._collection_links = partition_resolver.ResolveForRead(partition_key) - - self._collection_links_length = len(self._collection_links) - - if self._collection_links is None: - raise ValueError("_collection_links is None.") - - if self._collection_links_length <= 0: - raise ValueError("_collection_links_length is not greater than 0.") - - # Creating the QueryFeed for the first collection - path = _base.GetPathFromLink(self._collection_links[self._current_collection_index], "docs") - collection_id = _base.GetResourceIdOrFullNameFromLink(self._collection_links[self._current_collection_index]) - - self._current_collection_index += 1 - - def fetch_fn(options): - return client.QueryFeed(path, collection_id, query, options) - - self._fetch_function = fetch_fn - - def _has_more_pages(self): - return ( - not self._has_started - or self._continuation - or (self._collection_links and self._current_collection_index < self._collection_links_length) - ) - - def _fetch_next_block(self): - """Fetches the next block of query results. - - This iterates fetches the next block of results from the current collection link. - Once the current collection results were exhausted. It moves to the next collection link. - - :return: - List of fetched items. - :rtype: list - """ - # Fetch next block of results by executing the query against the current document collection - fetched_items = self._fetch_items_helper_with_retries(self._fetch_function) - - # If there are multiple document collections to query for(in case of partitioning), - # keep looping through each one of them, creating separate feed queries for each - # collection and fetching the items - while not fetched_items: - if self._collection_links and self._current_collection_index < self._collection_links_length: - path = _base.GetPathFromLink(self._collection_links[self._current_collection_index], "docs") - collection_id = _base.GetResourceIdOrFullNameFromLink( - self._collection_links[self._current_collection_index] - ) - - self._continuation = None - self._has_started = False - - def fetch_fn(options): - return self._client.QueryFeed(path, collection_id, self._query, options) - - self._fetch_function = fetch_fn - - fetched_items = self._fetch_items_helper_with_retries(self._fetch_function) - self._current_collection_index += 1 - else: - break - - return fetched_items diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/endpoint_component.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/endpoint_component.py index d876abaee8c0..254c4264224d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/endpoint_component.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/endpoint_component.py @@ -22,6 +22,10 @@ """Internal class for query execution endpoint component implementation in the Azure Cosmos database service. """ import numbers +import copy +import hashlib +import json +import six from azure.cosmos._execution_context.aggregators import ( _AverageAggregator, @@ -75,6 +79,86 @@ def next(self): raise StopIteration +class _QueryExecutionDistinctOrderedEndpointComponent(_QueryExecutionEndpointComponent): + """Represents an endpoint in handling distinct query. + + It returns only those values not already returned. + """ + def __init__(self, execution_context): + super(_QueryExecutionDistinctOrderedEndpointComponent, self).__init__(execution_context) + self.last_result = None + + def next(self): + res = next(self._execution_context) + while self.last_result == res: + res = next(self._execution_context) + self.last_result = res + return res + + +class _QueryExecutionDistinctUnorderedEndpointComponent(_QueryExecutionEndpointComponent): + """Represents an endpoint in handling distinct query. + + It returns only those values not already returned. + """ + def __init__(self, execution_context): + super(_QueryExecutionDistinctUnorderedEndpointComponent, self).__init__(execution_context) + self.last_result = set() + + def make_hash(self, value): + if isinstance(value, (set, tuple, list)): + return tuple([self.make_hash(v) for v in value]) + if not isinstance(value, dict): + if isinstance(value, numbers.Number): + return float(value) + return value + new_value = copy.deepcopy(value) + for k, v in new_value.items(): + new_value[k] = self.make_hash(v) + + return tuple(frozenset(sorted(new_value.items()))) + + def next(self): + res = next(self._execution_context) + + json_repr = json.dumps(self.make_hash(res)) + if six.PY3: + json_repr = json_repr.encode("utf-8") + + hash_object = hashlib.sha1(json_repr) + hashed_result = hash_object.hexdigest() + + while hashed_result in self.last_result: + res = next(self._execution_context) + json_repr = json.dumps(self.make_hash(res)) + if six.PY3: + json_repr = json_repr.encode("utf-8") + + hash_object = hashlib.sha1(json_repr) + hashed_result = hash_object.hexdigest() + self.last_result.add(hashed_result) + return res + + +class _QueryExecutionOffsetEndpointComponent(_QueryExecutionEndpointComponent): + """Represents an endpoint in handling offset query. + + It returns results offset by as many results as offset arg specified. + """ + def __init__(self, execution_context, offset_count): + super(_QueryExecutionOffsetEndpointComponent, self).__init__(execution_context) + self._offset_count = offset_count + + def next(self): + while self._offset_count > 0: + res = next(self._execution_context) + if res is not None: + self._offset_count -= 1 + else: + raise StopIteration + return next(self._execution_context) + + class _QueryExecutionAggregateEndpointComponent(_QueryExecutionEndpointComponent): """Represents an endpoint in handling aggregate query. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 2c37510a63a4..49a5c14befb5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -25,11 +25,12 @@ import json from six.moves import xrange from azure.cosmos.errors import CosmosHttpResponseError +from azure.cosmos._execution_context import multi_execution_aggregator from azure.cosmos._execution_context.base_execution_context import _QueryExecutionContextBase from azure.cosmos._execution_context.base_execution_context import _DefaultQueryExecutionContext from azure.cosmos._execution_context.query_execution_info import _PartitionedQueryExecutionInfo from azure.cosmos._execution_context import endpoint_component -from azure.cosmos._execution_context import multi_execution_aggregator +from azure.cosmos.documents import _DistinctType from azure.cosmos.http_constants import StatusCodes, SubStatusCodes # pylint: disable=protected-access @@ -78,7 +79,9 @@ def next(self): return next(self._execution_context) except CosmosHttpResponseError as e: if _is_partitioned_execution_info(e): - query_execution_info = _get_partitioned_execution_info(e) + query_to_use = self._query if self._query is not None else "Select * from root r" + query_execution_info = _PartitionedQueryExecutionInfo(self._client._GetQueryPlanThroughGateway + (query_to_use, self._resource_link)) self._execution_context = self._create_pipelined_execution_context(query_execution_info) else: raise e @@ -99,7 +102,9 @@ def fetch_next_block(self): return self._execution_context.fetch_next_block() except CosmosHttpResponseError as e: if _is_partitioned_execution_info(e): - query_execution_info = _get_partitioned_execution_info(e) + query_to_use = self._query if self._query is not None else "Select * from root r" + query_execution_info = _PartitionedQueryExecutionInfo(self._client._GetQueryPlanThroughGateway + (query_to_use, self._resource_link)) self._execution_context = self._create_pipelined_execution_context(query_execution_info) else: raise e @@ -108,14 +113,20 @@ def fetch_next_block(self): def _create_pipelined_execution_context(self, query_execution_info): - assert self._resource_link, "code bug, resource_link has is required." - execution_context_aggregator = multi_execution_aggregator._MultiExecutionContextAggregator( - self._client, self._resource_link, self._query, self._options, query_execution_info - ) - return _PipelineExecutionContext( - self._client, self._options, execution_context_aggregator, query_execution_info - ) - + assert self._resource_link, "code bug, resource_link is required." + if query_execution_info.has_aggregates() and not query_execution_info.has_select_value(): + if self._options and ("enableCrossPartitionQuery" in self._options + and self._options["enableCrossPartitionQuery"]): + raise CosmosHttpResponseError(StatusCodes.BAD_REQUEST, + "Cross partition query only supports 'VALUE ' for aggregates") + + execution_context_aggregator = multi_execution_aggregator._MultiExecutionContextAggregator(self._client, + self._resource_link, + self._query, + self._options, + query_execution_info) + return _PipelineExecutionContext(self._client, self._options, execution_context_aggregator, + query_execution_info) class _PipelineExecutionContext(_QueryExecutionContextBase): # pylint: disable=abstract-method @@ -140,13 +151,28 @@ def __init__(self, client, options, execution_context, query_execution_info): if order_by: self._endpoint = endpoint_component._QueryExecutionOrderByEndpointComponent(self._endpoint) + aggregates = query_execution_info.get_aggregates() + if aggregates: + self._endpoint = endpoint_component._QueryExecutionAggregateEndpointComponent(self._endpoint, aggregates) + + offset = query_execution_info.get_offset() + if offset is not None: + self._endpoint = endpoint_component._QueryExecutionOffsetEndpointComponent(self._endpoint, offset) + top = query_execution_info.get_top() if top is not None: self._endpoint = endpoint_component._QueryExecutionTopEndpointComponent(self._endpoint, top) - aggregates = query_execution_info.get_aggregates() - if aggregates: - self._endpoint = endpoint_component._QueryExecutionAggregateEndpointComponent(self._endpoint, aggregates) + limit = query_execution_info.get_limit() + if limit is not None: + self._endpoint = endpoint_component._QueryExecutionTopEndpointComponent(self._endpoint, limit) + + distinct_type = query_execution_info.get_distinct_type() + if distinct_type != _DistinctType.NoneType: + if distinct_type == _DistinctType.Ordered: + self._endpoint = endpoint_component._QueryExecutionDistinctOrderedEndpointComponent(self._endpoint) + else: + self._endpoint = endpoint_component._QueryExecutionDistinctUnorderedEndpointComponent(self._endpoint) def next(self): """Returns the next query result. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/query_execution_info.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/query_execution_info.py index 617d56c81e8a..6c1f717c8703 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/query_execution_info.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/query_execution_info.py @@ -23,6 +23,7 @@ """ import six +from azure.cosmos.documents import _DistinctType class _PartitionedQueryExecutionInfo(object): @@ -32,7 +33,11 @@ class _PartitionedQueryExecutionInfo(object): """ QueryInfoPath = "queryInfo" + HasSelectValue = [QueryInfoPath, "hasSelectValue"] TopPath = [QueryInfoPath, "top"] + OffsetPath = [QueryInfoPath, "offset"] + LimitPath = [QueryInfoPath, "limit"] + DistinctTypePath = [QueryInfoPath, "distinctType"] OrderByPath = [QueryInfoPath, "orderBy"] AggregatesPath = [QueryInfoPath, "aggregates"] QueryRangesPath = "queryRanges" @@ -50,6 +55,21 @@ def get_top(self): """ return self._extract(_PartitionedQueryExecutionInfo.TopPath) + def get_limit(self): + """Returns the limit count (if any) or None + """ + return self._extract(_PartitionedQueryExecutionInfo.LimitPath) + + def get_offset(self): + """Returns the offset count (if any) or None + """ + return self._extract(_PartitionedQueryExecutionInfo.OffsetPath) + + def get_distinct_type(self): + """Returns the offset count (if any) or None + """ + return self._extract(_PartitionedQueryExecutionInfo.DistinctTypePath) + def get_order_by(self): """Returns order by items (if any) or None """ @@ -74,6 +94,32 @@ def get_rewritten_query(self): rewrittenQuery = rewrittenQuery.replace("{documentdb-formattableorderbyquery-filter}", "true") return rewrittenQuery + def has_select_value(self): + return self._extract(self.HasSelectValue) + + def has_top(self): + return self.get_top() is not None + + def has_limit(self): + return self.get_limit() is not None + + def has_offset(self): + return self.get_offset() is not None + + def has_distinct_type(self): + return self.get_distinct_type() != _DistinctType.NoneType + + def has_order_by(self): + order_by = self.get_order_by() + return order_by is not None and len(order_by) > 0 + + def has_aggregates(self): + aggregates = self.get_aggregates() + return aggregates is not None and len(aggregates) > 0 + + def has_rewritten_query(self): + return self.get_rewritten_query() is not None + def _extract(self, path): item = self._query_execution_info diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py index d1cf600be217..76ee23451274 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py @@ -23,7 +23,6 @@ """ from azure.core.paging import PageIterator # type: ignore from azure.cosmos._execution_context import execution_dispatcher -from azure.cosmos._execution_context import base_execution_context # pylint: disable=protected-access @@ -54,7 +53,9 @@ def __init__( :param dict options: The request options for the request. :param method fetch_function: - :param str collection_link: + :param method resource_type: + The type of the resource being queried + :param str resource_link: If this is a Document query/feed collection_link is required. Example of `fetch_function`: @@ -73,20 +74,10 @@ def __init__( self._collection_link = collection_link self._database_link = database_link self._partition_key = partition_key - self._ex_context = self._create_execution_context() - super(QueryIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token) - - def _create_execution_context(self): - """instantiates the internal query execution context based. - """ - if self._database_link: - # client side partitioning query - return base_execution_context._MultiCollectionQueryExecutionContext( - self._client, self._options, self._database_link, self._query, self._partition_key - ) - return execution_dispatcher._ProxyQueryExecutionContext( + self._ex_context = execution_dispatcher._ProxyQueryExecutionContext( self._client, self._collection_link, self._query, self._options, self._fetch_function ) + super(QueryIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token) def _unpack(self, block): continuation = None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py index 4d32c3aee167..dbbdd227cf2d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py @@ -93,7 +93,7 @@ def _second_range_is_after_first_range(range1, range2): ##r.min < #previous_r.max return False - if range2.min == range2.min and range1.isMaxInclusive and range2.isMinInclusive: + if range2.min == range1.max and range1.isMaxInclusive and range2.isMinInclusive: # the inclusive ending endpoint of previous_r is the same as the inclusive beginning endpoint of r return False diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 73441d19f5ab..875c4e65335a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -190,11 +190,11 @@ def read_item( request_options = build_options(kwargs) response_hook = kwargs.pop('response_hook', None) - if partition_key: + if partition_key is not None: request_options["partitionKey"] = self._set_partition_key(partition_key) if populate_query_metrics is not None: request_options["populateQueryMetrics"] = populate_query_metrics - if post_trigger_include: + if post_trigger_include is not None: request_options["postTriggerInclude"] = post_trigger_include result = self.client_connection.ReadItem(document_link=doc_link, options=request_options, **kwargs) @@ -402,9 +402,9 @@ def replace_item( request_options["disableIdGeneration"] = True if populate_query_metrics is not None: request_options["populateQueryMetrics"] = populate_query_metrics - if pre_trigger_include: + if pre_trigger_include is not None: request_options["preTriggerInclude"] = pre_trigger_include - if post_trigger_include: + if post_trigger_include is not None: request_options["postTriggerInclude"] = post_trigger_include result = self.client_connection.ReplaceItem( @@ -446,9 +446,9 @@ def upsert_item( request_options["disableIdGeneration"] = True if populate_query_metrics is not None: request_options["populateQueryMetrics"] = populate_query_metrics - if pre_trigger_include: + if pre_trigger_include is not None: request_options["preTriggerInclude"] = pre_trigger_include - if post_trigger_include: + if post_trigger_include is not None: request_options["postTriggerInclude"] = post_trigger_include result = self.client_connection.UpsertItem( @@ -492,11 +492,11 @@ def create_item( request_options["disableAutomaticIdGeneration"] = True if populate_query_metrics: request_options["populateQueryMetrics"] = populate_query_metrics - if pre_trigger_include: + if pre_trigger_include is not None: request_options["preTriggerInclude"] = pre_trigger_include - if post_trigger_include: + if post_trigger_include is not None: request_options["postTriggerInclude"] = post_trigger_include - if indexing_directive: + if indexing_directive is not None: request_options["indexingDirective"] = indexing_directive result = self.client_connection.CreateItem( @@ -536,13 +536,13 @@ def delete_item( """ request_options = build_options(kwargs) response_hook = kwargs.pop('response_hook', None) - if partition_key: + if partition_key is not None: request_options["partitionKey"] = self._set_partition_key(partition_key) if populate_query_metrics is not None: request_options["populateQueryMetrics"] = populate_query_metrics - if pre_trigger_include: + if pre_trigger_include is not None: request_options["preTriggerInclude"] = pre_trigger_include - if post_trigger_include: + if post_trigger_include is not None: request_options["postTriggerInclude"] = post_trigger_include document_link = self._get_document_link(item) @@ -699,7 +699,7 @@ def get_conflict(self, conflict, partition_key, **kwargs): """ request_options = build_options(kwargs) response_hook = kwargs.pop('response_hook', None) - if partition_key: + if partition_key is not None: request_options["partitionKey"] = self._set_partition_key(partition_key) result = self.client_connection.ReadConflict( @@ -725,7 +725,7 @@ def delete_conflict(self, conflict, partition_key, **kwargs): """ request_options = build_options(kwargs) response_hook = kwargs.pop('response_hook', None) - if partition_key: + if partition_key is not None: request_options["partitionKey"] = self._set_partition_key(partition_key) result = self.client_connection.DeleteConflict( diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/database.py b/sdk/cosmos/azure-cosmos/azure/cosmos/database.py index cc5067d003db..41eef9799339 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/database.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/database.py @@ -196,15 +196,15 @@ def create_container( :name: create_container_with_settings """ definition = dict(id=id) # type: Dict[str, Any] - if partition_key: + if partition_key is not None: definition["partitionKey"] = partition_key - if indexing_policy: + if indexing_policy is not None: definition["indexingPolicy"] = indexing_policy - if default_ttl: + if default_ttl is not None: definition["defaultTtl"] = default_ttl - if unique_key_policy: + if unique_key_policy is not None: definition["uniqueKeyPolicy"] = unique_key_policy - if conflict_resolution_policy: + if conflict_resolution_policy is not None: definition["conflictResolutionPolicy"] = conflict_resolution_policy request_options = build_options(kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index e71254442471..58eda8ec45b1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -401,7 +401,6 @@ def __init__(self): class _OperationType(object): """Represents the type of the operation """ - Create = "Create" Delete = "Delete" ExecuteJavaScript = "ExecuteJavaScript" @@ -413,6 +412,7 @@ class _OperationType(object): Recreate = "Recreate" Replace = "Replace" SqlQuery = "SqlQuery" + QueryPlan = "QueryPlan" Update = "Update" Upsert = "Upsert" @@ -438,3 +438,32 @@ def IsReadOnlyOperation(operationType): _OperationType.Query, _OperationType.SqlQuery, ) + + @staticmethod + def IsFeedOperation(operationType): + return operationType in ( + _OperationType.Create, + _OperationType.Upsert, + _OperationType.ReadFeed, + _OperationType.Query, + _OperationType.SqlQuery, + _OperationType.QueryPlan, + _OperationType.HeadFeed, + ) + +class _QueryFeature(object): + NoneQuery = "NoneQuery" + Aggregate = "Aggregate" + CompositeAggregate = "CompositeAggregate" + Distinct = "Distinct" + GroupBy = "GroupBy" + MultipleAggregates = "MultipleAggregates" + MultipleOrderBy = "MultipleOrderBy" + OffsetAndLimit = "OffsetAndLimit" + OrderBy = "OrderBy" + Top = "Top" + +class _DistinctType(object): + NoneType = "None" + Ordered = "Ordered" + Unordered = "Unordered" diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py index da327a708548..040d6d0c9ab8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py @@ -89,6 +89,9 @@ class HttpHeaders(object): # Query Query = "x-ms-documentdb-query" IsQuery = "x-ms-documentdb-isquery" + IsQueryPlanRequest = "x-ms-cosmos-is-query-plan-request" + SupportedQueryFeatures = "x-ms-cosmos-supported-query-features" + QueryVersion = "x-ms-cosmos-query-version" # Our custom DocDB headers Continuation = "x-ms-continuation" @@ -262,9 +265,9 @@ class CookieHeaders(object): class Versions(object): """Constants of versions. """ - CurrentVersion = "2018-12-31" SDKName = "azure-cosmos" + QueryVersion = "1.0" class Delimiters(object): diff --git a/sdk/cosmos/azure-cosmos/test/aggregate_tests.py b/sdk/cosmos/azure-cosmos/test/aggregate_tests.py index a10343334234..15e7f687d765 100644 --- a/sdk/cosmos/azure-cosmos/test/aggregate_tests.py +++ b/sdk/cosmos/azure-cosmos/test/aggregate_tests.py @@ -49,156 +49,156 @@ class _config: sum = 0 -class AggregateQueryTestSequenceMeta(type): - def __new__(mcs, name, bases, dict): - def _run_one(query, expected_result): - def test(self): - self._execute_query_and_validate_results(mcs.created_collection, query, expected_result) - - return test - - def _setup(): - if (not _config.master_key or not _config.host): - raise Exception( - "You must specify your Azure Cosmos account values for " - "'masterKey' and 'host' at the top of this class to run the " - "tests.") - - mcs.client = cosmos_client.CosmosClient( - _config.host, _config.master_key, "Session", connection_policy=_config.connection_policy) - created_db = test_config._test_config.create_database_if_not_exist(mcs.client) - mcs.created_collection = _create_collection(created_db) - - # test documents - document_definitions = [] - - values = [None, False, True, "abc", "cdfg", "opqrs", "ttttttt", "xyz", "oo", "ppp"] - for value in values: - d = {_config.PARTITION_KEY: value, 'id': str(uuid.uuid4())} - document_definitions.append(d) - - for i in xrange(_config.DOCS_WITH_SAME_PARTITION_KEY): - d = {_config.PARTITION_KEY: _config.UNIQUE_PARTITION_KEY, - 'resourceId': i, - _config.FIELD: i + 1, - 'id': str(uuid.uuid4())} - document_definitions.append(d) - - _config.docs_with_numeric_id = \ - _config.DOCUMENTS_COUNT - len(values) - _config.DOCS_WITH_SAME_PARTITION_KEY - for i in xrange(_config.docs_with_numeric_id): - d = {_config.PARTITION_KEY: i + 1, 'id': str(uuid.uuid4())} - document_definitions.append(d) - - _config.sum = _config.docs_with_numeric_id \ - * (_config.docs_with_numeric_id + 1) / 2.0 - - _insert_doc(mcs.created_collection, document_definitions) - - def _generate_test_configs(): - aggregate_query_format = 'SELECT VALUE {}(r.{}) FROM r WHERE {}' - aggregate_orderby_query_format = 'SELECT VALUE {}(r.{}) FROM r WHERE {} ORDER BY r.{}' - aggregate_configs = [ - ['AVG', _config.sum / _config.docs_with_numeric_id, - 'IS_NUMBER(r.{})'.format(_config.PARTITION_KEY)], - ['AVG', None, 'true'], - ['COUNT', _config.DOCUMENTS_COUNT, 'true'], - ['MAX', 'xyz', 'true'], - ['MIN', None, 'true'], - ['SUM', _config.sum, 'IS_NUMBER(r.{})'.format(_config.PARTITION_KEY)], - ['SUM', None, 'true'] - ] - for operator, expected, condition in aggregate_configs: - _all_tests.append([ - '{} {}'.format(operator, condition), - aggregate_query_format.format(operator, _config.PARTITION_KEY, condition), - expected]) - _all_tests.append([ - '{} {} OrderBy'.format(operator, condition), - aggregate_orderby_query_format.format(operator, _config.PARTITION_KEY, condition, - _config.PARTITION_KEY), - expected]) - - aggregate_single_partition_format = 'SELECT VALUE {}(r.{}) FROM r WHERE r.{} = \'{}\'' - aggregate_orderby_single_partition_format = 'SELECT {}(r.{}) FROM r WHERE r.{} = \'{}\'' - same_partiton_sum = _config.DOCS_WITH_SAME_PARTITION_KEY * (_config.DOCS_WITH_SAME_PARTITION_KEY + 1) / 2.0 - aggregate_single_partition_configs = [ - ['AVG', same_partiton_sum / _config.DOCS_WITH_SAME_PARTITION_KEY], - ['COUNT', _config.DOCS_WITH_SAME_PARTITION_KEY], - ['MAX', _config.DOCS_WITH_SAME_PARTITION_KEY], - ['MIN', 1], - ['SUM', same_partiton_sum] - ] - for operator, expected in aggregate_single_partition_configs: - _all_tests.append([ - '{} SinglePartition {}'.format(operator, 'SELECT VALUE'), - aggregate_single_partition_format.format( - operator, _config.FIELD, _config.PARTITION_KEY, _config.UNIQUE_PARTITION_KEY), expected]) - _all_tests.append([ - '{} SinglePartition {}'.format(operator, 'SELECT'), - aggregate_orderby_single_partition_format.format( - operator, _config.FIELD, _config.PARTITION_KEY, _config.UNIQUE_PARTITION_KEY), - Exception()]) - - def _run_all(): - for test_name, query, expected_result in _all_tests: - test_name = "test_%s" % test_name - dict[test_name] = _run_one(query, expected_result) - - def _create_collection(created_db): - # type: (Database) -> Container - created_collection = created_db.create_container( - id='aggregate tests collection ' + str(uuid.uuid4()), - indexing_policy={ - 'includedPaths': [ - { - 'path': '/', - 'indexes': [ - { - 'kind': 'Range', - 'dataType': 'Number' - }, - { - 'kind': 'Range', - 'dataType': 'String' - } - ] - } - ] - }, - partition_key=PartitionKey( - path='/{}'.format(_config.PARTITION_KEY), - kind=documents.PartitionKind.Hash, - ), - offer_throughput=10100 - ) - - return created_collection - - def _insert_doc(collection, document_definitions): - # type: (Container, Dict[str, Any]) -> [Dict[str, Any]] - created_docs = [] - for d in document_definitions: - print(d) - created_doc = collection.create_item(body=d) - created_docs.append(created_doc) - - return created_docs - - _all_tests = [] - - return type.__new__(mcs, name, bases, dict) +@pytest.mark.usefixtures("teardown") +class AggregationQueryTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._all_tests = [] + cls._setup() + cls._generate_test_configs() + + @classmethod + def _setup(cls): + if (not _config.master_key or not _config.host): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + cls.client = cosmos_client.CosmosClient( + _config.host, {'masterKey': _config.master_key}, "Session", connection_policy=_config.connection_policy) + created_db = test_config._test_config.create_database_if_not_exist(cls.client) + cls.created_collection = cls._create_collection(created_db) + + # test documents + document_definitions = [] + + values = [None, False, True, "abc", "cdfg", "opqrs", "ttttttt", "xyz", "oo", "ppp"] + for value in values: + d = {_config.PARTITION_KEY: value, 'id': str(uuid.uuid4())} + document_definitions.append(d) + + for i in xrange(_config.DOCS_WITH_SAME_PARTITION_KEY): + d = {_config.PARTITION_KEY: _config.UNIQUE_PARTITION_KEY, + 'resourceId': i, + _config.FIELD: i + 1, + 'id': str(uuid.uuid4())} + document_definitions.append(d) + + _config.docs_with_numeric_id = \ + _config.DOCUMENTS_COUNT - len(values) - _config.DOCS_WITH_SAME_PARTITION_KEY + for i in xrange(_config.docs_with_numeric_id): + d = {_config.PARTITION_KEY: i + 1, 'id': str(uuid.uuid4())} + document_definitions.append(d) + + _config.sum = _config.docs_with_numeric_id \ + * (_config.docs_with_numeric_id + 1) / 2.0 + + cls._insert_doc(cls.created_collection, document_definitions) + + @classmethod + def _generate_test_configs(cls): + aggregate_query_format = 'SELECT VALUE {}(r.{}) FROM r WHERE {}' + aggregate_orderby_query_format = 'SELECT VALUE {}(r.{}) FROM r WHERE {} ORDER BY r.{}' + aggregate_configs = [ + ['AVG', _config.sum / _config.docs_with_numeric_id, + 'IS_NUMBER(r.{})'.format(_config.PARTITION_KEY)], + ['AVG', None, 'true'], + ['COUNT', _config.DOCUMENTS_COUNT, 'true'], + ['MAX', 'xyz', 'true'], + ['MIN', None, 'true'], + ['SUM', _config.sum, 'IS_NUMBER(r.{})'.format(_config.PARTITION_KEY)], + ['SUM', None, 'true'] + ] + for operator, expected, condition in aggregate_configs: + cls._all_tests.append([ + '{} {}'.format(operator, condition), + aggregate_query_format.format(operator, _config.PARTITION_KEY, condition), + expected]) + cls._all_tests.append([ + '{} {} OrderBy'.format(operator, condition), + aggregate_orderby_query_format.format(operator, _config.PARTITION_KEY, condition, + _config.PARTITION_KEY), + expected]) + + aggregate_single_partition_format = 'SELECT VALUE {}(r.{}) FROM r WHERE r.{} = \'{}\'' + aggregate_orderby_single_partition_format = 'SELECT {}(r.{}) FROM r WHERE r.{} = \'{}\'' + same_partiton_sum = _config.DOCS_WITH_SAME_PARTITION_KEY * (_config.DOCS_WITH_SAME_PARTITION_KEY + 1) / 2.0 + aggregate_single_partition_configs = [ + ['AVG', same_partiton_sum / _config.DOCS_WITH_SAME_PARTITION_KEY], + ['COUNT', _config.DOCS_WITH_SAME_PARTITION_KEY], + ['MAX', _config.DOCS_WITH_SAME_PARTITION_KEY], + ['MIN', 1], + ['SUM', same_partiton_sum] + ] + for operator, expected in aggregate_single_partition_configs: + cls._all_tests.append([ + '{} SinglePartition {}'.format(operator, 'SELECT VALUE'), + aggregate_single_partition_format.format( + operator, _config.FIELD, _config.PARTITION_KEY, _config.UNIQUE_PARTITION_KEY), expected]) + cls._all_tests.append([ + '{} SinglePartition {}'.format(operator, 'SELECT'), + aggregate_orderby_single_partition_format.format( + operator, _config.FIELD, _config.PARTITION_KEY, _config.UNIQUE_PARTITION_KEY), + Exception()]) + + def test_run_all(self): + for test_name, query, expected_result in self._all_tests: + test_name = "test_%s" % test_name + try: + self._run_one(query, expected_result) + print(test_name + ': ' + query + " PASSED") + except Exception as e: + print(test_name + ': ' + query + " FAILED") + raise e + + def _run_one(self, query, expected_result): + self._execute_query_and_validate_results(self.created_collection, query, expected_result) + + @classmethod + def _create_collection(cls, created_db): + # type: (Database) -> Container + created_collection = created_db.create_container( + id='aggregate tests collection ' + str(uuid.uuid4()), + indexing_policy={ + 'includedPaths': [ + { + 'path': '/', + 'indexes': [ + { + 'kind': 'Range', + 'dataType': 'Number' + }, + { + 'kind': 'Range', + 'dataType': 'String' + } + ] + } + ] + }, + partition_key=PartitionKey( + path='/{}'.format(_config.PARTITION_KEY), + kind=documents.PartitionKind.Hash, + ), + offer_throughput=10100 + ) + return created_collection + @classmethod + def _insert_doc(cls, collection, document_definitions): + # type: (Container, Dict[str, Any]) -> [Dict[str, Any]] + created_docs = [] + for d in document_definitions: + created_doc = collection.create_item(body=d) + created_docs.append(created_doc) + + return created_docs -@pytest.mark.usefixtures("teardown") -class AggregationQueryTest(with_metaclass(AggregateQueryTestSequenceMeta, unittest.TestCase)): def _execute_query_and_validate_results(self, collection, query, expected): # type: (Container, str, [Dict[str, Any]]) -> None - print('Running test with query: ' + query) # executes the query and validates the results against the expected results - options = {'enableCrossPartitionQuery': 'true'} - result_iterable = collection.query_items( query=query, enable_cross_partition_query=True @@ -239,5 +239,6 @@ def invokeNext(): else: _verify_result() + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/test/crud_tests.py b/sdk/cosmos/azure-cosmos/test/crud_tests.py index cd2d2934f90f..4bb00b7c78de 100644 --- a/sdk/cosmos/azure-cosmos/test/crud_tests.py +++ b/sdk/cosmos/azure-cosmos/test/crud_tests.py @@ -2104,7 +2104,6 @@ def test_absolute_client_timeout(self): with self.assertRaises(errors.CosmosClientTimeoutError): list(databases) - def test_query_iterable_functionality(self): def __create_resources(client): """Creates resources for this test. diff --git a/sdk/cosmos/azure-cosmos/test/query_tests.py b/sdk/cosmos/azure-cosmos/test/query_tests.py index 9fc10882ac60..8dbe6a01e815 100644 --- a/sdk/cosmos/azure-cosmos/test/query_tests.py +++ b/sdk/cosmos/azure-cosmos/test/query_tests.py @@ -2,7 +2,13 @@ import uuid import azure.cosmos.cosmos_client as cosmos_client import azure.cosmos._retry_utility as retry_utility +from azure.cosmos._execution_context.query_execution_info import _PartitionedQueryExecutionInfo +import azure.cosmos.errors as errors +from azure.cosmos.partition_key import PartitionKey +from azure.cosmos._execution_context.base_execution_context import _QueryExecutionContextBase +from azure.cosmos.documents import _DistinctType import pytest +import collections import test_config pytestmark = pytest.mark.cosmosEmulator @@ -184,11 +190,10 @@ def test_max_item_count_honored_in_order_by_query(self): max_item_count=1, enable_cross_partition_query=True ) - # 1 call to get query plans, 1 call to get pkr, 10 calls to one partion with the documents, 1 call each to other 4 partitions if 'localhost' in self.host or '127.0.0.1' in self.host: # TODO: Differing result between live and emulator - self.validate_query_requests_count(query_iterable, 16 * 2) + self.validate_query_requests_count(query_iterable, 15 * 2 + 1) else: - self.validate_query_requests_count(query_iterable, 17 * 2) + self.validate_query_requests_count(query_iterable, 17 * 2 + 1) query_iterable = created_collection.query_items( query=query, @@ -196,8 +201,7 @@ def test_max_item_count_honored_in_order_by_query(self): enable_cross_partition_query=True ) - # 1 call to get query plan 1 calls to one partition with the documents, 1 call each to other 4 partitions - self.validate_query_requests_count(query_iterable, 6 * 2) + self.validate_query_requests_count(query_iterable, 13) def validate_query_requests_count(self, query_iterable, expected_count): self.count = 0 @@ -213,6 +217,337 @@ def _MockExecuteFunction(self, function, *args, **kwargs): self.count += 1 return self.OriginalExecuteFunction(function, *args, **kwargs) + def test_get_query_plan_through_gateway(self): + created_collection = self.config.create_multi_partition_collection_with_custom_pk_if_not_exist(self.client) + self._validate_query_plan(query="Select top 10 value count(c.id) from c", + container_link=created_collection.container_link, + top=10, + order_by=[], + aggregate=['Count'], + select_value=True, + offset=None, + limit=None, + distinct=_DistinctType.NoneType) + + self._validate_query_plan(query="Select * from c order by c._ts offset 5 limit 10", + container_link=created_collection.container_link, + top=None, + order_by=['Ascending'], + aggregate=[], + select_value=False, + offset=5, + limit=10, + distinct=_DistinctType.NoneType) + + self._validate_query_plan(query="Select distinct value c.id from c order by c.id", + container_link=created_collection.container_link, + top=None, + order_by=['Ascending'], + aggregate=[], + select_value=True, + offset=None, + limit=None, + distinct=_DistinctType.Ordered) + + def _validate_query_plan(self, query, container_link, top, order_by, aggregate, select_value, offset, limit, distinct): + query_plan_dict = self.client.client_connection._GetQueryPlanThroughGateway(query, container_link) + query_execution_info = _PartitionedQueryExecutionInfo(query_plan_dict) + self.assertTrue(query_execution_info.has_rewritten_query()) + self.assertEquals(query_execution_info.has_distinct_type(), distinct != "None") + self.assertEquals(query_execution_info.get_distinct_type(), distinct) + self.assertEquals(query_execution_info.has_top(), top is not None) + self.assertEquals(query_execution_info.get_top(), top) + self.assertEquals(query_execution_info.has_order_by(), len(order_by) > 0) + self.assertListEqual(query_execution_info.get_order_by(), order_by) + self.assertEquals(query_execution_info.has_aggregates(), len(aggregate) > 0) + self.assertListEqual(query_execution_info.get_aggregates(), aggregate) + self.assertEquals(query_execution_info.has_select_value(), select_value) + self.assertEquals(query_execution_info.has_offset(), offset is not None) + self.assertEquals(query_execution_info.get_offset(), offset) + self.assertEquals(query_execution_info.has_limit(), limit is not None) + self.assertEquals(query_execution_info.get_limit(), limit) + + def test_unsupported_queries(self): + created_collection = self.config.create_multi_partition_collection_with_custom_pk_if_not_exist(self.client) + queries = ['SELECT COUNT(1) FROM c', 'SELECT COUNT(1) + 5 FROM c', 'SELECT COUNT(1) + SUM(c) FROM c'] + for query in queries: + query_iterable = created_collection.query_items(query=query, enable_cross_partition_query=True) + try: + list(query_iterable) + self.fail() + except errors.CosmosHttpResponseError as e: + self.assertEqual(e.status_code, 400) + + def test_query_with_non_overlapping_pk_ranges(self): + created_collection = self.config.create_multi_partition_collection_with_custom_pk_if_not_exist(self.client) + query_iterable = created_collection.query_items("select * from c where c.pk='1' or c.pk='2'", enable_cross_partition_query=True) + self.assertListEqual(list(query_iterable), []) + + def test_offset_limit(self): + created_collection = self.config.create_multi_partition_collection_with_custom_pk_if_not_exist(self.client) + max_item_counts = [0, 2, 5, 10] + values = [] + for i in range(10): + document_definition = {'pk': i, 'id': 'myId' + str(uuid.uuid4())} + values.append(created_collection.create_item(body=document_definition)['pk']) + + for max_item_count in max_item_counts: + self._validate_offset_limit(created_collection=created_collection, + query='SELECT * from c ORDER BY c.pk OFFSET 0 LIMIT 5', + max_item_count=max_item_count, + results=values[:5]) + + self._validate_offset_limit(created_collection=created_collection, + query='SELECT * from c ORDER BY c.pk OFFSET 5 LIMIT 10', + max_item_count=max_item_count, + results=values[5:]) + + self._validate_offset_limit(created_collection=created_collection, + query='SELECT * from c ORDER BY c.pk OFFSET 10 LIMIT 5', + max_item_count=max_item_count, + results=[]) + + self._validate_offset_limit(created_collection=created_collection, + query='SELECT * from c ORDER BY c.pk OFFSET 100 LIMIT 1', + max_item_count=max_item_count, + results=[]) + + def _validate_offset_limit(self, created_collection, query, max_item_count, results): + query_iterable = created_collection.query_items( + query=query, + enable_cross_partition_query=True, + max_item_count=max_item_count + ) + self.assertListEqual(list(map(lambda doc: doc['pk'], list(query_iterable))), results) + + def test_distinct(self): + created_database = self.config.create_database_if_not_exist(self.client) + distinct_field = 'distinct_field' + pk_field = "pk" + different_field = "different_field" + + created_collection = created_database.create_container( + id='collection with composite index ' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk", kind="Hash"), + indexing_policy={ + "compositeIndexes": [ + [{"path": "/" + pk_field, "order": "ascending"}, {"path": "/" + distinct_field, "order": "ascending"}], + [{"path": "/" + distinct_field, "order": "ascending"}, {"path": "/" + pk_field, "order": "ascending"}] + ] + } + ) + documents = [] + for i in range(5): + j = i + while j > i - 5: + document_definition = {pk_field: i, 'id': str(uuid.uuid4()), distinct_field: j} + documents.append(created_collection.create_item(body=document_definition)) + document_definition = {pk_field: i, 'id': str(uuid.uuid4()), distinct_field: j} + documents.append(created_collection.create_item(body=document_definition)) + document_definition = {pk_field: i, 'id': str(uuid.uuid4())} + documents.append(created_collection.create_item(body=document_definition)) + j -= 1 + + padded_docs = self._pad_with_none(documents, distinct_field) + + self._validate_distinct(created_collection=created_collection, + query='SELECT distinct c.%s from c ORDER BY c.%s' % (distinct_field, distinct_field), + results=self._get_distinct_docs(self._get_order_by_docs(padded_docs, distinct_field, None), distinct_field, None, True), + is_select=False, + fields=[distinct_field]) + + self._validate_distinct(created_collection=created_collection, + query='SELECT distinct c.%s, c.%s from c ORDER BY c.%s, c.%s' % (distinct_field, pk_field, pk_field, distinct_field), + results=self._get_distinct_docs(self._get_order_by_docs(padded_docs, pk_field, distinct_field), distinct_field, pk_field, True), + is_select=False, + fields=[distinct_field, pk_field]) + + self._validate_distinct(created_collection=created_collection, + query='SELECT distinct c.%s, c.%s from c ORDER BY c.%s, c.%s' % (distinct_field, pk_field, distinct_field, pk_field), + results=self._get_distinct_docs(self._get_order_by_docs(padded_docs, distinct_field, pk_field), distinct_field, pk_field, True), + is_select=False, + fields=[distinct_field, pk_field]) + + self._validate_distinct(created_collection=created_collection, + query='SELECT distinct value c.%s from c ORDER BY c.%s' % (distinct_field, distinct_field), + results=self._get_distinct_docs(self._get_order_by_docs(padded_docs, distinct_field, None), distinct_field, None, True), + is_select=False, + fields=[distinct_field]) + + self._validate_distinct(created_collection=created_collection, + query='SELECT distinct c.%s from c' % (distinct_field), + results=self._get_distinct_docs(padded_docs, distinct_field, None, False), + is_select=True, + fields=[distinct_field]) + + self._validate_distinct(created_collection=created_collection, + query='SELECT distinct c.%s, c.%s from c' % (distinct_field, pk_field), + results=self._get_distinct_docs(padded_docs, distinct_field, pk_field, False), + is_select=True, + fields=[distinct_field, pk_field]) + + self._validate_distinct(created_collection=created_collection, + query='SELECT distinct value c.%s from c' % (distinct_field), + results=self._get_distinct_docs(padded_docs, distinct_field, None, True), + is_select=True, + fields=[distinct_field]) + + self._validate_distinct(created_collection=created_collection, + query='SELECT distinct c.%s from c ORDER BY c.%s' % (different_field, different_field), + results=[], + is_select=True, + fields=[different_field]) + + self._validate_distinct(created_collection=created_collection, + query='SELECT distinct c.%s from c' % (different_field), + results=['None'], + is_select=True, + fields=[different_field]) + + created_database.delete_container(created_collection.id) + + def _get_order_by_docs(self, documents, field1, field2): + if field2 is None: + return sorted(documents, key=lambda d: (d[field1] is not None, d[field1])) + else: + return sorted(documents, key=lambda d: (d[field1] is not None, d[field1], d[field2] is not None, d[field2])) + + def _get_distinct_docs(self, documents, field1, field2, is_order_by_or_value): + if field2 is None: + res = collections.OrderedDict.fromkeys(doc[field1] for doc in documents) + if is_order_by_or_value: + res = filter(lambda x: False if x is None else True, res) + else: + res = collections.OrderedDict.fromkeys(str(doc[field1]) + "," + str(doc[field2]) for doc in documents) + return list(res) + + def _pad_with_none(self, documents, field): + for doc in documents: + if field not in doc: + doc[field] = None + return documents + + def _validate_distinct(self, created_collection, query, results, is_select, fields): + query_iterable = created_collection.query_items( + query=query, + enable_cross_partition_query=True + ) + query_results = list(query_iterable) + + self.assertEquals(len(results), len(query_results)) + query_results_strings = [] + result_strings = [] + for i in range(len(results)): + query_results_strings.append(self._get_query_result_string(query_results[i], fields)) + result_strings.append(str(results[i])) + if is_select: + query_results_strings = sorted(query_results_strings) + result_strings = sorted(result_strings) + self.assertListEqual(result_strings, query_results_strings) + + def _get_query_result_string(self, query_result, fields): + if type(query_result) is not dict: + return str(query_result) + res = str(query_result[fields[0]] if fields[0] in query_result else None) + if len(fields) == 2: + res = res + "," + str(query_result[fields[1]] if fields[1] in query_result else None) + + return res + + def test_distinct_on_different_types_and_field_orders(self): + created_collection = self.config.create_multi_partition_collection_with_custom_pk_if_not_exist(self.client) + self.payloads = [ + {'f1': 1, 'f2': 'value', 'f3': 100000000000000000, 'f4': [1, 2, '3'], 'f5': {'f6': {'f7': 2}}}, + {'f2': '\'value', 'f4': [1.0, 2, '3'], 'f5': {'f6': {'f7': 2.0}}, 'f1': 1.0, 'f3': 100000000000000000.00}, + {'f3': 100000000000000000.0, 'f5': {'f6': {'f7': 2}}, 'f2': '\'value', 'f1': 1, 'f4': [1, 2.0, '3']} + ] + self.OriginalExecuteFunction = _QueryExecutionContextBase.next + _QueryExecutionContextBase.next = self._MockNextFunction + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct value c.f1 from c", + expected_results=[1], + get_mock_result=lambda x, i: (None, x[i]["f1"]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct value c.f2 from c", + expected_results=['value', '\'value'], + get_mock_result=lambda x, i: (None, x[i]["f2"]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct value c.f2 from c order by c.f2", + expected_results=['\'value', 'value'], + get_mock_result=lambda x, i: (x[i]["f2"], x[i]["f2"]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct value c.f3 from c", + expected_results=[100000000000000000], + get_mock_result=lambda x, i: (None, x[i]["f3"]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct value c.f4 from c", + expected_results=[[1, 2, '3']], + get_mock_result=lambda x, i: (None, x[i]["f4"]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct value c.f5.f6 from c", + expected_results=[{'f7': 2}], + get_mock_result=lambda x, i: (None, x[i]["f5"]["f6"]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct c.f1, c.f2, c.f3 from c", + expected_results=[self.payloads[0], self.payloads[1]], + get_mock_result=lambda x, i: (None, x[i]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct c.f1, c.f2, c.f3 from c order by c.f1", + expected_results=[self.payloads[0], self.payloads[1]], + get_mock_result=lambda x, i: (i, x[i]) + ) + + _QueryExecutionContextBase.next = self.OriginalExecuteFunction + + def _validate_distinct_on_different_types_and_field_orders(self, collection, query, expected_results, get_mock_result): + self.count = 0 + self.get_mock_result = get_mock_result + query_iterable = collection.query_items(query, enable_cross_partition_query=True) + results = list(query_iterable) + for i in range(len(expected_results)): + if isinstance(results[i], dict): + self.assertDictEqual(results[i], expected_results[i]) + elif isinstance(results[i], list): + self.assertListEqual(results[i], expected_results[i]) + else: + self.assertEquals(results[i], expected_results[i]) + self.count = 0 + + def _MockNextFunction(self): + if self.count < len(self.payloads): + item, result = self.get_mock_result(self.payloads, self.count) + self.count += 1 + if item is not None: + return {'orderByItems': [{'item': item}], '_rid': 'fake_rid', 'payload': result} + else: + return result + return result + else: + raise StopIteration + if __name__ == "__main__": unittest.main() \ No newline at end of file