Skip to content

Commit

Permalink
async client (#17366)
Browse files Browse the repository at this point in the history
* initial commit

* update samples

* updates from feedback

* johans feedback

* renaming to use job terminology

* update samples - optional src language

* samples hero scenarios (#16936)

* [samples] added 'batch_translation_async' sample

* [samples] added 'batch_translation_with_storage_async' sample

* [samples] added remianing async samples

* [samples] update file names

* [samples] added self to instance methods

* [samples][async] fix import textanalytics :)

* [samples] fix self. when calling instance methods

* [samples] fixed async check status to use AsyncItemPaged used in Async Client

* [samples] async -> some async operations instead of sync ones

* [samples][async] use async blob operations

* [samples][async] blob download async

* [samples][async] check_documents async

* [samples][async] added some missing await methods

* [async samples] change await time to recommended period

* [samples] updated async samples to comply with new changes

* [client] modify the 'create_translation_job()' method

* [models wrapping 'JobStatusDetail'

* [models wrapping] update client

* [models wrapping] formats mapping [documents, glossaries, storage]

* [models wrapping] batch document input

* [models wrapping] added document status

* [models wrapping] added support for list document status

* [models wrapping] remove unwanted code

* [models wrapping] forgot to add job id

* [model wrapping] list all jobs

* [pr review] extract job id

* [refactor] some refactoring

* [pr reviews] refactor

* [refactor] to_generated in glossary

* [refactor] to generated in StorageTargets

* [refactor] to_generated in BatchDocumentInput

* [refactor] to_generated in Glossary more

* [refactor] make code readable -> storage formats

* [integration] added support for wait_until_done

* [refactor] spelling error

* [integrate sdk] updated "wait_until_done" to use azure core poller instead of busy wait

* [sync client][wait till done] initial poller

* [sync client][wait poller] poller algorithm completed

* [PR comments] wrap model before returning

* [pr reviews] adjust parameter type to list

* [pr reviews] some linter checks

* [pr reviews] remove static methods, and fix private  naming convention in model functions

* [pr reviews] fix return type value

* [PR reviews] implement unimplemented inherited abstract method

* [pr reviews] fix python private method naming in client

* [pr review] handle case with no error happening

* [pr reviews] renaming stuff :)

* [pr reviews] renaming more stuff :)

* [pr reviews] refactor private functions and linting options

* [bug fix] poller -> handle non-standard status responses

* [#17289] remove supported storage sources method

* [bug fix] private method name

* [linting] _models.py

* [linting] client

* [linting] models

* [bug fix] wait until done pipeline response

* [linting] more client linting

* [linting] more client

* [linting] more in models

* more linting

* [linting] more linting things

* [linting] client

* [linting] more

* [linting] more

* [linting] for goodness sakes!

* [bug fix] models -> StorageTarget -> handle when user doesn't pass glossaries

* [linting] line-too-long

* [async client] submit translation job

* [async client] job status

* [async client] more function support added

* [async client] added document status

* [async client] add remaining methods

* [async client] use async poller

* [linting] asycn client

* [linting] more

Co-authored-by: Krista Pratico <[email protected]>
  • Loading branch information
Mohamed Shaban and kristapratico authored Mar 18, 2021
1 parent 6a7cedd commit b6e2230
Showing 1 changed file with 98 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,22 @@
from typing import Union, Any, List, TYPE_CHECKING
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.core.tracing.decorator import distributed_trace
from azure.core.polling import AsyncLROPoller
from azure.core.polling.async_base_polling import AsyncLROBasePolling
from azure.core.async_paging import AsyncItemPaged
from .._generated.aio import BatchDocumentTranslationClient as _BatchDocumentTranslationClient
from .._user_agent import USER_AGENT
from .._models import JobStatusDetail, DocumentStatusDetail, BatchDocumentInput, FileFormat
from .._generated.models import (
BatchStatusDetail as _BatchStatusDetail,
)
from .._models import (
JobStatusDetail,
BatchDocumentInput,
FileFormat,
DocumentStatusDetail
)
from .._helpers import get_authentication_policy
from .._polling import TranslationPolling
if TYPE_CHECKING:
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.credentials import AzureKeyCredential
Expand Down Expand Up @@ -58,12 +69,27 @@ async def create_translation_job(self, batch, **kwargs):
:rtype: JobStatusDetail
"""

return await self._client.document_translation.begin_submit_batch_request(
inputs=batch,
# submit translation job
response_headers = await self._client.document_translation._submit_batch_request_initial( # pylint: disable=protected-access
# pylint: disable=protected-access
inputs=BatchDocumentInput._to_generated_list(batch),
cls=lambda pipeline_response, _, response_headers: response_headers,
polling=True,
**kwargs
)

def get_job_id(response_headers):
# extract job id.
operation_location_header = response_headers['Operation-Location']
return operation_location_header.split('/')[-1]

# get job id from response header
job_id = get_job_id(response_headers)

# get job status
return await self.get_job_status(job_id)


@distributed_trace_async
async def get_job_status(self, job_id, **kwargs):
# type: (str, **Any) -> JobStatusDetail
Expand All @@ -74,7 +100,9 @@ async def get_job_status(self, job_id, **kwargs):
:rtype: ~azure.ai.documenttranslation.JobStatusDetail
"""

return await self._client.document_translation.get_operation_status(job_id, **kwargs)
job_status = await self._client.document_translation.get_operation_status(job_id, **kwargs)
# pylint: disable=protected-access
return JobStatusDetail._from_generated(job_status)

@distributed_trace_async
async def cancel_job(self, job_id, **kwargs):
Expand All @@ -98,7 +126,26 @@ async def wait_until_done(self, job_id, **kwargs):
:return: JobStatusDetail
:rtype: JobStatusDetail
"""
pass # pylint: disable=unnecessary-pass
pipeline_response = await self._client.document_translation.get_operation_status(
job_id,
cls=lambda pipeline_response, _, response_headers: pipeline_response
)

def callback(raw_response):
detail = self._client._deserialize(_BatchStatusDetail, raw_response) # pylint: disable=protected-access
return JobStatusDetail._from_generated(detail) # pylint: disable=protected-access

poller = AsyncLROPoller(
client=self._client._client, # pylint: disable=protected-access
initial_response=pipeline_response,
deserialization_callback=callback,
polling_method=AsyncLROBasePolling(
timeout=30,
lro_algorithms=[TranslationPolling()],
**kwargs
),
)
return poller.result()

@distributed_trace
def list_submitted_jobs(self, **kwargs):
Expand All @@ -109,7 +156,24 @@ def list_submitted_jobs(self, **kwargs):
:keyword int skip:
:rtype: ~azure.core.polling.AsyncItemPaged[JobStatusDetail]
"""
return self._client.document_translation.get_operations(**kwargs)
skip = kwargs.pop('skip', None)
results_per_page = kwargs.pop('results_per_page', None)

def _convert_from_generated_model(generated_model):
# pylint: disable=protected-access
return JobStatusDetail._from_generated(generated_model)

model_conversion_function = kwargs.pop(
"cls",
lambda job_statuses: [_convert_from_generated_model(job_status) for job_status in job_statuses]
)

return self._client.document_translation.get_operations(
top=results_per_page,
skip=skip,
cls=model_conversion_function,
**kwargs
)

@distributed_trace
def list_documents_statuses(self, job_id, **kwargs):
Expand All @@ -122,8 +186,26 @@ def list_documents_statuses(self, job_id, **kwargs):
:keyword int skip:
:rtype: ~azure.core.paging.AsyncItemPaged[DocumentStatusDetail]
"""
skip = kwargs.pop('skip', None)
results_per_page = kwargs.pop('results_per_page', None)

def _convert_from_generated_model(generated_model):
# pylint: disable=protected-access
return DocumentStatusDetail._from_generated(generated_model)

model_conversion_function = kwargs.pop(
"cls",
lambda doc_statuses: [_convert_from_generated_model(doc_status) for doc_status in doc_statuses]
)

return self._client.document_translation.get_operation_documents_status(
id=job_id,
top=results_per_page,
skip=skip,
cls=model_conversion_function,
**kwargs
)

return self._client.document_translation.get_operation_documents_status(job_id, **kwargs)

@distributed_trace_async
async def get_document_status(self, job_id, document_id, **kwargs):
Expand All @@ -136,16 +218,10 @@ async def get_document_status(self, job_id, document_id, **kwargs):
:type document_id: str
:rtype: ~azure.ai.documenttranslation.DocumentStatusDetail
"""
return await self._client.document_translation.get_document_status(job_id, document_id, **kwargs)
document_status = await self._client.document_translation.get_document_status(job_id, document_id, **kwargs)
# pylint: disable=protected-access
return DocumentStatusDetail._from_generated(document_status)

@distributed_trace_async
async def get_supported_storage_sources(self, **kwargs):
# type: (**Any) -> List[str]
"""
:rtype: list[str]
"""
return await self._client.document_translation.get_document_storage_source(**kwargs)

@distributed_trace_async
async def get_supported_glossary_formats(self, **kwargs):
Expand All @@ -154,8 +230,9 @@ async def get_supported_glossary_formats(self, **kwargs):
:rtype: list[FileFormat]
"""

return await self._client.document_translation.get_glossary_formats(**kwargs)
glossary_formats = await self._client.document_translation.get_glossary_formats(**kwargs)
# pylint: disable=protected-access
return FileFormat._from_generated_list(glossary_formats.value)

@distributed_trace_async
async def get_supported_document_formats(self, **kwargs):
Expand All @@ -164,5 +241,6 @@ async def get_supported_document_formats(self, **kwargs):
:rtype: list[FileFormat]
"""

return await self._client.document_translation.get_document_formats(**kwargs)
document_formats = await self._client.document_translation.get_document_formats(**kwargs)
# pylint: disable=protected-access
return FileFormat._from_generated_list(document_formats.value)

0 comments on commit b6e2230

Please sign in to comment.