Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make use_training_labels positional required #11529

Merged
merged 5 commits into from
May 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions sdk/formrecognizer/azure-ai-formrecognizer/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
- `FormField` does not have a page_number.
- `begin_recognize_receipts` APIs now return `RecognizedReceipt` instead of `USReceipt`
- `USReceiptType` is renamed to `ReceiptType`
- `use_training_labels` is now a required positional param in the `begin_training` APIs.
- `stream` and `url` parameters found on methods for `FormRecognizerClient` have been renamed to `form` and `form_url`, respectively.
For recognize receipt methods, parameters have been renamed to `receipt` and `receipt_url`.



**New features**

- Authentication using `azure-identity` credentials now supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# pylint: disable=protected-access

from typing import (
Optional,
Any,
Iterable,
Union,
Expand Down Expand Up @@ -79,8 +78,8 @@ def __init__(self, endpoint, credential, **kwargs):
)

@distributed_trace
def begin_train_model(self, training_files_url, use_training_labels=False, **kwargs):
# type: (str, Optional[bool], Any) -> LROPoller
def begin_train_model(self, training_files_url, use_training_labels, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the changelog?

# type: (str, bool, Any) -> LROPoller
"""Create and train a custom model. The request must include a `training_files_url` parameter that is an
externally accessible Azure storage blob container Uri (preferably a Shared Access Signature Uri).
Models are trained using documents that are of the following content type - 'application/pdf',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# pylint: disable=protected-access

from typing import (
Optional,
Any,
Union,
AsyncIterable,
Expand Down Expand Up @@ -87,7 +86,7 @@ def __init__(
async def train_model(
self,
training_files_url: str,
use_training_labels: Optional[bool] = False,
use_training_labels: bool,
**kwargs: Any
) -> CustomFormModel:
"""Create and train a custom model. The request must include a `training_files_url` parameter that is an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def train_model_without_labels(self):
) as form_training_client:

# Default for train_model is `use_training_labels=False`
model = await form_training_client.train_model(self.container_sas_url)
model = await form_training_client.train_model(self.container_sas_url, use_training_labels=False)

# Custom model information
print("Model ID: {}".format(model.model_id))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def train_model_without_labels(self):
form_training_client = FormTrainingClient(self.endpoint, AzureKeyCredential(self.key))

# Default for begin_train_model is `use_training_labels=False`
poller = form_training_client.begin_train_model(self.container_sas_url)
poller = form_training_client.begin_train_model(self.container_sas_url, use_training_labels=False)
model = poller.result()

# Custom model information
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_auto_detect_unsupported_stream_content(self, resource_group, location,
def test_custom_form_damaged_file(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

with self.assertRaises(HttpResponseError):
Expand All @@ -73,7 +73,7 @@ def test_custom_form_damaged_file(self, client, container_sas_url):
def test_custom_form_unlabeled(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

with open(self.form_jpg, "rb") as stream:
Expand All @@ -98,7 +98,7 @@ def test_custom_form_unlabeled(self, client, container_sas_url):
def test_custom_form_multipage_unlabeled(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

with open(self.multipage_invoice_pdf, "rb") as stream:
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_custom_form_multipage_labeled(self, client, container_sas_url):
def test_custom_form_unlabeled_transform(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

responses = []
Expand Down Expand Up @@ -216,7 +216,7 @@ def callback(raw_response, _, headers):
def test_custom_form_multipage_unlabeled_transform(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

responses = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_auto_detect_unsupported_stream_content(self, resource_group, loca
async def test_custom_form_damaged_file(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

with self.assertRaises(HttpResponseError):
form = await fr_client.recognize_custom_forms(
Expand All @@ -73,7 +73,7 @@ async def test_custom_form_damaged_file(self, client, container_sas_url):
async def test_custom_form_unlabeled(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

with open(self.form_jpg, "rb") as fd:
myfile = fd.read()
Expand All @@ -94,7 +94,7 @@ async def test_custom_form_unlabeled(self, client, container_sas_url):
async def test_custom_form_multipage_unlabeled(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

with open(self.multipage_invoice_pdf, "rb") as fd:
myfile = fd.read()
Expand Down Expand Up @@ -168,7 +168,7 @@ async def test_custom_form_multipage_labeled(self, client, container_sas_url):
async def test_form_unlabeled_transform(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

responses = []

Expand Down Expand Up @@ -204,7 +204,7 @@ def callback(raw_response, _, headers):
async def test_custom_forms_multipage_unlabeled_transform(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

responses = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_custom_form_bad_url(self, client, container_sas_url):
def test_custom_form_unlabeled(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

poller = fr_client.begin_recognize_custom_forms_from_url(model.model_id, self.form_url_jpg)
Expand All @@ -89,7 +89,7 @@ def test_custom_form_unlabeled(self, client, container_sas_url):
def test_form_multipage_unlabeled(self, client, container_sas_url, blob_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

poller = fr_client.begin_recognize_custom_forms_from_url(
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_form_multipage_labeled(self, client, container_sas_url, blob_sas_url):
def test_custom_form_unlabeled_transform(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

responses = []
Expand Down Expand Up @@ -193,7 +193,7 @@ def callback(raw_response, _, headers):
def test_custom_form_multipage_unlabeled_transform(self, client, container_sas_url, blob_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

responses = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def test_form_bad_url(self, client, container_sas_url):
async def test_form_unlabeled(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

form = await fr_client.recognize_custom_forms_from_url(model.model_id, self.form_url_jpg)

Expand All @@ -85,7 +85,7 @@ async def test_form_unlabeled(self, client, container_sas_url):
async def test_custom_form_multipage_unlabeled(self, client, container_sas_url, blob_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

forms = await fr_client.recognize_custom_forms_from_url(
model.model_id,
Expand Down Expand Up @@ -148,7 +148,7 @@ async def test_form_multipage_labeled(self, client, container_sas_url, blob_sas_
async def test_form_unlabeled_transform(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

responses = []

Expand Down Expand Up @@ -181,7 +181,7 @@ def callback(raw_response, _, headers):
async def test_multipage_unlabeled_transform(self, client, container_sas_url, blob_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

responses = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_mgmt_model_labeled(self, client, container_sas_url):
@GlobalTrainingAccountPreparer()
def test_mgmt_model_unlabeled(self, client, container_sas_url):

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
unlabeled_model_from_train = poller.result()

unlabeled_model_from_get = client.get_custom_model(unlabeled_model_from_train.model_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def test_mgmt_model_labeled(self, client, container_sas_url):
@GlobalFormRecognizerAccountPreparer()
@GlobalTrainingAccountPreparer()
async def test_mgmt_model_unlabeled(self, client, container_sas_url):
unlabeled_model_from_train = await client.train_model(container_sas_url)
unlabeled_model_from_train = await client.train_model(container_sas_url, use_training_labels=False)

unlabeled_model_from_get = await client.get_custom_model(unlabeled_model_from_train.model_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ class TestTraining(FormRecognizerTest):
def test_training_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key):
client = FormTrainingClient(form_recognizer_account, AzureKeyCredential("xxxx"))
with self.assertRaises(ClientAuthenticationError):
poller = client.begin_train_model("xx")
poller = client.begin_train_model("xx", use_training_labels=False)

@GlobalFormRecognizerAccountPreparer()
@GlobalTrainingAccountPreparer()
def test_training(self, client, container_sas_url):

poller = client.begin_train_model(training_files_url=container_sas_url)
poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=False)
model = poller.result()

self.assertIsNotNone(model.model_id)
Expand All @@ -52,7 +52,7 @@ def test_training(self, client, container_sas_url):
@GlobalTrainingAccountPreparer(multipage=True)
def test_training_multipage(self, client, container_sas_url):

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

self.assertIsNotNone(model.model_id)
Expand Down Expand Up @@ -83,7 +83,7 @@ def callback(response):
raw_response.append(raw_model)
raw_response.append(custom_model)

poller = client.begin_train_model(training_files_url=container_sas_url, cls=callback)
poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=False, cls=callback)
model = poller.result()

raw_model = raw_response[0]
Expand All @@ -102,7 +102,7 @@ def callback(response):
raw_response.append(raw_model)
raw_response.append(custom_model)

poller = client.begin_train_model(container_sas_url, cls=callback)
poller = client.begin_train_model(container_sas_url, use_training_labels=False, cls=callback)
model = poller.result()

raw_model = raw_response[0]
Expand Down Expand Up @@ -199,16 +199,16 @@ def callback(response):
@GlobalTrainingAccountPreparer()
def test_training_with_files_filter(self, client, container_sas_url):

poller = client.begin_train_model(training_files_url=container_sas_url, include_sub_folders=True)
poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=False, include_sub_folders=True)
model = poller.result()
self.assertEqual(len(model.training_documents), 6)
self.assertEqual(model.training_documents[-1].document_name, "subfolder/Form_6.jpg") # we traversed subfolders

poller = client.begin_train_model(container_sas_url, prefix="subfolder", include_sub_folders=True)
poller = client.begin_train_model(container_sas_url, use_training_labels=False, prefix="subfolder", include_sub_folders=True)
model = poller.result()
self.assertEqual(len(model.training_documents), 1)
self.assertEqual(model.training_documents[0].document_name, "subfolder/Form_6.jpg") # we filtered for only subfolders

with self.assertRaises(HttpResponseError):
poller = client.begin_train_model(training_files_url=container_sas_url, prefix="xxx")
poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=False, prefix="xxx")
model = poller.result()
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ class TestTrainingAsync(AsyncFormRecognizerTest):
async def test_training_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key):
client = FormTrainingClient(form_recognizer_account, AzureKeyCredential("xxxx"))
with self.assertRaises(ClientAuthenticationError):
result = await client.train_model("xx")
result = await client.train_model("xx", use_training_labels=False)

@GlobalFormRecognizerAccountPreparer()
@GlobalTrainingAccountPreparer()
async def test_training(self, client, container_sas_url):

model = await client.train_model(training_files_url=container_sas_url)
model = await client.train_model(
training_files_url=container_sas_url,
use_training_labels=False)

self.assertIsNotNone(model.model_id)
self.assertIsNotNone(model.created_on)
Expand All @@ -51,7 +53,7 @@ async def test_training(self, client, container_sas_url):
@GlobalTrainingAccountPreparer(multipage=True)
async def test_training_multipage(self, client, container_sas_url):

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

self.assertIsNotNone(model.model_id)
self.assertIsNotNone(model.created_on)
Expand Down Expand Up @@ -81,7 +83,10 @@ def callback(response):
raw_response.append(raw_model)
raw_response.append(custom_model)

model = await client.train_model(training_files_url=container_sas_url, cls=callback)
model = await client.train_model(
training_files_url=container_sas_url,
use_training_labels=False,
cls=callback)

raw_model = raw_response[0]
custom_model = raw_response[1]
Expand All @@ -99,7 +104,7 @@ def callback(response):
raw_response.append(raw_model)
raw_response.append(custom_model)

model = await client.train_model(container_sas_url, cls=callback)
model = await client.train_model(container_sas_url, use_training_labels=False, cls=callback)

raw_model = raw_response[0]
custom_model = raw_response[1]
Expand Down Expand Up @@ -189,13 +194,13 @@ def callback(response):
@GlobalTrainingAccountPreparer()
async def test_training_with_files_filter(self, client, container_sas_url):

model = await client.train_model(training_files_url=container_sas_url, include_sub_folders=True)
model = await client.train_model(training_files_url=container_sas_url, use_training_labels=False, include_sub_folders=True)
self.assertEqual(len(model.training_documents), 6)
self.assertEqual(model.training_documents[-1].document_name, "subfolder/Form_6.jpg") # we traversed subfolders

model = await client.train_model(container_sas_url, prefix="subfolder", include_sub_folders=True)
model = await client.train_model(container_sas_url, use_training_labels=False, prefix="subfolder", include_sub_folders=True)
self.assertEqual(len(model.training_documents), 1)
self.assertEqual(model.training_documents[0].document_name, "subfolder/Form_6.jpg") # we filtered for only subfolders

with self.assertRaises(HttpResponseError):
model = await client.train_model(training_files_url=container_sas_url, prefix="xxx")
model = await client.train_model(training_files_url=container_sas_url, use_training_labels=False, prefix="xxx")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update sample_train_model_without_labels (the sync and async)? they don't have use_training_labels set, and they also include a comment that the default is use_training_labels=False`

Can you also run the samples that call train_model (or begin_train_model), just to make 100% sure they pass with this new change? thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we missing tests that use training labels (user_training_labels=True)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have the tests - just didnt need to update them