Skip to content

Commit

Permalink
[formrecognizer] Add classifier administration methods (Azure#29466)
Browse files Browse the repository at this point in the history
* add DocumentClassifierDetails model

* add admin methods for document classifiers

* add samples

* pushing an empty commit to prove I have access to your fork

* add begin_build_document_classifier + tests + models + regen with poller

* add tests for classifiers list/get/delete + record

* add samples for begin_build_document_classifer and fix lint

* fix sample snippet docstrings

* update changelog

* fix rtype

* fix test-resources.json

* update sample

* fix documentation

---------

Co-authored-by: Krista Pratico <[email protected]>
  • Loading branch information
catalinaperalta and kristapratico authored Mar 28, 2023
1 parent c44dd70 commit 289f3be
Show file tree
Hide file tree
Showing 18 changed files with 1,464 additions and 13 deletions.
2 changes: 2 additions & 0 deletions sdk/formrecognizer/azure-ai-formrecognizer/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
- Added `features` keyword argument on `begin_analyze_document()` and `begin_analyze_document_from_url()`.
- Added `AnalysisFeature` enum with optional document analysis feature to enable.
- Added the following optional properties on `DocumentStyle` class: `similar_font_family`, `font_style`, `font_weight`, `color`, `background_color`.
- Added support for custom document classification on `DocumentModelAdministrationClient`: `begin_build_document_classifier`,
`list_document_classifiers`, `get_document_classifier`, and `delete_document_classifier`.

### Breaking Changes

Expand Down
2 changes: 1 addition & 1 deletion sdk/formrecognizer/azure-ai-formrecognizer/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "python",
"TagPrefix": "python/formrecognizer/azure-ai-formrecognizer",
"Tag": "python/formrecognizer/azure-ai-formrecognizer_ab3a99b236"
"Tag": "python/formrecognizer/azure-ai-formrecognizer_380d29abf3"
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
CurrencyValue,
CustomDocumentModelsDetails,
ModelBuildMode,
DocumentClassifierDetails,
DocumentField,
DocumentKeyValuePair,
DocumentKeyValueElement,
Expand All @@ -67,6 +68,11 @@
DocumentAnalysisInnerError,
TargetAuthorization,
)
from ._generated.models import ( # patched models
ClassifierDocumentTypeDetails,
AzureBlobFileListSource,
AzureBlobContentSource,
)
from ._api_versions import FormRecognizerApiVersion, DocumentAnalysisApiVersion


Expand Down Expand Up @@ -110,6 +116,7 @@
"CurrencyValue",
"CustomDocumentModelsDetails",
"ModelBuildMode",
"DocumentClassifierDetails",
"DocumentField",
"DocumentKeyValueElement",
"DocumentKeyValuePair",
Expand All @@ -134,6 +141,9 @@
"DocumentAnalysisError",
"DocumentAnalysisInnerError",
"TargetAuthorization",
"ClassifierDocumentTypeDetails",
"AzureBlobFileListSource",
"AzureBlobContentSource",
]

__VERSION__ = VERSION
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
Any,
Union,
List,
Optional,
Mapping,
)
from azure.core.credentials import AzureKeyCredential, TokenCredential
from azure.core.tracing.decorator import distributed_trace
Expand All @@ -26,13 +28,15 @@
from ._document_analysis_client import DocumentAnalysisClient
from ._models import (
ModelBuildMode,
DocumentClassifierDetails,
DocumentModelDetails,
DocumentModelSummary,
OperationDetails,
OperationSummary,
ResourceDetails,
TargetAuthorization,
)
from ._generated.models import ClassifierDocumentTypeDetails


class DocumentModelAdministrationClient(FormRecognizerClientBase):
Expand Down Expand Up @@ -509,6 +513,161 @@ def get_operation(self, operation_id: str, **kwargs: Any) -> OperationDetails:
api_version=self._api_version,
)

@distributed_trace
def begin_build_document_classifier(
self,
doc_types: Mapping[str, ClassifierDocumentTypeDetails],
*,
classifier_id: Optional[str] = None,
description: Optional[str] = None,
**kwargs: Any
) -> DocumentModelAdministrationLROPoller[DocumentClassifierDetails]:
"""Build a document classifier. For more information on how to build and train
a custom classifier model, see https://aka.ms/azsdk/formrecognizer/buildclassifiermodel.
:param doc_types: Required. Mapping of document types to classify against.
:keyword str classifier_id: Unique document classifier name.
If not specified, a classifier ID will be created for you.
:keyword str description: Document classifier description.
:return: An instance of an DocumentModelAdministrationLROPoller. Call `result()` on the poller
object to return a :class:`~azure.ai.formrecognizer.DocumentClassifierDetails`.
:rtype: ~azure.ai.formrecognizer.DocumentModelAdministrationLROPoller[DocumentClassifierDetails]
:raises ~azure.core.exceptions.HttpResponseError:
.. versionadded:: 2023-02-28-preview
The *begin_build_document_classifier* client method.
.. admonition:: Example:
.. literalinclude:: ../samples/v3.2/sample_build_classifier.py
:start-after: [START build_classifier]
:end-before: [END build_classifier]
:language: python
:dedent: 4
:caption: Build a document classifier.
"""
def callback(raw_response, _, headers): # pylint: disable=unused-argument
op_response = \
self._deserialize(self._generated_models.DocumentClassifierBuildOperationDetails, raw_response)
model_info = self._deserialize(self._generated_models.DocumentClassifierDetails, op_response.result)
return DocumentClassifierDetails._from_generated(model_info)

if self._api_version == DocumentAnalysisApiVersion.V2022_08_31:
raise ValueError("Method 'begin_build_document_classifier()' is only available for API version "
"V2023_02_28_PREVIEW and later")
cls = kwargs.pop("cls", callback)
continuation_token = kwargs.pop("continuation_token", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
if classifier_id is None:
classifier_id = str(uuid.uuid4())

return self._client.document_classifiers.begin_build_classifier(
build_request=self._generated_models.BuildDocumentClassifierRequest(
classifier_id=classifier_id,
description=description,
doc_types=doc_types,
),
cls=cls,
continuation_token=continuation_token,
polling=LROBasePolling(
timeout=polling_interval, lro_algorithms=[DocumentModelAdministrationPolling()], **kwargs
),
**kwargs
)

@distributed_trace
def get_document_classifier(self, classifier_id: str, **kwargs: Any) -> DocumentClassifierDetails:
"""Get a document classifier by its ID.
:param str classifier_id: Classifier identifier.
:return: DocumentClassifierDetails
:rtype: ~azure.ai.formrecognizer.DocumentClassifierDetails
:raises ~azure.core.exceptions.HttpResponseError or ~azure.core.exceptions.ResourceNotFoundError:
.. versionadded:: 2023-02-28-preview
The *get_document_classifier* client method.
.. admonition:: Example:
.. literalinclude:: ../samples/v3.2/sample_manage_classifiers.py
:start-after: [START get_document_classifier]
:end-before: [END get_document_classifier]
:language: python
:dedent: 4
:caption: Get a classifier by its ID.
"""

if not classifier_id:
raise ValueError("classifier_id cannot be None or empty.")

if self._api_version == DocumentAnalysisApiVersion.V2022_08_31:
raise ValueError("Method 'get_document_classifier()' is only available for API version "
"V2023_02_28_PREVIEW and later")
response = self._client.document_classifiers.get_classifier(classifier_id=classifier_id, **kwargs)
return DocumentClassifierDetails._from_generated(response)

@distributed_trace
def list_document_classifiers(self, **kwargs: Any) -> ItemPaged[DocumentClassifierDetails]:
"""List information for each document classifier, including its classifier ID,
description, and when it was created.
:return: Pageable of DocumentClassifierDetails.
:rtype: ~azure.core.paging.ItemPaged[DocumentClassifierDetails]
:raises ~azure.core.exceptions.HttpResponseError:
.. versionadded:: 2023-02-28-preview
The *list_document_classifiers* client method.
.. admonition:: Example:
.. literalinclude:: ../samples/v3.2/sample_manage_classifiers.py
:start-after: [START list_document_classifiers]
:end-before: [END list_document_classifiers]
:language: python
:dedent: 4
:caption: List all classifiers that were built successfully under the Form Recognizer resource.
"""

if self._api_version == DocumentAnalysisApiVersion.V2022_08_31:
raise ValueError("Method 'list_document_classifiers()' is only available for API version "
"V2023_02_28_PREVIEW and later")
return self._client.document_classifiers.list_classifiers( # type: ignore
cls=kwargs.pop(
"cls",
lambda objs: [DocumentClassifierDetails._from_generated(x) for x in objs],
),
**kwargs
)

@distributed_trace
def delete_document_classifier(self, classifier_id: str, **kwargs: Any) -> None:
"""Delete a document classifier.
:param str classifier_id: Classifier identifier.
:rtype: None
:raises ~azure.core.exceptions.HttpResponseError or ~azure.core.exceptions.ResourceNotFoundError:
.. versionadded:: 2023-02-28-preview
The *delete_document_classifier* client method.
.. admonition:: Example:
.. literalinclude:: ../samples/v3.2/sample_manage_classifiers.py
:start-after: [START delete_document_classifier]
:end-before: [END delete_document_classifier]
:language: python
:dedent: 4
:caption: Delete a classifier.
"""

if not classifier_id:
raise ValueError("classifier_id cannot be None or empty.")

if self._api_version == DocumentAnalysisApiVersion.V2022_08_31:
raise ValueError("Method 'delete_document_classifier()' is only available for API version "
"V2023_02_28_PREVIEW and later")
return self._client.document_classifiers.delete_classifier(classifier_id=classifier_id, **kwargs)

def get_document_analysis_client(self, **kwargs: Any) -> DocumentAnalysisClient:
"""Get an instance of a DocumentAnalysisClient from DocumentModelAdministrationClient.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# --------------------------------------------------------------------------
from typing import Any, AsyncIterable, Callable, Dict, IO, Optional, TypeVar, Union, cast

from .....aio._async_polling import AsyncDocumentModelAdministrationClientLROPoller
from azure.core.async_paging import AsyncItemPaged, AsyncList
from azure.core.exceptions import ClientAuthenticationError, HttpResponseError, ResourceExistsError, ResourceNotFoundError, map_error
from azure.core.pipeline import PipelineResponse
Expand Down Expand Up @@ -104,7 +105,7 @@ async def begin_build_classifier( # pylint: disable=inconsistent-return-stateme
self,
build_request: _models.BuildDocumentClassifierRequest,
**kwargs: Any
) -> AsyncLROPoller[None]:
) -> AsyncDocumentModelAdministrationClientLROPoller[None]:
"""Build document classifier.
Builds a custom document classifier.
Expand All @@ -120,8 +121,9 @@ async def begin_build_classifier( # pylint: disable=inconsistent-return-stateme
:paramtype polling: bool or ~azure.core.polling.AsyncPollingMethod
:keyword int polling_interval: Default waiting time between two polls for LRO operations if no
Retry-After header is present.
:return: An instance of AsyncLROPoller that returns either None or the result of cls(response)
:rtype: ~azure.core.polling.AsyncLROPoller[None]
:return: An instance of AsyncDocumentModelAdministrationClientLROPoller that returns either
None or the result of cls(response)
:rtype: ~.....aio._async_polling.AsyncDocumentModelAdministrationClientLROPoller[None]
:raises: ~azure.core.exceptions.HttpResponseError
"""
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
Expand Down Expand Up @@ -167,13 +169,13 @@ def get_long_running_output(pipeline_response):
elif polling is False: polling_method = cast(AsyncPollingMethod, AsyncNoPolling())
else: polling_method = polling
if cont_token:
return AsyncLROPoller.from_continuation_token(
return AsyncDocumentModelAdministrationClientLROPoller.from_continuation_token(
polling_method=polling_method,
continuation_token=cont_token,
client=self._client,
deserialization_callback=get_long_running_output
)
return AsyncLROPoller(self._client, raw_result, get_long_running_output, polling_method)
return AsyncDocumentModelAdministrationClientLROPoller(self._client, raw_result, get_long_running_output, polling_method)

begin_build_classifier.metadata = {'url': "/documentClassifiers:build"} # type: ignore

Expand Down
Loading

0 comments on commit 289f3be

Please sign in to comment.