-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(automl): add support for image classification, image object detection, text classification, text extraction; add batch_predict
; add deploy_model
, undeploy_model
, export_model
; add annotation specs (via synth)
#9628
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,80 @@ | |
import enum | ||
|
||
|
||
class ClassificationType(enum.IntEnum): | ||
""" | ||
Type of the classification problem. | ||
|
||
Attributes: | ||
CLASSIFICATION_TYPE_UNSPECIFIED (int): An un-set value of this enum. | ||
MULTICLASS (int): At most one label is allowed per example. | ||
MULTILABEL (int): Multiple labels are allowed for one example. | ||
""" | ||
|
||
CLASSIFICATION_TYPE_UNSPECIFIED = 0 | ||
MULTICLASS = 1 | ||
MULTILABEL = 2 | ||
|
||
|
||
class Document(object): | ||
class Layout(object): | ||
class TextSegmentType(enum.IntEnum): | ||
""" | ||
The type of TextSegment in the context of the original document. | ||
|
||
Attributes: | ||
TEXT_SEGMENT_TYPE_UNSPECIFIED (int): Should not be used. | ||
TOKEN (int): The text segment is a token. e.g. word. | ||
PARAGRAPH (int): The text segment is a paragraph. | ||
FORM_FIELD (int): The text segment is a form field. | ||
FORM_FIELD_NAME (int): The text segment is the name part of a form field. It will be treated as | ||
child of another FORM\_FIELD TextSegment if its span is subspan of | ||
another TextSegment with type FORM\_FIELD. | ||
FORM_FIELD_CONTENTS (int): The text segment is the text content part of a form field. It will be | ||
treated as child of another FORM\_FIELD TextSegment if its span is | ||
subspan of another TextSegment with type FORM\_FIELD. | ||
TABLE (int): The text segment is a whole table, including headers, and all rows. | ||
TABLE_HEADER (int): The text segment is a table's headers. It will be treated as child of | ||
another TABLE TextSegment if its span is subspan of another TextSegment | ||
with type TABLE. | ||
TABLE_ROW (int): The text segment is a row in table. It will be treated as child of | ||
another TABLE TextSegment if its span is subspan of another TextSegment | ||
with type TABLE. | ||
TABLE_CELL (int): The text segment is a cell in table. It will be treated as child of | ||
another TABLE\_ROW TextSegment if its span is subspan of another | ||
TextSegment with type TABLE\_ROW. | ||
""" | ||
|
||
TEXT_SEGMENT_TYPE_UNSPECIFIED = 0 | ||
TOKEN = 1 | ||
PARAGRAPH = 2 | ||
FORM_FIELD = 3 | ||
FORM_FIELD_NAME = 4 | ||
FORM_FIELD_CONTENTS = 5 | ||
TABLE = 6 | ||
TABLE_HEADER = 7 | ||
TABLE_ROW = 8 | ||
TABLE_CELL = 9 | ||
|
||
|
||
class DocumentDimensions(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add |
||
class DocumentDimensionUnit(enum.IntEnum): | ||
""" | ||
Unit of the document dimension. | ||
|
||
Attributes: | ||
DOCUMENT_DIMENSION_UNIT_UNSPECIFIED (int): Should not be used. | ||
INCH (int): Document dimension is measured in inches. | ||
CENTIMETER (int): Document dimension is measured in centimeters. | ||
POINT (int): Document dimension is measured in points. 72 points = 1 inch. | ||
""" | ||
|
||
DOCUMENT_DIMENSION_UNIT_UNSPECIFIED = 0 | ||
INCH = 1 | ||
CENTIMETER = 2 | ||
POINT = 3 | ||
|
||
|
||
class Model(object): | ||
class DeploymentState(enum.IntEnum): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,14 +26,18 @@ | |
import google.api_core.gapic_v1.method | ||
import google.api_core.gapic_v1.routing_header | ||
import google.api_core.grpc_helpers | ||
import google.api_core.operation | ||
import google.api_core.operations_v1 | ||
import google.api_core.path_template | ||
import grpc | ||
|
||
from google.cloud.automl_v1.gapic import enums | ||
from google.cloud.automl_v1.gapic import prediction_service_client_config | ||
from google.cloud.automl_v1.gapic.transports import prediction_service_grpc_transport | ||
from google.cloud.automl_v1.proto import annotation_spec_pb2 | ||
from google.cloud.automl_v1.proto import data_items_pb2 | ||
from google.cloud.automl_v1.proto import dataset_pb2 | ||
from google.cloud.automl_v1.proto import image_pb2 | ||
from google.cloud.automl_v1.proto import io_pb2 | ||
from google.cloud.automl_v1.proto import model_evaluation_pb2 | ||
from google.cloud.automl_v1.proto import model_pb2 | ||
|
@@ -222,8 +226,18 @@ def predict( | |
returned in the response. Available for following ML problems, and their | ||
expected request payloads: | ||
|
||
- Image Classification - Image in .JPEG, .GIF or .PNG format, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add online predictions for Image Classification, Image Object Detection, Text Classification, Text Extraction, Text Sentiment (unclear if this feature was added or if this is just a docstring update) |
||
image\_bytes up to 30MB. | ||
- Image Object Detection - Image in .JPEG, .GIF or .PNG format, | ||
image\_bytes up to 30MB. | ||
- Text Classification - TextSnippet, content up to 60,000 characters, | ||
UTF-8 encoded. | ||
- Text Extraction - TextSnippet, content up to 30,000 characters, UTF-8 | ||
NFC encoded. | ||
- Translation - TextSnippet, content up to 25,000 characters, UTF-8 | ||
encoded. | ||
- Text Sentiment - TextSnippet, content up 500 characters, UTF-8 | ||
encoded. | ||
|
||
Example: | ||
>>> from google.cloud import automl_v1 | ||
|
@@ -246,6 +260,19 @@ def predict( | |
message :class:`~google.cloud.automl_v1.types.ExamplePayload` | ||
params (dict[str -> str]): Additional domain-specific parameters, any string must be up to 25000 | ||
characters long. | ||
|
||
- For Image Classification: | ||
|
||
``score_threshold`` - (float) A value from 0.0 to 1.0. When the model | ||
makes predictions for an image, it will only produce results that | ||
have at least this confidence score. The default is 0.5. | ||
|
||
- For Image Object Detection: ``score_threshold`` - (float) When Model | ||
detects objects on the image, it will only produce bounding boxes | ||
which have at least this confidence score. Value in 0 to 1 range, | ||
default is 0.5. ``max_bounding_box_count`` - (int64) No more than | ||
this number of bounding boxes will be returned in the response. | ||
Default is 100, the requested value may be limited by server. | ||
retry (Optional[google.api_core.retry.Retry]): A retry object used | ||
to retry requests. If ``None`` is specified, requests will | ||
be retried using a default configuration. | ||
|
@@ -295,3 +322,142 @@ def predict( | |
return self._inner_api_calls["predict"]( | ||
request, retry=retry, timeout=timeout, metadata=metadata | ||
) | ||
|
||
def batch_predict( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add |
||
self, | ||
name, | ||
input_config, | ||
output_config, | ||
params=None, | ||
retry=google.api_core.gapic_v1.method.DEFAULT, | ||
timeout=google.api_core.gapic_v1.method.DEFAULT, | ||
metadata=None, | ||
): | ||
""" | ||
Perform a batch prediction. Unlike the online ``Predict``, batch | ||
prediction result won't be immediately available in the response. | ||
Instead, a long running operation object is returned. User can poll the | ||
operation result via ``GetOperation`` method. Once the operation is | ||
done, ``BatchPredictResult`` is returned in the ``response`` field. | ||
Available for following ML problems: | ||
|
||
- Image Classification | ||
- Image Object Detection | ||
- Text Extraction | ||
|
||
Example: | ||
>>> from google.cloud import automl_v1 | ||
>>> | ||
>>> client = automl_v1.PredictionServiceClient() | ||
>>> | ||
>>> name = client.model_path('[PROJECT]', '[LOCATION]', '[MODEL]') | ||
>>> | ||
>>> # TODO: Initialize `input_config`: | ||
>>> input_config = {} | ||
>>> | ||
>>> # TODO: Initialize `output_config`: | ||
>>> output_config = {} | ||
>>> | ||
>>> response = client.batch_predict(name, input_config, output_config) | ||
>>> | ||
>>> def callback(operation_future): | ||
... # Handle result. | ||
... result = operation_future.result() | ||
>>> | ||
>>> response.add_done_callback(callback) | ||
>>> | ||
>>> # Handle metadata. | ||
>>> metadata = response.metadata() | ||
|
||
Args: | ||
name (str): Name of the model requested to serve the batch prediction. | ||
input_config (Union[dict, ~google.cloud.automl_v1.types.BatchPredictInputConfig]): Required. The input configuration for batch prediction. | ||
|
||
If a dict is provided, it must be of the same form as the protobuf | ||
message :class:`~google.cloud.automl_v1.types.BatchPredictInputConfig` | ||
output_config (Union[dict, ~google.cloud.automl_v1.types.BatchPredictOutputConfig]): Required. The Configuration specifying where output predictions should | ||
be written. | ||
|
||
If a dict is provided, it must be of the same form as the protobuf | ||
message :class:`~google.cloud.automl_v1.types.BatchPredictOutputConfig` | ||
params (dict[str -> str]): Additional domain-specific parameters for the predictions, any string | ||
must be up to 25000 characters long. | ||
|
||
- For Text Classification: | ||
|
||
``score_threshold`` - (float) A value from 0.0 to 1.0. When the model | ||
makes predictions for a text snippet, it will only produce results | ||
that have at least this confidence score. The default is 0.5. | ||
|
||
- For Image Classification: | ||
|
||
``score_threshold`` - (float) A value from 0.0 to 1.0. When the model | ||
makes predictions for an image, it will only produce results that | ||
have at least this confidence score. The default is 0.5. | ||
|
||
- For Image Object Detection: | ||
|
||
``score_threshold`` - (float) When Model detects objects on the | ||
image, it will only produce bounding boxes which have at least this | ||
confidence score. Value in 0 to 1 range, default is 0.5. | ||
``max_bounding_box_count`` - (int64) No more than this number of | ||
bounding boxes will be produced per image. Default is 100, the | ||
requested value may be limited by server. | ||
retry (Optional[google.api_core.retry.Retry]): A retry object used | ||
to retry requests. If ``None`` is specified, requests will | ||
be retried using a default configuration. | ||
timeout (Optional[float]): The amount of time, in seconds, to wait | ||
for the request to complete. Note that if ``retry`` is | ||
specified, the timeout applies to each individual attempt. | ||
metadata (Optional[Sequence[Tuple[str, str]]]): Additional metadata | ||
that is provided to the method. | ||
|
||
Returns: | ||
A :class:`~google.cloud.automl_v1.types._OperationFuture` instance. | ||
|
||
Raises: | ||
google.api_core.exceptions.GoogleAPICallError: If the request | ||
failed for any reason. | ||
google.api_core.exceptions.RetryError: If the request failed due | ||
to a retryable error and retry attempts failed. | ||
ValueError: If the parameters are invalid. | ||
""" | ||
# Wrap the transport method to add retry and timeout logic. | ||
if "batch_predict" not in self._inner_api_calls: | ||
self._inner_api_calls[ | ||
"batch_predict" | ||
] = google.api_core.gapic_v1.method.wrap_method( | ||
self.transport.batch_predict, | ||
default_retry=self._method_configs["BatchPredict"].retry, | ||
default_timeout=self._method_configs["BatchPredict"].timeout, | ||
client_info=self._client_info, | ||
) | ||
|
||
request = prediction_service_pb2.BatchPredictRequest( | ||
name=name, | ||
input_config=input_config, | ||
output_config=output_config, | ||
params=params, | ||
) | ||
if metadata is None: | ||
metadata = [] | ||
metadata = list(metadata) | ||
try: | ||
routing_header = [("name", name)] | ||
except AttributeError: | ||
pass | ||
else: | ||
routing_metadata = google.api_core.gapic_v1.routing_header.to_grpc_metadata( | ||
routing_header | ||
) | ||
metadata.append(routing_metadata) | ||
|
||
operation = self._inner_api_calls["batch_predict"]( | ||
request, retry=retry, timeout=timeout, metadata=metadata | ||
) | ||
return google.api_core.operation.from_gapic( | ||
operation, | ||
self.transport._operations_client, | ||
prediction_service_pb2.BatchPredictResult, | ||
metadata_type=proto_operations_pb2.OperationMetadata, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add
ClasisificationType
enum (classification problem type)