From 7a42f48e1860cfc6f9b097c2b8d4c673faaac92b Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Wed, 31 Jul 2019 10:24:00 -0700 Subject: [PATCH] cls and paging (#106) --- .../AzurePagingMethodTemplate.cshtml | 10 +++- .../AcceptanceTests/asynctests/test_paging.py | 9 +++ test/azure/AcceptanceTests/test_paging.py | 8 +++ .../_paging_operations_async.py | 55 +++++++++++++++---- .../paging/operations/_paging_operations.py | 55 +++++++++++++++---- .../_storage_accounts_operations_async.py | 10 +++- .../_storage_accounts_operations.py | 10 +++- 7 files changed, 129 insertions(+), 28 deletions(-) diff --git a/src/azure/Templates/AzurePagingMethodTemplate.cshtml b/src/azure/Templates/AzurePagingMethodTemplate.cshtml index 023ab078067..799bab8c68a 100644 --- a/src/azure/Templates/AzurePagingMethodTemplate.cshtml +++ b/src/azure/Templates/AzurePagingMethodTemplate.cshtml @@ -125,7 +125,10 @@ else { @: def extract_data(response): @: deserialized = self._deserialize('@(Model.PagedResponseClass.Name)', response) -@: return @(nextLinkName), iter(deserialized.@(Model.PagedMetadata.ItemProp.Name)) +@: list_of_elem = deserialized.@(Model.PagedMetadata.ItemProp.Name) +@: if cls: +@: list_of_elem = cls(list_of_elem) +@: return @(nextLinkName), iter(list_of_elem) @EmptyLine @: def get_next(next_link=None): @: request = prepare_request(next_link) @@ -136,7 +139,10 @@ else { @: async def extract_data_async(response): @: deserialized = self._deserialize('@(Model.PagedResponseClass.Name)', response) -@: return @(nextLinkName), AsyncList(deserialized.@(Model.PagedMetadata.ItemProp.Name)) +@: list_of_elem = deserialized.@(Model.PagedMetadata.ItemProp.Name) +@: if cls: +@: list_of_elem = cls(list_of_elem) +@: return @(nextLinkName), AsyncList(list_of_elem) @EmptyLine @: async def get_next_async(next_link=None): @: request = prepare_request(next_link) diff --git a/test/azure/AcceptanceTests/asynctests/test_paging.py b/test/azure/AcceptanceTests/asynctests/test_paging.py index 7443a378703..5b88471960c 100644 --- a/test/azure/AcceptanceTests/asynctests/test_paging.py +++ b/test/azure/AcceptanceTests/asynctests/test_paging.py @@ -57,6 +57,15 @@ async def paging_client(): async with AutoRestPagingTestService(cred, base_url="http://localhost:3000") as client: yield client +@pytest.mark.asyncio +async def test_paging_cls(paging_client): + def cb(list_of_obj): + for obj in list_of_obj: + obj.marked = True + return list_of_obj + async for obj in paging_client.paging.get_single_pages(cls=cb): + assert obj.marked + @pytest.mark.asyncio async def test_paging_happy_path(paging_client): diff --git a/test/azure/AcceptanceTests/test_paging.py b/test/azure/AcceptanceTests/test_paging.py index 53161644830..528776b34e1 100644 --- a/test/azure/AcceptanceTests/test_paging.py +++ b/test/azure/AcceptanceTests/test_paging.py @@ -57,6 +57,14 @@ def paging_client(): with AutoRestPagingTestService(cred, base_url="http://localhost:3000") as client: yield client +def test_paging_cls(paging_client): + def cb(list_of_obj): + for obj in list_of_obj: + obj.marked = True + return list_of_obj + pages = paging_client.paging.get_single_pages(cls=cb) + assert all(obj.marked for obj in pages) + def test_paging_happy_path(paging_client): pages = paging_client.paging.get_single_pages() diff --git a/test/azure/Expected/AcceptanceTests/Paging/paging/aio/operations_async/_paging_operations_async.py b/test/azure/Expected/AcceptanceTests/Paging/paging/aio/operations_async/_paging_operations_async.py index ac062034d33..2cd673c48e7 100644 --- a/test/azure/Expected/AcceptanceTests/Paging/paging/aio/operations_async/_paging_operations_async.py +++ b/test/azure/Expected/AcceptanceTests/Paging/paging/aio/operations_async/_paging_operations_async.py @@ -73,7 +73,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, AsyncList(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) @@ -143,7 +146,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, AsyncList(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) @@ -214,7 +220,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('OdataProductResult', response) - return deserialized.odatanext_link, AsyncList(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.odatanext_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) @@ -291,7 +300,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, AsyncList(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) @@ -343,7 +355,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, AsyncList(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) @@ -396,7 +411,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, AsyncList(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) @@ -447,7 +465,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, AsyncList(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) @@ -498,7 +519,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, AsyncList(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) @@ -549,7 +573,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, AsyncList(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) @@ -615,7 +642,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('OdataProductResult', response) - return deserialized.odatanext_link, AsyncList(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.odatanext_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) @@ -687,7 +717,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('OdataProductResult', response) - return deserialized.odatanext_link, AsyncList(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.odatanext_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) diff --git a/test/azure/Expected/AcceptanceTests/Paging/paging/operations/_paging_operations.py b/test/azure/Expected/AcceptanceTests/Paging/paging/operations/_paging_operations.py index cd8b75ea238..4b798a91065 100644 --- a/test/azure/Expected/AcceptanceTests/Paging/paging/operations/_paging_operations.py +++ b/test/azure/Expected/AcceptanceTests/Paging/paging/operations/_paging_operations.py @@ -71,7 +71,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, iter(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link) @@ -140,7 +143,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, iter(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link) @@ -210,7 +216,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('OdataProductResult', response) - return deserialized.odatanext_link, iter(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.odatanext_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link) @@ -286,7 +295,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, iter(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link) @@ -337,7 +349,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, iter(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link) @@ -389,7 +404,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, iter(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link) @@ -439,7 +457,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, iter(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link) @@ -489,7 +510,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, iter(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link) @@ -539,7 +563,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('ProductResult', response) - return deserialized.next_link, iter(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link) @@ -604,7 +631,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('OdataProductResult', response) - return deserialized.odatanext_link, iter(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.odatanext_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link) @@ -675,7 +705,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('OdataProductResult', response) - return deserialized.odatanext_link, iter(deserialized.values) + list_of_elem = deserialized.values + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.odatanext_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link) diff --git a/test/azure/Expected/AcceptanceTests/StorageManagementClient/storage/aio/operations_async/_storage_accounts_operations_async.py b/test/azure/Expected/AcceptanceTests/StorageManagementClient/storage/aio/operations_async/_storage_accounts_operations_async.py index e04bb2baef0..00922f43a5d 100644 --- a/test/azure/Expected/AcceptanceTests/StorageManagementClient/storage/aio/operations_async/_storage_accounts_operations_async.py +++ b/test/azure/Expected/AcceptanceTests/StorageManagementClient/storage/aio/operations_async/_storage_accounts_operations_async.py @@ -461,7 +461,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('StorageAccountListResult', response) - return deserialized.next_link, AsyncList(deserialized.value) + list_of_elem = deserialized.value + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) @@ -523,7 +526,10 @@ def prepare_request(next_link=None): async def extract_data_async(response): deserialized = self._deserialize('StorageAccountListResult', response) - return deserialized.next_link, AsyncList(deserialized.value) + list_of_elem = deserialized.value + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, AsyncList(list_of_elem) async def get_next_async(next_link=None): request = prepare_request(next_link) diff --git a/test/azure/Expected/AcceptanceTests/StorageManagementClient/storage/operations/_storage_accounts_operations.py b/test/azure/Expected/AcceptanceTests/StorageManagementClient/storage/operations/_storage_accounts_operations.py index 2b88568fb46..38a87a9c660 100644 --- a/test/azure/Expected/AcceptanceTests/StorageManagementClient/storage/operations/_storage_accounts_operations.py +++ b/test/azure/Expected/AcceptanceTests/StorageManagementClient/storage/operations/_storage_accounts_operations.py @@ -460,7 +460,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('StorageAccountListResult', response) - return deserialized.next_link, iter(deserialized.value) + list_of_elem = deserialized.value + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link) @@ -521,7 +524,10 @@ def prepare_request(next_link=None): def extract_data(response): deserialized = self._deserialize('StorageAccountListResult', response) - return deserialized.next_link, iter(deserialized.value) + list_of_elem = deserialized.value + if cls: + list_of_elem = cls(list_of_elem) + return deserialized.next_link, iter(list_of_elem) def get_next(next_link=None): request = prepare_request(next_link)