From 6e22b40c4efffecfb34168b76de6ab3895327ddb Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Thu, 23 Jul 2020 10:48:26 -0700 Subject: [PATCH] close client session in async tests (#12656) --- .../tests/test_content_async.py | 151 ++++---- .../tests/test_content_from_url.py | 6 +- .../tests/test_content_from_url_async.py | 84 +++-- .../tests/test_copy_model_async.py | 104 +++--- .../tests/test_custom_forms_async.py | 341 ++++++++++-------- .../tests/test_custom_forms_from_url_async.py | 289 ++++++++------- .../tests/test_mgmt_async.py | 155 ++++---- .../tests/test_receipt_async.py | 168 +++++---- .../tests/test_receipt_from_url_async.py | 149 ++++---- .../tests/test_samples_async.py | 24 +- .../tests/test_training.py | 1 + .../tests/test_training_async.py | 114 +++--- 12 files changed, 870 insertions(+), 716 deletions(-) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_async.py index 10ff0cba14f2..9a2dc3cd153a 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_async.py @@ -29,34 +29,38 @@ async def test_content_bad_endpoint(self, resource_group, location, form_recogni myfile = fd.read() with self.assertRaises(ServiceRequestError): client = FormRecognizerClient("http://notreal.azure.com", AzureKeyCredential(form_recognizer_account_key)) - poller = await client.begin_recognize_content(myfile) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content(myfile) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_content_authentication_successful_key(self, client): with open(self.invoice_pdf, "rb") as fd: myfile = fd.read() - poller = await client.begin_recognize_content(myfile) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content(myfile) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_content_authentication_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - poller = await client.begin_recognize_content(b"xxx", content_type="application/pdf") - result = await poller.result() + async with client: + poller = await client.begin_recognize_content(b"xxx", content_type="application/pdf") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_passing_enum_content_type(self, client): with open(self.invoice_pdf, "rb") as fd: myfile = fd.read() - poller = await client.begin_recognize_content( - myfile, - content_type=FormContentType.APPLICATION_PDF - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content( + myfile, + content_type=FormContentType.APPLICATION_PDF + ) + result = await poller.result() self.assertIsNotNone(result) @GlobalFormRecognizerAccountPreparer() @@ -64,50 +68,55 @@ async def test_passing_enum_content_type(self, client): async def test_damaged_file_passed_as_bytes(self, client): damaged_pdf = b"\x25\x50\x44\x46\x55\x55\x55" # still has correct bytes to be recognized as PDF with self.assertRaises(HttpResponseError): - poller = await client.begin_recognize_content( - damaged_pdf, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content( + damaged_pdf, + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_damaged_file_bytes_fails_autodetect_content_type(self, client): damaged_pdf = b"\x50\x44\x46\x55\x55\x55" # doesn't match any magic file numbers with self.assertRaises(ValueError): - poller = await client.begin_recognize_content( - damaged_pdf, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content( + damaged_pdf, + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_damaged_file_passed_as_bytes_io(self, client): damaged_pdf = BytesIO(b"\x25\x50\x44\x46\x55\x55\x55") # still has correct bytes to be recognized as PDF with self.assertRaises(HttpResponseError): - poller = await client.begin_recognize_content( - damaged_pdf, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content( + damaged_pdf, + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_damaged_file_bytes_io_fails_autodetect(self, client): damaged_pdf = BytesIO(b"\x50\x44\x46\x55\x55\x55") # doesn't match any magic file numbers with self.assertRaises(ValueError): - poller = await client.begin_recognize_content( - damaged_pdf, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content( + damaged_pdf, + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_blank_page(self, client): with open(self.blank_pdf, "rb") as fd: blank = fd.read() - poller = await client.begin_recognize_content( - blank, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content( + blank, + ) + result = await poller.result() self.assertIsNotNone(result) @GlobalFormRecognizerAccountPreparer() @@ -116,18 +125,20 @@ async def test_passing_bad_content_type_param_passed(self, client): with open(self.invoice_pdf, "rb") as fd: myfile = fd.read() with self.assertRaises(ValueError): - poller = await client.begin_recognize_content( - myfile, - content_type="application/jpeg" - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content( + myfile, + content_type="application/jpeg" + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_content_stream_passing_url(self, client): with self.assertRaises(TypeError): - poller = await client.begin_recognize_content("https://badurl.jpg", content_type="application/json") - result = await poller.result() + async with client: + poller = await client.begin_recognize_content("https://badurl.jpg", content_type="application/json") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() @@ -136,10 +147,11 @@ async def test_auto_detect_unsupported_stream_content(self, client): myfile = fd.read() with self.assertRaises(ValueError): - poller = await client.begin_recognize_content( - myfile - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content( + myfile + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() @@ -155,8 +167,9 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - poller = await client.begin_recognize_content(myform, cls=callback) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content(myform, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results @@ -171,8 +184,9 @@ async def test_content_stream_pdf(self, client): with open(self.invoice_pdf, "rb") as fd: myform = fd.read() - poller = await client.begin_recognize_content(myform) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content(myform) + result = await poller.result() self.assertEqual(len(result), 1) layout = result[0] self.assertEqual(layout.page_number, 1) @@ -195,8 +209,9 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - poller = await client.begin_recognize_content(myform, cls=callback) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content(myform, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results @@ -211,8 +226,9 @@ async def test_content_stream_jpg(self, client): with open(self.form_jpg, "rb") as fd: myform = fd.read() - poller = await client.begin_recognize_content(myform) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content(myform) + result = await poller.result() self.assertEqual(len(result), 1) layout = result[0] self.assertEqual(layout.page_number, 1) @@ -229,8 +245,9 @@ async def test_content_stream_jpg(self, client): async def test_content_multipage(self, client): with open(self.multipage_invoice_pdf, "rb") as fd: invoice = fd.read() - poller = await client.begin_recognize_content(invoice) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content(invoice) + result = await poller.result() self.assertEqual(len(result), 3) self.assertFormPagesHasValues(result) @@ -249,8 +266,9 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - poller = await client.begin_recognize_content(myform, cls=callback) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content(myform, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results @@ -265,13 +283,13 @@ def callback(raw_response, _, headers): async def test_content_continuation_token(self, client): with open(self.form_jpg, "rb") as fd: myfile = fd.read() - initial_poller = await client.begin_recognize_content(myfile) - cont_token = initial_poller.continuation_token() - - poller = await client.begin_recognize_content(myfile, continuation_token=cont_token) - result = await poller.result() - self.assertIsNotNone(result) - await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error + async with client: + initial_poller = await client.begin_recognize_content(myfile) + cont_token = initial_poller.continuation_token() + poller = await client.begin_recognize_content(myfile, continuation_token=cont_token) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error @GlobalFormRecognizerAccountPreparer() @@ -279,8 +297,9 @@ async def test_content_continuation_token(self, client): async def test_content_multipage_table_span_pdf(self, client): with open(self.multipage_table_pdf, "rb") as fd: myfile = fd.read() - poller = await client.begin_recognize_content(myfile) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content(myfile) + result = await poller.result() self.assertEqual(len(result), 2) layout = result[0] self.assertEqual(layout.page_number, 1) @@ -312,9 +331,9 @@ def callback(raw_response, _, headers): extracted_layout = prepare_content_result(analyze_result) responses.append(analyze_result) responses.append(extracted_layout) - - poller = await client.begin_recognize_content(myform, cls=callback) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content(myform, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url.py index 6726de20b83f..6c22b94259b1 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url.py @@ -23,10 +23,10 @@ class TestContentFromUrl(FormRecognizerTest): @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() def test_content_encoded_url(self, client): - try: + with pytest.raises(HttpResponseError) as e: poller = client.begin_recognize_content_from_url("https://fakeuri.com/blank%20space") - except HttpResponseError as e: - self.assertIn("https://fakeuri.com/blank%20space", e.response.request.body) + client.close() + self.assertIn("https://fakeuri.com/blank%20space", e.value.response.request.body) @GlobalFormRecognizerAccountPreparer() def test_content_url_bad_endpoint(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url_async.py index 0c94b7a9fb66..10e727bf120f 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url_async.py @@ -24,37 +24,41 @@ class TestContentFromUrlAsync(AsyncFormRecognizerTest): @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_content_encoded_url(self, client): - try: + with pytest.raises(HttpResponseError) as e: poller = await client.begin_recognize_content_from_url("https://fakeuri.com/blank%20space") - except HttpResponseError as e: - self.assertIn("https://fakeuri.com/blank%20space", e.response.request.body) + await client.close() + self.assertIn("https://fakeuri.com/blank%20space", e.value.response.request.body) @GlobalFormRecognizerAccountPreparer() async def test_content_url_bad_endpoint(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): with self.assertRaises(ServiceRequestError): client = FormRecognizerClient("http://notreal.azure.com", AzureKeyCredential(form_recognizer_account_key)) - poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_content_url_auth_successful_key(self, client): - poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_content_url_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_content_bad_url(self, client): with self.assertRaises(HttpResponseError): - poller = await client.begin_recognize_content_from_url("https://badurl.jpg") - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url("https://badurl.jpg") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() @@ -63,8 +67,9 @@ async def test_content_url_pass_stream(self, client): receipt = fd.read(4) # makes the recording smaller with self.assertRaises(HttpResponseError): - poller = await client.begin_recognize_content_from_url(receipt) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url(receipt) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() @@ -77,8 +82,9 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf, cls=callback) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results @@ -90,8 +96,9 @@ def callback(raw_response, _, headers): @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_content_url_pdf(self, client): - poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf) + result = await poller.result() self.assertEqual(len(result), 1) layout = result[0] self.assertEqual(layout.page_number, 1) @@ -111,8 +118,9 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - poller = await client.begin_recognize_content_from_url(self.form_url_jpg, cls=callback) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url(self.form_url_jpg, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results @@ -124,8 +132,9 @@ def callback(raw_response, _, headers): @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_content_url_jpg(self, client): - poller = await client.begin_recognize_content_from_url(self.form_url_jpg) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url(self.form_url_jpg) + result = await poller.result() self.assertEqual(len(result), 1) layout = result[0] self.assertEqual(layout.page_number, 1) @@ -140,8 +149,9 @@ async def test_content_url_jpg(self, client): @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_content_multipage_url(self, client): - poller = await client.begin_recognize_content_from_url(self.multipage_url_pdf) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url(self.multipage_url_pdf) + result = await poller.result() self.assertEqual(len(result), 3) self.assertFormPagesHasValues(result) @@ -156,8 +166,9 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - poller = await client.begin_recognize_content_from_url(self.multipage_url_pdf, cls=callback) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url(self.multipage_url_pdf, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results @@ -170,19 +181,21 @@ def callback(raw_response, _, headers): @GlobalClientPreparer() @pytest.mark.live_test_only async def test_content_continuation_token(self, client): - initial_poller = await client.begin_recognize_content_from_url(self.form_url_jpg) - cont_token = initial_poller.continuation_token() + async with client: + initial_poller = await client.begin_recognize_content_from_url(self.form_url_jpg) + cont_token = initial_poller.continuation_token() - poller = await client.begin_recognize_content_from_url(self.form_url_jpg, continuation_token=cont_token) - result = await poller.result() - self.assertIsNotNone(result) - await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error + poller = await client.begin_recognize_content_from_url(self.form_url_jpg, continuation_token=cont_token) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_content_multipage_table_span_pdf(self, client): - poller = await client.begin_recognize_content_from_url(self.multipage_table_url_pdf) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url(self.multipage_table_url_pdf) + result = await poller.result() self.assertEqual(len(result), 2) layout = result[0] self.assertEqual(layout.page_number, 1) @@ -212,8 +225,9 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - poller = await client.begin_recognize_content_from_url(self.multipage_table_url_pdf, cls=callback) - result = await poller.result() + async with client: + poller = await client.begin_recognize_content_from_url(self.multipage_table_url_pdf, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_copy_model_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_copy_model_async.py index 58e7fa1e2796..b9f3ef1242a8 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_copy_model_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_copy_model_async.py @@ -24,27 +24,29 @@ class TestCopyModelAsync(AsyncFormRecognizerTest): @GlobalClientPreparer(training=True) async def test_copy_model_none_model_id(self, client, container_sas_url): with self.assertRaises(ValueError): - await client.begin_copy_model(model_id=None, target={}) + async with client: + await client.begin_copy_model(model_id=None, target={}) @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True) async def test_copy_model_empty_model_id(self, client, container_sas_url): with self.assertRaises(ValueError): - await client.begin_copy_model(model_id="", target={}) + async with client: + await client.begin_copy_model(model_id="", target={}) @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True, copy=True) async def test_copy_model_successful(self, client, container_sas_url, location, resource_id): + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() - training_poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await training_poller.result() + target = await client.get_copy_authorization(resource_region=location, resource_id=resource_id) - target = await client.get_copy_authorization(resource_region=location, resource_id=resource_id) + copy_poller = await client.begin_copy_model(model.model_id, target=target) + copy = await copy_poller.result() - copy_poller = await client.begin_copy_model(model.model_id, target=target) - copy = await copy_poller.result() - - copied_model = await client.get_custom_model(copy.model_id) + copied_model = await client.get_custom_model(copy.model_id) self.assertEqual(copy.status, "ready") self.assertIsNotNone(copy.training_started_on) @@ -56,53 +58,53 @@ async def test_copy_model_successful(self, client, container_sas_url, location, @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True, copy=True) async def test_copy_model_fail(self, client, container_sas_url, location, resource_id): + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() - training_poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await training_poller.result() - - # give an incorrect region - target = await client.get_copy_authorization(resource_region="eastus", resource_id=resource_id) + # give an incorrect region + target = await client.get_copy_authorization(resource_region="eastus", resource_id=resource_id) - with pytest.raises(HttpResponseError) as e: - poller = await client.begin_copy_model(model.model_id, target=target) - copy = await poller.result() - self.assertIsNotNone(e.value.error.code) - self.assertIsNotNone(e.value.error.message) + with pytest.raises(HttpResponseError) as e: + poller = await client.begin_copy_model(model.model_id, target=target) + copy = await poller.result() + self.assertIsNotNone(e.value.error.code) + self.assertIsNotNone(e.value.error.message) @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True, copy=True) async def test_copy_model_fail_bad_model_id(self, client, container_sas_url, location, resource_id): pytest.skip("service team will tell us when to enable this test") + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() - poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await poller.result() - - target = await client.get_copy_authorization(resource_region=location, resource_id=resource_id) + target = await client.get_copy_authorization(resource_region=location, resource_id=resource_id) - with self.assertRaises(HttpResponseError): - # give bad model_id - poller = await client.begin_copy_model("00000000-0000-0000-0000-000000000000", target=target) - copy = await poller.result() + with self.assertRaises(HttpResponseError): + # give bad model_id + poller = await client.begin_copy_model("00000000-0000-0000-0000-000000000000", target=target) + copy = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True, copy=True) async def test_copy_model_transform(self, client, container_sas_url, location, resource_id): - - training_poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await training_poller.result() - - target = await client.get_copy_authorization(resource_region=location, resource_id=resource_id) - - raw_response = [] - def callback(response, _, headers): copy_result = client._client._deserialize(CopyOperationResult, response) model_info = CustomFormModelInfo._from_generated(copy_result, target["modelId"]) raw_response.append(copy_result) raw_response.append(model_info) - poller = await client.begin_copy_model(model.model_id, target=target, cls=callback) - copy = await poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() + + target = await client.get_copy_authorization(resource_region=location, resource_id=resource_id) + + raw_response = [] + + poller = await client.begin_copy_model(model.model_id, target=target, cls=callback) + copy = await poller.result() actual = raw_response[0] copy = raw_response[1] @@ -114,8 +116,8 @@ def callback(response, _, headers): @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True, copy=True) async def test_copy_authorization(self, client, container_sas_url, location, resource_id): - - target = await client.get_copy_authorization(resource_region="eastus", resource_id=resource_id) + async with client: + target = await client.get_copy_authorization(resource_region="eastus", resource_id=resource_id) self.assertIsNotNone(target["modelId"]) self.assertIsNotNone(target["accessToken"]) @@ -127,18 +129,18 @@ async def test_copy_authorization(self, client, container_sas_url, location, res @GlobalClientPreparer(training=True, copy=True) @pytest.mark.live_test_only async def test_copy_continuation_token(self, client, container_sas_url, location, resource_id): + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() - poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await poller.result() + target = await client.get_copy_authorization(resource_region=location, resource_id=resource_id) - target = await client.get_copy_authorization(resource_region=location, resource_id=resource_id) + initial_poller = await client.begin_copy_model(model.model_id, target=target) + cont_token = initial_poller.continuation_token() + poller = await client.begin_copy_model(model.model_id, target=target, continuation_token=cont_token) + result = await poller.result() + self.assertIsNotNone(result) - initial_poller = await client.begin_copy_model(model.model_id, target=target) - cont_token = initial_poller.continuation_token() - poller = await client.begin_copy_model(model.model_id, target=target, continuation_token=cont_token) - result = await poller.result() - self.assertIsNotNone(result) - - copied_model = await client.get_custom_model(result.model_id) - self.assertIsNotNone(copied_model) - await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error + copied_model = await client.get_custom_model(result.model_id) + self.assertIsNotNone(copied_model) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_async.py index 490732439f08..f1333b0b7e88 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_async.py @@ -26,13 +26,15 @@ class TestCustomFormsAsync(AsyncFormRecognizerTest): async def test_custom_form_none_model_id(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(ValueError): - await client.begin_recognize_custom_forms(model_id=None, form=b"xx") + async with client: + await client.begin_recognize_custom_forms(model_id=None, form=b"xx") @GlobalFormRecognizerAccountPreparer() async def test_custom_form_empty_model_id(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(ValueError): - await client.begin_recognize_custom_forms(model_id="", form=b"xx") + async with client: + await client.begin_recognize_custom_forms(model_id="", form=b"xx") @GlobalFormRecognizerAccountPreparer() async def test_custom_form_bad_endpoint(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -40,23 +42,26 @@ async def test_custom_form_bad_endpoint(self, resource_group, location, form_rec myfile = fd.read() with self.assertRaises(ServiceRequestError): client = FormRecognizerClient("http://notreal.azure.com", AzureKeyCredential(form_recognizer_account_key)) - poller = await client.begin_recognize_custom_forms(model_id="xx", form=myfile) - result = await poller.result() + async with client: + poller = await client.begin_recognize_custom_forms(model_id="xx", form=myfile) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_authentication_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - poller = await client.begin_recognize_custom_forms(model_id="xx", form=b"xx", content_type="image/jpeg") - result = await poller.result() + async with client: + poller = await client.begin_recognize_custom_forms(model_id="xx", form=b"xx", content_type="image/jpeg") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_passing_unsupported_url_content_type(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(TypeError): - poller = await client.begin_recognize_custom_forms(model_id="xx", form="https://badurl.jpg", content_type="application/json") - result = await poller.result() + async with client: + poller = await client.begin_recognize_custom_forms(model_id="xx", form="https://badurl.jpg", content_type="application/json") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_auto_detect_unsupported_stream_content(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -66,42 +71,46 @@ async def test_auto_detect_unsupported_stream_content(self, resource_group, loca myfile = fd.read() with self.assertRaises(ValueError): - poller = await client.begin_recognize_custom_forms( - model_id="xxx", - form=myfile, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_custom_forms( + model_id="xxx", + form=myfile, + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True) async def test_custom_form_damaged_file(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - - training_poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await training_poller.result() - - with self.assertRaises(HttpResponseError): - poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - b"\x25\x50\x44\x46\x55\x55\x55", - ) - result = await poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() + + with self.assertRaises(HttpResponseError): + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + b"\x25\x50\x44\x46\x55\x55\x55", + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True) async def test_custom_form_unlabeled_blank_page(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - - poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await poller.result() - with open(self.blank_pdf, "rb") as fd: blank = fd.read() - poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - blank - ) - form = await poller.result() + + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + blank + ) + form = await poller.result() self.assertEqual(len(form), 1) self.assertEqual(form[0].page_range.first_page_number, 1) @@ -112,17 +121,19 @@ async def test_custom_form_unlabeled_blank_page(self, client, container_sas_url) @GlobalClientPreparer(training=True) async def test_custom_form_labeled_blank_page(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - - poller = await client.begin_training(container_sas_url, use_training_labels=True) - model = await poller.result() - with open(self.blank_pdf, "rb") as fd: blank = fd.read() - poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - blank - ) - form = await poller.result() + + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + blank + ) + form = await poller.result() self.assertEqual(len(form), 1) self.assertEqual(form[0].page_range.first_page_number, 1) @@ -134,14 +145,16 @@ async def test_custom_form_labeled_blank_page(self, client, container_sas_url): async def test_custom_form_unlabeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await training_poller.result() - with open(self.form_jpg, "rb") as fd: myfile = fd.read() - poller = await fr_client.begin_recognize_custom_forms(model.model_id, myfile, content_type=FormContentType.IMAGE_JPEG) - form = await poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms(model.model_id, myfile, content_type=FormContentType.IMAGE_JPEG) + form = await poller.result() self.assertEqual(form[0].form_type, "form-0") self.assertFormPagesHasValues(form[0].pages) for label, field in form[0].fields.items(): @@ -155,19 +168,20 @@ async def test_custom_form_unlabeled(self, client, container_sas_url): @GlobalClientPreparer(training=True, multipage=True) async def test_custom_form_multipage_unlabeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - - training_poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await training_poller.result() - with open(self.multipage_invoice_pdf, "rb") as fd: myfile = fd.read() - poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - myfile, - content_type=FormContentType.APPLICATION_PDF - ) - forms = await poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + myfile, + content_type=FormContentType.APPLICATION_PDF + ) + forms = await poller.result() for form in forms: if form.form_type is None: @@ -186,14 +200,16 @@ async def test_custom_form_multipage_unlabeled(self, client, container_sas_url): async def test_custom_form_labeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training(container_sas_url, use_training_labels=True) - model = await training_poller.result() - with open(self.form_jpg, "rb") as fd: myfile = fd.read() - poller = await fr_client.begin_recognize_custom_forms(model.model_id, myfile, content_type=FormContentType.IMAGE_JPEG) - form = await poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms(model.model_id, myfile, content_type=FormContentType.IMAGE_JPEG) + form = await poller.result() self.assertEqual(form[0].form_type, "form-"+model.model_id) self.assertFormPagesHasValues(form[0].pages) @@ -207,22 +223,23 @@ async def test_custom_form_labeled(self, client, container_sas_url): @GlobalClientPreparer(training=True, multipage=True) async def test_custom_form_multipage_labeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - - training_poller = await client.begin_training( - container_sas_url, - use_training_labels=True - ) - model = await training_poller.result() - with open(self.multipage_invoice_pdf, "rb") as fd: myfile = fd.read() - poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - myfile, - content_type=FormContentType.APPLICATION_PDF - ) - forms = await poller.result() + async with client: + training_poller = await client.begin_training( + container_sas_url, + use_training_labels=True + ) + model = await training_poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + myfile, + content_type=FormContentType.APPLICATION_PDF + ) + forms = await poller.result() for form in forms: self.assertEqual(form.form_type, "form-"+model.model_id) @@ -239,9 +256,6 @@ async def test_custom_form_multipage_labeled(self, client, container_sas_url): async def test_form_unlabeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await training_poller.result() - responses = [] def callback(raw_response, _, headers): @@ -253,13 +267,18 @@ def callback(raw_response, _, headers): with open(self.form_jpg, "rb") as fd: myfile = fd.read() - poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - myfile, - include_field_elements=True, - cls=callback - ) - form = await poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + myfile, + include_field_elements=True, + cls=callback + ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] @@ -277,9 +296,6 @@ def callback(raw_response, _, headers): async def test_custom_forms_multipage_unlabeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await training_poller.result() - responses = [] def callback(raw_response, _, headers): @@ -291,13 +307,18 @@ def callback(raw_response, _, headers): with open(self.multipage_invoice_pdf, "rb") as fd: myfile = fd.read() - poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - myfile, - include_field_elements=True, - cls=callback - ) - form = await poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + myfile, + include_field_elements=True, + cls=callback + ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] read_results = actual.analyze_result.read_results @@ -316,9 +337,6 @@ def callback(raw_response, _, headers): async def test_form_labeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - training_polling = await client.begin_training(container_sas_url, use_training_labels=True) - model = await training_polling.result() - responses = [] def callback(raw_response, _, headers): @@ -330,13 +348,18 @@ def callback(raw_response, _, headers): with open(self.form_jpg, "rb") as fd: myfile = fd.read() - poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - myfile, - include_field_elements=True, - cls=callback - ) - form = await poller.result() + async with client: + training_polling = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_polling.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + myfile, + include_field_elements=True, + cls=callback + ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] @@ -354,9 +377,6 @@ def callback(raw_response, _, headers): async def test_custom_forms_multipage_labeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training(container_sas_url, use_training_labels=True) - model = await training_poller.result() - responses = [] def callback(raw_response, _, headers): @@ -368,13 +388,18 @@ def callback(raw_response, _, headers): with open(self.multipage_invoice_pdf, "rb") as fd: myfile = fd.read() - poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - myfile, - include_field_elements=True, - cls=callback - ) - form = await poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + myfile, + include_field_elements=True, + cls=callback + ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] @@ -394,35 +419,34 @@ def callback(raw_response, _, headers): @pytest.mark.live_test_only async def test_custom_form_continuation_token(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - - poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await poller.result() - with open(self.form_jpg, "rb") as fd: myfile = fd.read() - initial_poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - myfile - ) - - cont_token = initial_poller.continuation_token() - poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - myfile, - continuation_token=cont_token - ) - result = await poller.result() - self.assertIsNotNone(result) - await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error + + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() + + async with fr_client: + initial_poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + myfile + ) + + cont_token = initial_poller.continuation_token() + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + myfile, + continuation_token=cont_token + ) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True, multipage2=True) async def test_custom_form_multipage_vendor_set_unlabeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await poller.result() - responses = [] def callback(raw_response, _, headers): @@ -434,13 +458,18 @@ def callback(raw_response, _, headers): with open(self.multipage_vendor_pdf, "rb") as fd: myfile = fd.read() - poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - myfile, - include_field_elements=True, - cls=callback - ) - form = await poller.result() + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + myfile, + include_field_elements=True, + cls=callback + ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] read_results = actual.analyze_result.read_results @@ -459,27 +488,29 @@ def callback(raw_response, _, headers): async def test_custom_form_multipage_vendor_set_labeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = await client.begin_training(container_sas_url, use_training_labels=True) - model = await poller.result() - responses = [] + with open(self.multipage_vendor_pdf, "rb") as fd: + myfile = fd.read() + def callback(raw_response, _, headers): analyze_result = fr_client._client._deserialize(AnalyzeOperationResult, raw_response) form = prepare_form_result(analyze_result, model.model_id) responses.append(analyze_result) responses.append(form) - with open(self.multipage_vendor_pdf, "rb") as fd: - myfile = fd.read() - - poller = await fr_client.begin_recognize_custom_forms( - model.model_id, - myfile, - include_field_elements=True, - cls=callback - ) - form = await poller.result() + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + myfile, + include_field_elements=True, + cls=callback + ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] read_results = actual.analyze_result.read_results diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_from_url_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_from_url_async.py index 083a05e41d12..983e301f6f4f 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_from_url_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_from_url_async.py @@ -24,47 +24,52 @@ class TestCustomFormsFromUrlAsync(AsyncFormRecognizerTest): @GlobalFormRecognizerAccountPreparer() async def test_custom_forms_encoded_url(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) - try: - poller = await client.begin_recognize_custom_forms_from_url( - model_id="00000000-0000-0000-0000-000000000000", - form_url="https://fakeuri.com/blank%20space" - ) - except HttpResponseError as e: - self.assertIn("https://fakeuri.com/blank%20space", e.response.request.body) + with pytest.raises(HttpResponseError) as e: + async with client: + poller = await client.begin_recognize_custom_forms_from_url( + model_id="00000000-0000-0000-0000-000000000000", + form_url="https://fakeuri.com/blank%20space" + ) + self.assertIn("https://fakeuri.com/blank%20space", e.value.response.request.body) @GlobalFormRecognizerAccountPreparer() async def test_custom_form_none_model_id(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(ValueError): - await client.begin_recognize_custom_forms_from_url(model_id=None, form_url="https://badurl.jpg") + async with client: + await client.begin_recognize_custom_forms_from_url(model_id=None, form_url="https://badurl.jpg") @GlobalFormRecognizerAccountPreparer() async def test_custom_form_empty_model_id(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(ValueError): - await client.begin_recognize_custom_forms_from_url(model_id="", form_url="https://badurl.jpg") + async with client: + await client.begin_recognize_custom_forms_from_url(model_id="", form_url="https://badurl.jpg") @GlobalFormRecognizerAccountPreparer() async def test_custom_form_url_bad_endpoint(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): with self.assertRaises(ServiceRequestError): client = FormRecognizerClient("http://notreal.azure.com", AzureKeyCredential(form_recognizer_account_key)) - poller = await client.begin_recognize_custom_forms_from_url(model_id="xx", form_url=self.form_url_jpg) - result = await poller.result() + async with client: + poller = await client.begin_recognize_custom_forms_from_url(model_id="xx", form_url=self.form_url_jpg) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_url_authentication_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - poller = await client.begin_recognize_custom_forms_from_url(model_id="xx", form_url=self.form_url_jpg) - result = await poller.result() + async with client: + poller = await client.begin_recognize_custom_forms_from_url(model_id="xx", form_url=self.form_url_jpg) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_passing_bad_url(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with pytest.raises(HttpResponseError) as e: - poller = await client.begin_recognize_custom_forms_from_url(model_id="xx", form_url="https://badurl.jpg") - result = await poller.result() + async with client: + poller = await client.begin_recognize_custom_forms_from_url(model_id="xx", form_url="https://badurl.jpg") + result = await poller.result() self.assertIsNotNone(e.value.error.code) self.assertIsNotNone(e.value.error.message) @@ -74,37 +79,42 @@ async def test_pass_stream_into_url(self, resource_group, location, form_recogni with open(self.unsupported_content_py, "rb") as fd: with self.assertRaises(HttpResponseError): - poller = await client.begin_recognize_custom_forms_from_url( - model_id="xxx", - form_url=fd, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_custom_forms_from_url( + model_id="xxx", + form_url=fd, + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True) async def test_form_bad_url(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training(container_sas_url, use_training_labels=True) - model = await training_poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_poller.result() - with self.assertRaises(HttpResponseError): - poller = await fr_client.begin_recognize_custom_forms_from_url( - model.model_id, - form_url="https://badurl.jpg" - ) - result = await poller.result() + with self.assertRaises(HttpResponseError): + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + form_url="https://badurl.jpg" + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True) async def test_form_unlabeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await training_poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() - poller = await fr_client.begin_recognize_custom_forms_from_url(model.model_id, self.form_url_jpg) - form = await poller.result() + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms_from_url(model.model_id, self.form_url_jpg) + form = await poller.result() self.assertEqual(form[0].form_type, "form-0") self.assertFormPagesHasValues(form[0].pages) @@ -120,14 +130,16 @@ async def test_form_unlabeled(self, client, container_sas_url): async def test_custom_form_multipage_unlabeled(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await training_poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() - poller = await fr_client.begin_recognize_custom_forms_from_url( - model.model_id, - blob_sas_url, - ) - forms = await poller.result() + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + blob_sas_url, + ) + forms = await poller.result() for form in forms: if form.form_type is None: @@ -146,11 +158,13 @@ async def test_custom_form_multipage_unlabeled(self, client, container_sas_url, async def test_form_labeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training(container_sas_url, use_training_labels=True) - model = await training_poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_poller.result() - poller = await fr_client.begin_recognize_custom_forms_from_url(model.model_id, self.form_url_jpg) - form = await poller.result() + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms_from_url(model.model_id, self.form_url_jpg) + form = await poller.result() self.assertEqual(form[0].form_type, "form-"+model.model_id) self.assertFormPagesHasValues(form[0].pages) @@ -165,17 +179,19 @@ async def test_form_labeled(self, client, container_sas_url): async def test_form_multipage_labeled(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training( - container_sas_url, - use_training_labels=True - ) - model = await training_poller.result() + async with client: + training_poller = await client.begin_training( + container_sas_url, + use_training_labels=True + ) + model = await training_poller.result() - poller = await fr_client.begin_recognize_custom_forms_from_url( - model.model_id, - blob_sas_url - ) - forms = await poller.result() + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + blob_sas_url + ) + forms = await poller.result() for form in forms: self.assertEqual(form.form_type, "form-"+model.model_id) @@ -190,10 +206,6 @@ async def test_form_multipage_labeled(self, client, container_sas_url, blob_sas_ @GlobalClientPreparer(training=True) async def test_form_unlabeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - - training_poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await training_poller.result() - responses = [] def callback(raw_response, _, headers): @@ -202,13 +214,18 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(form) - poller = await fr_client.begin_recognize_custom_forms_from_url( - model.model_id, - self.form_url_jpg, - include_field_elements=True, - cls=callback - ) - form = await poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + self.form_url_jpg, + include_field_elements=True, + cls=callback + ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] @@ -226,9 +243,6 @@ def callback(raw_response, _, headers): async def test_multipage_unlabeled_transform(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await training_poller.result() - responses = [] def callback(raw_response, _, headers): @@ -237,14 +251,19 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(form) - poller = await fr_client.begin_recognize_custom_forms_from_url( - model.model_id, - blob_sas_url, - include_field_elements=True, - cls=callback - ) + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + blob_sas_url, + include_field_elements=True, + cls=callback + ) - form = await poller.result() + form = await poller.result() actual = responses[0] recognized_form = responses[1] read_results = actual.analyze_result.read_results @@ -262,9 +281,6 @@ def callback(raw_response, _, headers): async def test_form_labeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training(container_sas_url, use_training_labels=True) - model = await training_poller.result() - responses = [] def callback(raw_response, _, headers): @@ -273,13 +289,18 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(form) - poller = await fr_client.begin_recognize_custom_forms_from_url( - model.model_id, - self.form_url_jpg, - include_field_elements=True, - cls=callback - ) - form = await poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + self.form_url_jpg, + include_field_elements=True, + cls=callback + ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] @@ -297,9 +318,6 @@ def callback(raw_response, _, headers): async def test_multipage_labeled_transform(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - training_poller = await client.begin_training(container_sas_url, use_training_labels=True) - model = await training_poller.result() - responses = [] def callback(raw_response, _, headers): @@ -308,13 +326,18 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(form) - poller = await fr_client.begin_recognize_custom_forms_from_url( - model.model_id, - blob_sas_url, - include_field_elements=True, - cls=callback - ) - form = await poller.result() + async with client: + training_poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + blob_sas_url, + include_field_elements=True, + cls=callback + ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] @@ -335,31 +358,30 @@ def callback(raw_response, _, headers): async def test_custom_form_continuation_token(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await poller.result() - - initial_poller = await fr_client.begin_recognize_custom_forms_from_url( - model.model_id, - self.form_url_jpg - ) - cont_token = initial_poller.continuation_token() - poller = await fr_client.begin_recognize_custom_forms_from_url( - model.model_id, - self.form_url_jpg, - continuation_token=cont_token - ) - result = await poller.result() - self.assertIsNotNone(result) - await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() + + async with fr_client: + initial_poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + self.form_url_jpg + ) + cont_token = initial_poller.continuation_token() + poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + self.form_url_jpg, + continuation_token=cont_token + ) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True, multipage2=True, blob_sas_url=True) async def test_custom_form_multipage_vendor_set_unlabeled_transform(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await poller.result() - responses = [] def callback(raw_response, _, headers): @@ -368,13 +390,18 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(form) - poller = await fr_client.begin_recognize_custom_forms_from_url( - model.model_id, - blob_sas_url, - include_field_elements=True, - cls=callback - ) - form = await poller.result() + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + blob_sas_url, + include_field_elements=True, + cls=callback + ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] read_results = actual.analyze_result.read_results @@ -393,9 +420,6 @@ def callback(raw_response, _, headers): async def test_custom_form_multipage_vendor_set_labeled_transform(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - poller = await client.begin_training(container_sas_url, use_training_labels=True) - model = await poller.result() - responses = [] def callback(raw_response, _, headers): @@ -404,13 +428,18 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(form) - poller = await fr_client.begin_recognize_custom_forms_from_url( - model.model_id, - blob_sas_url, - include_field_elements=True, - cls=callback - ) - form = await poller.result() + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await poller.result() + + async with fr_client: + poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + blob_sas_url, + include_field_elements=True, + cls=callback + ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] read_results = actual.analyze_result.read_results diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_mgmt_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_mgmt_async.py index 6278cf8861d1..6f24ce442e39 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_mgmt_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_mgmt_async.py @@ -26,63 +26,73 @@ async def test_active_directory_auth_async(self): token = self.generate_oauth_token() endpoint = self.get_oauth_endpoint() client = FormTrainingClient(endpoint, token) - props = await client.get_account_properties() + async with client: + props = await client.get_account_properties() self.assertIsNotNone(props) @GlobalFormRecognizerAccountPreparer() async def test_account_properties_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormTrainingClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - result = await client.get_account_properties() + async with client: + result = await client.get_account_properties() @GlobalFormRecognizerAccountPreparer() async def test_get_model_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormTrainingClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - result = await client.get_custom_model("xx") + async with client: + result = await client.get_custom_model("xx") @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_get_model_empty_model_id(self, client): with self.assertRaises(ValueError): - result = await client.get_custom_model("") + async with client: + result = await client.get_custom_model("") @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_get_model_none_model_id(self, client): with self.assertRaises(ValueError): - result = await client.get_custom_model(None) + async with client: + result = await client.get_custom_model(None) @GlobalFormRecognizerAccountPreparer() async def test_list_model_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormTrainingClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - result = client.list_custom_models() - async for res in result: - test = res + async with client: + result = client.list_custom_models() + async for res in result: + test = res @GlobalFormRecognizerAccountPreparer() async def test_delete_model_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormTrainingClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - result = await client.delete_model("xx") + async with client: + result = await client.delete_model("xx") @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_delete_model_none_model_id(self, client): with self.assertRaises(ValueError): - result = await client.delete_model(None) + async with client: + result = await client.delete_model(None) @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_delete_model_empty_model_id(self, client): with self.assertRaises(ValueError): - result = await client.delete_model("") + async with client: + result = await client.delete_model("") @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_account_properties(self, client): - properties = await client.get_account_properties() + async with client: + properties = await client.get_account_properties() self.assertIsNotNone(properties.custom_model_limit) self.assertIsNotNone(properties.custom_model_count) @@ -90,70 +100,71 @@ async def test_account_properties(self, client): @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True) async def test_mgmt_model_labeled(self, client, container_sas_url): - - poller = await client.begin_training(container_sas_url, use_training_labels=True) - labeled_model_from_train = await poller.result() - labeled_model_from_get = await client.get_custom_model(labeled_model_from_train.model_id) - - self.assertEqual(labeled_model_from_train.model_id, labeled_model_from_get.model_id) - self.assertEqual(labeled_model_from_train.status, labeled_model_from_get.status) - self.assertEqual(labeled_model_from_train.training_started_on, labeled_model_from_get.training_started_on) - self.assertEqual(labeled_model_from_train.training_completed_on, labeled_model_from_get.training_completed_on) - self.assertEqual(labeled_model_from_train.errors, labeled_model_from_get.errors) - for a, b in zip(labeled_model_from_train.training_documents, labeled_model_from_get.training_documents): - self.assertEqual(a.document_name, b.document_name) - self.assertEqual(a.errors, b.errors) - self.assertEqual(a.page_count, b.page_count) - self.assertEqual(a.status, b.status) - for a, b in zip(labeled_model_from_train.submodels, labeled_model_from_get.submodels): - for field1, field2 in zip(a.fields.items(), b.fields.items()): - self.assertEqual(a.fields[field1[0]].name, b.fields[field2[0]].name) - self.assertEqual(a.fields[field1[0]].accuracy, b.fields[field2[0]].accuracy) - - models_list = client.list_custom_models() - async for model in models_list: - self.assertIsNotNone(model.model_id) - self.assertIsNotNone(model.status) - self.assertIsNotNone(model.training_started_on) - self.assertIsNotNone(model.training_completed_on) - - await client.delete_model(labeled_model_from_train.model_id) - - with self.assertRaises(ResourceNotFoundError): - await client.get_custom_model(labeled_model_from_train.model_id) + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=True) + labeled_model_from_train = await poller.result() + labeled_model_from_get = await client.get_custom_model(labeled_model_from_train.model_id) + + self.assertEqual(labeled_model_from_train.model_id, labeled_model_from_get.model_id) + self.assertEqual(labeled_model_from_train.status, labeled_model_from_get.status) + self.assertEqual(labeled_model_from_train.training_started_on, labeled_model_from_get.training_started_on) + self.assertEqual(labeled_model_from_train.training_completed_on, labeled_model_from_get.training_completed_on) + self.assertEqual(labeled_model_from_train.errors, labeled_model_from_get.errors) + for a, b in zip(labeled_model_from_train.training_documents, labeled_model_from_get.training_documents): + self.assertEqual(a.document_name, b.document_name) + self.assertEqual(a.errors, b.errors) + self.assertEqual(a.page_count, b.page_count) + self.assertEqual(a.status, b.status) + for a, b in zip(labeled_model_from_train.submodels, labeled_model_from_get.submodels): + for field1, field2 in zip(a.fields.items(), b.fields.items()): + self.assertEqual(a.fields[field1[0]].name, b.fields[field2[0]].name) + self.assertEqual(a.fields[field1[0]].accuracy, b.fields[field2[0]].accuracy) + + models_list = client.list_custom_models() + async for model in models_list: + self.assertIsNotNone(model.model_id) + self.assertIsNotNone(model.status) + self.assertIsNotNone(model.training_started_on) + self.assertIsNotNone(model.training_completed_on) + + await client.delete_model(labeled_model_from_train.model_id) + + with self.assertRaises(ResourceNotFoundError): + await client.get_custom_model(labeled_model_from_train.model_id) @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True) async def test_mgmt_model_unlabeled(self, client, container_sas_url): - poller = await client.begin_training(container_sas_url, use_training_labels=False) - unlabeled_model_from_train = await poller.result() - unlabeled_model_from_get = await client.get_custom_model(unlabeled_model_from_train.model_id) - - self.assertEqual(unlabeled_model_from_train.model_id, unlabeled_model_from_get.model_id) - self.assertEqual(unlabeled_model_from_train.status, unlabeled_model_from_get.status) - self.assertEqual(unlabeled_model_from_train.training_started_on, unlabeled_model_from_get.training_started_on) - self.assertEqual(unlabeled_model_from_train.training_completed_on, unlabeled_model_from_get.training_completed_on) - self.assertEqual(unlabeled_model_from_train.errors, unlabeled_model_from_get.errors) - for a, b in zip(unlabeled_model_from_train.training_documents, unlabeled_model_from_get.training_documents): - self.assertEqual(a.document_name, b.document_name) - self.assertEqual(a.errors, b.errors) - self.assertEqual(a.page_count, b.page_count) - self.assertEqual(a.status, b.status) - for a, b in zip(unlabeled_model_from_train.submodels, unlabeled_model_from_get.submodels): - for field1, field2 in zip(a.fields.items(), b.fields.items()): - self.assertEqual(a.fields[field1[0]].label, b.fields[field2[0]].label) - - models_list = client.list_custom_models() - async for model in models_list: - self.assertIsNotNone(model.model_id) - self.assertIsNotNone(model.status) - self.assertIsNotNone(model.training_started_on) - self.assertIsNotNone(model.training_completed_on) - - await client.delete_model(unlabeled_model_from_train.model_id) - - with self.assertRaises(ResourceNotFoundError): - await client.get_custom_model(unlabeled_model_from_train.model_id) + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=False) + unlabeled_model_from_train = await poller.result() + unlabeled_model_from_get = await client.get_custom_model(unlabeled_model_from_train.model_id) + + self.assertEqual(unlabeled_model_from_train.model_id, unlabeled_model_from_get.model_id) + self.assertEqual(unlabeled_model_from_train.status, unlabeled_model_from_get.status) + self.assertEqual(unlabeled_model_from_train.training_started_on, unlabeled_model_from_get.training_started_on) + self.assertEqual(unlabeled_model_from_train.training_completed_on, unlabeled_model_from_get.training_completed_on) + self.assertEqual(unlabeled_model_from_train.errors, unlabeled_model_from_get.errors) + for a, b in zip(unlabeled_model_from_train.training_documents, unlabeled_model_from_get.training_documents): + self.assertEqual(a.document_name, b.document_name) + self.assertEqual(a.errors, b.errors) + self.assertEqual(a.page_count, b.page_count) + self.assertEqual(a.status, b.status) + for a, b in zip(unlabeled_model_from_train.submodels, unlabeled_model_from_get.submodels): + for field1, field2 in zip(a.fields.items(), b.fields.items()): + self.assertEqual(a.fields[field1[0]].label, b.fields[field2[0]].label) + + models_list = client.list_custom_models() + async for model in models_list: + self.assertIsNotNone(model.model_id) + self.assertIsNotNone(model.status) + self.assertIsNotNone(model.training_started_on) + self.assertIsNotNone(model.training_completed_on) + + await client.delete_model(unlabeled_model_from_train.model_id) + + with self.assertRaises(ResourceNotFoundError): + await client.get_custom_model(unlabeled_model_from_train.model_id) @GlobalFormRecognizerAccountPreparer() async def test_get_form_recognizer_client(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_async.py index 4b9cfc8a98dc..7c7e76545f03 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_async.py @@ -30,34 +30,38 @@ async def test_receipt_bad_endpoint(self, resource_group, location, form_recogni myfile = fd.read() with self.assertRaises(ServiceRequestError): client = FormRecognizerClient("http://notreal.azure.com", AzureKeyCredential(form_recognizer_account_key)) - poller = await client.begin_recognize_receipts(myfile) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts(myfile) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_authentication_successful_key(self, client): with open(self.receipt_jpg, "rb") as fd: myfile = fd.read() - poller = await client.begin_recognize_receipts(myfile) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts(myfile) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_authentication_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - poller = await client.begin_recognize_receipts(b"xx", content_type="image/jpeg") - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts(b"xx", content_type="image/jpeg") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_passing_enum_content_type(self, client): with open(self.receipt_png, "rb") as fd: myfile = fd.read() - poller = await client.begin_recognize_receipts( - myfile, - content_type=FormContentType.IMAGE_PNG - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts( + myfile, + content_type=FormContentType.IMAGE_PNG + ) + result = await poller.result() self.assertIsNotNone(result) @GlobalFormRecognizerAccountPreparer() @@ -65,40 +69,44 @@ async def test_passing_enum_content_type(self, client): async def test_damaged_file_passed_as_bytes(self, client): damaged_pdf = b"\x25\x50\x44\x46\x55\x55\x55" # still has correct bytes to be recognized as PDF with self.assertRaises(HttpResponseError): - poller = await client.begin_recognize_receipts( - damaged_pdf, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts( + damaged_pdf, + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_damaged_file_bytes_fails_autodetect_content_type(self, client): damaged_pdf = b"\x50\x44\x46\x55\x55\x55" # doesn't match any magic file numbers with self.assertRaises(ValueError): - poller = await client.begin_recognize_receipts( - damaged_pdf, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts( + damaged_pdf, + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_damaged_file_passed_as_bytes_io(self, client): damaged_pdf = BytesIO(b"\x25\x50\x44\x46\x55\x55\x55") # still has correct bytes to be recognized as PDF with self.assertRaises(HttpResponseError): - poller = await client.begin_recognize_receipts( - damaged_pdf, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts( + damaged_pdf, + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_damaged_file_bytes_io_fails_autodetect(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) damaged_pdf = BytesIO(b"\x50\x44\x46\x55\x55\x55") # doesn't match any magic file numbers with self.assertRaises(ValueError): - poller = await client.begin_recognize_receipts( - damaged_pdf, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts( + damaged_pdf, + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() @@ -106,10 +114,11 @@ async def test_blank_page(self, client): with open(self.blank_pdf, "rb") as fd: blank = fd.read() - poller = await client.begin_recognize_receipts( - blank, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts( + blank, + ) + result = await poller.result() self.assertIsNotNone(result) @GlobalFormRecognizerAccountPreparer() @@ -118,18 +127,20 @@ async def test_passing_bad_content_type_param_passed(self, client): with open(self.receipt_jpg, "rb") as fd: myfile = fd.read() with self.assertRaises(ValueError): - poller = await client.begin_recognize_receipts( - myfile, - content_type="application/jpeg" - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts( + myfile, + content_type="application/jpeg" + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_passing_unsupported_url_content_type(self, client): with self.assertRaises(TypeError): - poller = await client.begin_recognize_receipts("https://badurl.jpg", content_type="application/json") - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts("https://badurl.jpg", content_type="application/json") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() @@ -138,10 +149,11 @@ async def test_auto_detect_unsupported_stream_content(self, client): myfile = fd.read() with self.assertRaises(ValueError): - poller = await client.begin_recognize_receipts( - myfile, - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts( + myfile, + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() @@ -158,12 +170,13 @@ def callback(raw_response, _, headers): with open(self.receipt_png, "rb") as fd: myfile = fd.read() - poller = await client.begin_recognize_receipts( - receipt=myfile, - include_field_elements=True, - cls=callback - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts( + receipt=myfile, + include_field_elements=True, + cls=callback + ) + result = await poller.result() raw_response = responses[0] returned_model = responses[1] @@ -212,12 +225,13 @@ def callback(raw_response, _, headers): with open(self.receipt_jpg, "rb") as fd: myfile = fd.read() - poller = await client.begin_recognize_receipts( - receipt=myfile, - include_field_elements=True, - cls=callback - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts( + receipt=myfile, + include_field_elements=True, + cls=callback + ) + result = await poller.result() raw_response = responses[0] returned_model = responses[1] @@ -260,8 +274,9 @@ async def test_receipt_jpg(self, client): with open(self.receipt_jpg, "rb") as fd: receipt = fd.read() - poller = await client.begin_recognize_receipts(receipt) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts(receipt) + result = await poller.result() self.assertEqual(len(result), 1) receipt = result[0] @@ -288,8 +303,9 @@ async def test_receipt_png(self, client): with open(self.receipt_png, "rb") as fd: receipt = fd.read() - poller = await client.begin_recognize_receipts(receipt) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts(receipt) + result = await poller.result() self.assertEqual(len(result), 1) receipt = result[0] self.assertEqual(receipt.fields.get("MerchantAddress").value, '123 Main Street Redmond, WA 98052') @@ -311,8 +327,9 @@ async def test_receipt_png(self, client): async def test_receipt_jpg_include_field_elements(self, client): with open(self.receipt_jpg, "rb") as fd: receipt = fd.read() - poller = await client.begin_recognize_receipts(receipt, include_field_elements=True) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts(receipt, include_field_elements=True) + result = await poller.result() self.assertEqual(len(result), 1) receipt = result[0] @@ -331,8 +348,9 @@ async def test_receipt_jpg_include_field_elements(self, client): async def test_receipt_multipage(self, client): with open(self.multipage_invoice_pdf, "rb") as fd: receipt = fd.read() - poller = await client.begin_recognize_receipts(receipt, include_field_elements=True) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts(receipt, include_field_elements=True) + result = await poller.result() self.assertEqual(len(result), 3) receipt = result[0] @@ -374,12 +392,13 @@ def callback(raw_response, _, headers): with open(self.multipage_invoice_pdf, "rb") as fd: myfile = fd.read() - poller = await client.begin_recognize_receipts( - receipt=myfile, - include_field_elements=True, - cls=callback - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts( + receipt=myfile, + include_field_elements=True, + cls=callback + ) + result = await poller.result() raw_response = responses[0] returned_model = responses[1] @@ -427,9 +446,10 @@ async def test_receipt_continuation_token(self, client): with open(self.receipt_jpg, "rb") as fd: receipt = fd.read() - initial_poller = await client.begin_recognize_receipts(receipt) - cont_token = initial_poller.continuation_token() - poller = await client.begin_recognize_receipts(receipt, continuation_token=cont_token) - result = await poller.result() - self.assertIsNotNone(result) - await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error + async with client: + initial_poller = await client.begin_recognize_receipts(receipt) + cont_token = initial_poller.continuation_token() + poller = await client.begin_recognize_receipts(receipt, continuation_token=cont_token) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_from_url_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_from_url_async.py index 8c5e549386a6..632f2dde5506 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_from_url_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_from_url_async.py @@ -27,12 +27,13 @@ async def test_polling_interval(self, resource_group, location, form_recognizer_ client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key), polling_interval=7) self.assertEqual(client._client._config.polling_interval, 7) - poller = await client.begin_recognize_receipts_from_url(self.receipt_url_jpg, polling_interval=6) - await poller.wait() - self.assertEqual(poller._polling_method._timeout, 6) - poller2 = await client.begin_recognize_receipts_from_url(self.receipt_url_jpg) - await poller2.wait() - self.assertEqual(poller2._polling_method._timeout, 7) # goes back to client default + async with client: + poller = await client.begin_recognize_receipts_from_url(self.receipt_url_jpg, polling_interval=6) + await poller.wait() + self.assertEqual(poller._polling_method._timeout, 6) + poller2 = await client.begin_recognize_receipts_from_url(self.receipt_url_jpg) + await poller2.wait() + self.assertEqual(poller2._polling_method._timeout, 7) # goes back to client default @pytest.mark.live_test_only @GlobalFormRecognizerAccountPreparer() @@ -40,52 +41,57 @@ async def test_active_directory_auth_async(self): token = self.generate_oauth_token() endpoint = self.get_oauth_endpoint() client = FormRecognizerClient(endpoint, token) - poller = await client.begin_recognize_receipts_from_url( - self.receipt_url_jpg - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts_from_url( + self.receipt_url_jpg + ) + result = await poller.result() self.assertIsNotNone(result) @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_receipts_encoded_url(self, client): - try: - poller = await client.begin_recognize_receipts_from_url("https://fakeuri.com/blank%20space") - except HttpResponseError as e: - self.assertIn("https://fakeuri.com/blank%20space", e.response.request.body) + with pytest.raises(HttpResponseError) as e: + async with client: + poller = await client.begin_recognize_receipts_from_url("https://fakeuri.com/blank%20space") + self.assertIn("https://fakeuri.com/blank%20space", e.value.response.request.body) @GlobalFormRecognizerAccountPreparer() async def test_receipt_url_bad_endpoint(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): with self.assertRaises(ServiceRequestError): client = FormRecognizerClient("http://notreal.azure.com", AzureKeyCredential(form_recognizer_account_key)) - poller = await client.begin_recognize_receipts_from_url( - self.receipt_url_jpg - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts_from_url( + self.receipt_url_jpg + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_receipt_url_auth_successful_key(self, client): - poller = await client.begin_recognize_receipts_from_url( - self.receipt_url_jpg - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts_from_url( + self.receipt_url_jpg + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_receipt_url_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - poller = await client.begin_recognize_receipts_from_url( - self.receipt_url_jpg - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts_from_url( + self.receipt_url_jpg + ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_receipt_bad_url(self, client): with self.assertRaises(HttpResponseError): - poller = await client.begin_recognize_receipts_from_url("https://badurl.jpg") - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts_from_url("https://badurl.jpg") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() @@ -95,8 +101,9 @@ async def test_receipt_url_pass_stream(self, client): receipt = fd.read(4) # makes the recording smaller with self.assertRaises(HttpResponseError): - poller = await client.begin_recognize_receipts_from_url(receipt) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts_from_url(receipt) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() @@ -110,12 +117,13 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_receipt) - poller = await client.begin_recognize_receipts_from_url( - self.receipt_url_jpg, - include_field_elements=True, - cls=callback - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts_from_url( + self.receipt_url_jpg, + include_field_elements=True, + cls=callback + ) + result = await poller.result() raw_response = responses[0] returned_model = responses[1] @@ -161,12 +169,13 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_receipt) - poller = await client.begin_recognize_receipts_from_url( - self.receipt_url_png, - include_field_elements=True, - cls=callback - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts_from_url( + self.receipt_url_png, + include_field_elements=True, + cls=callback + ) + result = await poller.result() raw_response = responses[0] returned_model = responses[1] @@ -205,11 +214,12 @@ def callback(raw_response, _, headers): @GlobalClientPreparer() async def test_receipt_url_include_field_elements(self, client): - poller = await client.begin_recognize_receipts_from_url( - self.receipt_url_jpg, - include_field_elements=True - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts_from_url( + self.receipt_url_jpg, + include_field_elements=True + ) + result = await poller.result() self.assertEqual(len(result), 1) receipt = result[0] @@ -227,10 +237,11 @@ async def test_receipt_url_include_field_elements(self, client): @GlobalClientPreparer() async def test_receipt_url_jpg(self, client): - poller = await client.begin_recognize_receipts_from_url( - self.receipt_url_jpg - ) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts_from_url( + self.receipt_url_jpg + ) + result = await poller.result() self.assertEqual(len(result), 1) receipt = result[0] @@ -255,8 +266,9 @@ async def test_receipt_url_jpg(self, client): @GlobalClientPreparer() async def test_receipt_url_png(self, client): - poller = await client.begin_recognize_receipts_from_url(self.receipt_url_png) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts_from_url(self.receipt_url_png) + result = await poller.result() self.assertEqual(len(result), 1) receipt = result[0] @@ -278,8 +290,9 @@ async def test_receipt_url_png(self, client): @GlobalClientPreparer() async def test_receipt_multipage_url(self, client): - poller = await client.begin_recognize_receipts_from_url(self.multipage_url_pdf, include_field_elements=True) - result = await poller.result() + async with client: + poller = await client.begin_recognize_receipts_from_url(self.multipage_url_pdf, include_field_elements=True) + result = await poller.result() self.assertEqual(len(result), 3) receipt = result[0] @@ -319,13 +332,14 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_receipt) - poller = await client.begin_recognize_receipts_from_url( - self.multipage_url_pdf, - include_field_elements=True, - cls=callback - ) + async with client: + poller = await client.begin_recognize_receipts_from_url( + self.multipage_url_pdf, + include_field_elements=True, + cls=callback + ) - result = await poller.result() + result = await poller.result() raw_response = responses[0] returned_model = responses[1] actual = raw_response.analyze_result.document_results @@ -369,9 +383,10 @@ def callback(raw_response, _, headers): @pytest.mark.live_test_only async def test_receipt_continuation_token(self, client): - initial_poller = await client.begin_recognize_receipts_from_url(self.receipt_url_jpg) - cont_token = initial_poller.continuation_token() - poller = await client.begin_recognize_receipts_from_url(self.receipt_url_jpg, continuation_token=cont_token) - result = await poller.result() - self.assertIsNotNone(result) - await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error + async with client: + initial_poller = await client.begin_recognize_receipts_from_url(self.receipt_url_jpg) + cont_token = initial_poller.continuation_token() + poller = await client.begin_recognize_receipts_from_url(self.receipt_url_jpg, continuation_token=cont_token) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_samples_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_samples_async.py index 64fb40ed2f3d..d23a8eed33b5 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_samples_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_samples_async.py @@ -62,8 +62,9 @@ async def test_sample_get_bounding_boxes_async(self, resource_group, location, f os.environ['CONTAINER_SAS_URL'] = self.get_settings_value("FORM_RECOGNIZER_STORAGE_CONTAINER_SAS_URL") ftc = FormTrainingClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) container_sas_url = os.environ['CONTAINER_SAS_URL'] - poller = await ftc.begin_training(container_sas_url, use_training_labels=False) - model = await poller.result() + async with ftc: + poller = await ftc.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() os.environ['CUSTOM_TRAINED_MODEL_ID'] = model.model_id _test_file('sample_get_bounding_boxes_async.py', form_recognizer_account, form_recognizer_account_key) @@ -83,8 +84,9 @@ async def test_sample_recognize_custom_forms_async(self, resource_group, locatio os.environ['CONTAINER_SAS_URL'] = self.get_settings_value("FORM_RECOGNIZER_STORAGE_CONTAINER_SAS_URL") ftc = FormTrainingClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) container_sas_url = os.environ['CONTAINER_SAS_URL'] - poller = await ftc.begin_training(container_sas_url, use_training_labels=False) - model = await poller.result() + async with ftc: + poller = await ftc.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() os.environ['CUSTOM_TRAINED_MODEL_ID'] = model.model_id _test_file('sample_recognize_custom_forms_async.py', form_recognizer_account, form_recognizer_account_key) @@ -121,8 +123,9 @@ async def test_sample_copy_model_async(self, resource_group, location, form_reco os.environ['CONTAINER_SAS_URL'] = self.get_settings_value("FORM_RECOGNIZER_STORAGE_CONTAINER_SAS_URL") ftc = FormTrainingClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) container_sas_url = os.environ['CONTAINER_SAS_URL'] - poller = await ftc.begin_training(container_sas_url, use_training_labels=False) - model = await poller.result() + async with ftc: + poller = await ftc.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() os.environ['AZURE_SOURCE_MODEL_ID'] = model.model_id os.environ["AZURE_FORM_RECOGNIZER_TARGET_ENDPOINT"] = form_recognizer_account os.environ["AZURE_FORM_RECOGNIZER_TARGET_KEY"] = form_recognizer_account_key @@ -141,10 +144,11 @@ async def test_sample_differentiate_output_models_trained_with_and_without_label os.environ['CONTAINER_SAS_URL'] = self.get_settings_value("FORM_RECOGNIZER_STORAGE_CONTAINER_SAS_URL") ftc = FormTrainingClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) container_sas_url = os.environ['CONTAINER_SAS_URL'] - poller = await ftc.begin_training(container_sas_url, use_training_labels=False) - unlabeled_model = await poller.result() - poller = await ftc.begin_training(container_sas_url, use_training_labels=True) - labeled_model = await poller.result() + async with ftc: + poller = await ftc.begin_training(container_sas_url, use_training_labels=False) + unlabeled_model = await poller.result() + poller = await ftc.begin_training(container_sas_url, use_training_labels=True) + labeled_model = await poller.result() os.environ["ID_OF_MODEL_TRAINED_WITH_LABELS"] = labeled_model.model_id os.environ["ID_OF_MODEL_TRAINED_WITHOUT_LABELS"] = unlabeled_model.model_id _test_file('sample_differentiate_output_models_trained_with_and_without_labels_async.py', diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training.py index 10b016cb0e46..4dcb467d603d 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training.py @@ -35,6 +35,7 @@ def check_poll_value(poll): poller2 = client.begin_training(training_files_url=container_sas_url, use_training_labels=False) poller2.wait() check_poll_value(poller2._polling_method._timeout) # goes back to client default + client.close() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training_async.py index 4a0d3a51111b..18f570cdfb76 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training_async.py @@ -36,33 +36,37 @@ def check_poll_value(poll): poller2 = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False) await poller2.wait() check_poll_value(poller2._polling_method._timeout) # goes back to client default + await client.close() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer() async def test_training_encoded_url(self, client): with self.assertRaises(HttpResponseError): - poller = await client.begin_training( - training_files_url="https://fakeuri.com/blank%20space", - use_training_labels=False - ) - self.assertIn("https://fakeuri.com/blank%20space", poller._polling_method._initial_response.http_request.body) - await poller.wait() + async with client: + poller = await client.begin_training( + training_files_url="https://fakeuri.com/blank%20space", + use_training_labels=False + ) + self.assertIn("https://fakeuri.com/blank%20space", poller._polling_method._initial_response.http_request.body) + await poller.wait() @GlobalFormRecognizerAccountPreparer() async def test_training_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormTrainingClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - poller = await client.begin_training("xx", use_training_labels=False) - result = await poller.result() + async with client: + poller = await client.begin_training("xx", use_training_labels=False) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True) async def test_training(self, client, container_sas_url): - poller = await client.begin_training( - training_files_url=container_sas_url, - use_training_labels=False) - model = await poller.result() + async with client: + poller = await client.begin_training( + training_files_url=container_sas_url, + use_training_labels=False) + model = await poller.result() self.assertIsNotNone(model.model_id) self.assertIsNotNone(model.training_started_on) @@ -83,9 +87,9 @@ async def test_training(self, client, container_sas_url): @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True, multipage=True) async def test_training_multipage(self, client, container_sas_url): - - poller = await client.begin_training(container_sas_url, use_training_labels=False) - model = await poller.result() + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() self.assertIsNotNone(model.model_id) self.assertIsNotNone(model.training_started_on) @@ -115,11 +119,12 @@ def callback(response): raw_response.append(raw_model) raw_response.append(custom_model) - poller = await client.begin_training( - training_files_url=container_sas_url, - use_training_labels=False, - cls=callback) - model = await poller.result() + async with client: + poller = await client.begin_training( + training_files_url=container_sas_url, + use_training_labels=False, + cls=callback) + model = await poller.result() raw_model = raw_response[0] custom_model = raw_response[1] @@ -137,8 +142,9 @@ def callback(response): raw_response.append(raw_model) raw_response.append(custom_model) - poller = await client.begin_training(container_sas_url, use_training_labels=False, cls=callback) - model = await poller.result() + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=False, cls=callback) + model = await poller.result() raw_model = raw_response[0] custom_model = raw_response[1] @@ -147,9 +153,9 @@ def callback(response): @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True) async def test_training_with_labels(self, client, container_sas_url): - - poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=True) - model = await poller.result() + async with client: + poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=True) + model = await poller.result() self.assertIsNotNone(model.model_id) self.assertIsNotNone(model.training_started_on) @@ -170,9 +176,9 @@ async def test_training_with_labels(self, client, container_sas_url): @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True, multipage=True) async def test_training_multipage_with_labels(self, client, container_sas_url): - - poller = await client.begin_training(container_sas_url, use_training_labels=True) - model = await poller.result() + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await poller.result() self.assertIsNotNone(model.model_id) self.assertIsNotNone(model.training_started_on) @@ -203,8 +209,9 @@ def callback(response): raw_response.append(raw_model) raw_response.append(custom_model) - poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=True, cls=callback) - model = await poller.result() + async with client: + poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=True, cls=callback) + model = await poller.result() raw_model = raw_response[0] custom_model = raw_response[1] @@ -222,8 +229,9 @@ def callback(response): raw_response.append(raw_model) raw_response.append(custom_model) - poller = await client.begin_training(container_sas_url, use_training_labels=True, cls=callback) - model = await poller.result() + async with client: + poller = await client.begin_training(container_sas_url, use_training_labels=True, cls=callback) + model = await poller.result() raw_model = raw_response[0] custom_model = raw_response[1] @@ -232,31 +240,31 @@ def callback(response): @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True) async def test_training_with_files_filter(self, client, container_sas_url): + async with client: + poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False, include_sub_folders=True) + model = await poller.result() + self.assertEqual(len(model.training_documents), 6) + self.assertEqual(model.training_documents[-1].document_name, "subfolder/Form_6.jpg") # we traversed subfolders - poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False, include_sub_folders=True) - model = await poller.result() - self.assertEqual(len(model.training_documents), 6) - self.assertEqual(model.training_documents[-1].document_name, "subfolder/Form_6.jpg") # we traversed subfolders - - poller = await client.begin_training(container_sas_url, use_training_labels=False, prefix="subfolder", include_sub_folders=True) - model = await poller.result() - self.assertEqual(len(model.training_documents), 1) - self.assertEqual(model.training_documents[0].document_name, "subfolder/Form_6.jpg") # we filtered for only subfolders - - with pytest.raises(HttpResponseError) as e: - poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False, prefix="xxx") + poller = await client.begin_training(container_sas_url, use_training_labels=False, prefix="subfolder", include_sub_folders=True) model = await poller.result() - self.assertIsNotNone(e.value.error.code) - self.assertIsNotNone(e.value.error.message) + self.assertEqual(len(model.training_documents), 1) + self.assertEqual(model.training_documents[0].document_name, "subfolder/Form_6.jpg") # we filtered for only subfolders + + with pytest.raises(HttpResponseError) as e: + poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False, prefix="xxx") + model = await poller.result() + self.assertIsNotNone(e.value.error.code) + self.assertIsNotNone(e.value.error.message) @GlobalFormRecognizerAccountPreparer() @GlobalClientPreparer(training=True) @pytest.mark.live_test_only async def test_training_continuation_token(self, client, container_sas_url): - - initial_poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False) - cont_token = initial_poller.continuation_token() - poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False, continuation_token=cont_token) - result = await poller.result() - self.assertIsNotNone(result) - await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error + async with client: + initial_poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False) + cont_token = initial_poller.continuation_token() + poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False, continuation_token=cont_token) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error