From c230cf630cde7c5be0784670d2d430060532d0e5 Mon Sep 17 00:00:00 2001 From: Thomas Schultz Date: Wed, 4 Jan 2017 22:23:55 -0500 Subject: [PATCH] Add gax support for entity annotations. --- system_tests/vision.py | 15 +++ vision/google/cloud/vision/_gax.py | 23 ++++ vision/google/cloud/vision/_http.py | 3 +- vision/google/cloud/vision/annotations.py | 49 ++++++++ vision/google/cloud/vision/entity.py | 20 ++++ vision/google/cloud/vision/geometry.py | 27 +++++ vision/google/cloud/vision/image.py | 4 +- vision/unit_tests/test__gax.py | 20 ++++ vision/unit_tests/test_annotations.py | 131 ++++++++++++++++++++++ vision/unit_tests/test_client.py | 3 +- vision/unit_tests/test_entity.py | 31 +++++ 11 files changed, 321 insertions(+), 5 deletions(-) create mode 100644 vision/unit_tests/test_annotations.py diff --git a/system_tests/vision.py b/system_tests/vision.py index 277bcd9d657e5..c5626d4a8f874 100644 --- a/system_tests/vision.py +++ b/system_tests/vision.py @@ -190,6 +190,7 @@ def _assert_face(self, face): def test_detect_faces_content(self): client = Config.CLIENT + client._use_gax = False with open(FACE_FILE, 'rb') as image_file: image = client.image(content=image_file.read()) faces = image.detect_faces() @@ -208,6 +209,7 @@ def test_detect_faces_gcs(self): source_uri = 'gs://%s/%s' % (bucket_name, blob_name) client = Config.CLIENT + client._use_gax = False image = client.image(source_uri=source_uri) faces = image.detect_faces() self.assertEqual(len(faces), 5) @@ -216,6 +218,7 @@ def test_detect_faces_gcs(self): def test_detect_faces_filename(self): client = Config.CLIENT + client._use_gax = False image = client.image(filename=FACE_FILE) faces = image.detect_faces() self.assertEqual(len(faces), 5) @@ -310,6 +313,7 @@ def _assert_landmark(self, landmark): def test_detect_landmark_content(self): client = Config.CLIENT + client._use_gax = True with open(LANDMARK_FILE, 'rb') as image_file: image = client.image(content=image_file.read()) landmarks = image.detect_landmarks() @@ -328,6 +332,7 @@ def test_detect_landmark_gcs(self): source_uri = 'gs://%s/%s' % (bucket_name, blob_name) client = Config.CLIENT + client._use_gax = True image = client.image(source_uri=source_uri) landmarks = image.detect_landmarks() self.assertEqual(len(landmarks), 1) @@ -336,6 +341,7 @@ def test_detect_landmark_gcs(self): def test_detect_landmark_filename(self): client = Config.CLIENT + client._use_gax = True image = client.image(filename=LANDMARK_FILE) landmarks = image.detect_landmarks() self.assertEqual(len(landmarks), 1) @@ -362,6 +368,7 @@ def _assert_safe_search(self, safe_search): def test_detect_safe_search_content(self): client = Config.CLIENT + client._use_gax = False with open(FACE_FILE, 'rb') as image_file: image = client.image(content=image_file.read()) safe_searches = image.detect_safe_search() @@ -380,6 +387,7 @@ def test_detect_safe_search_gcs(self): source_uri = 'gs://%s/%s' % (bucket_name, blob_name) client = Config.CLIENT + client._use_gax = False image = client.image(source_uri=source_uri) safe_searches = image.detect_safe_search() self.assertEqual(len(safe_searches), 1) @@ -388,6 +396,7 @@ def test_detect_safe_search_gcs(self): def test_detect_safe_search_filename(self): client = Config.CLIENT + client._use_gax = False image = client.image(filename=FACE_FILE) safe_searches = image.detect_safe_search() self.assertEqual(len(safe_searches), 1) @@ -423,6 +432,7 @@ def _assert_text(self, text): def test_detect_text_content(self): client = Config.CLIENT + client._use_gax = True with open(TEXT_FILE, 'rb') as image_file: image = client.image(content=image_file.read()) texts = image.detect_text() @@ -441,6 +451,7 @@ def test_detect_text_gcs(self): source_uri = 'gs://%s/%s' % (bucket_name, blob_name) client = Config.CLIENT + client._use_gax = True image = client.image(source_uri=source_uri) texts = image.detect_text() self.assertEqual(len(texts), 9) @@ -449,6 +460,7 @@ def test_detect_text_gcs(self): def test_detect_text_filename(self): client = Config.CLIENT + client._use_gax = True image = client.image(filename=TEXT_FILE) texts = image.detect_text() self.assertEqual(len(texts), 9) @@ -485,6 +497,7 @@ def _assert_properties(self, image_property): def test_detect_properties_content(self): client = Config.CLIENT + client._use_gax = False with open(FACE_FILE, 'rb') as image_file: image = client.image(content=image_file.read()) properties = image.detect_properties() @@ -503,6 +516,7 @@ def test_detect_properties_gcs(self): source_uri = 'gs://%s/%s' % (bucket_name, blob_name) client = Config.CLIENT + client._use_gax = False image = client.image(source_uri=source_uri) properties = image.detect_properties() self.assertEqual(len(properties), 1) @@ -511,6 +525,7 @@ def test_detect_properties_gcs(self): def test_detect_properties_filename(self): client = Config.CLIENT + client._use_gax = False image = client.image(filename=FACE_FILE) properties = image.detect_properties() self.assertEqual(len(properties), 1) diff --git a/vision/google/cloud/vision/_gax.py b/vision/google/cloud/vision/_gax.py index 1820d8a958608..3ac927dbdb3bd 100644 --- a/vision/google/cloud/vision/_gax.py +++ b/vision/google/cloud/vision/_gax.py @@ -19,6 +19,8 @@ from google.cloud._helpers import _to_bytes +from google.cloud.vision.annotations import Annotations + class _GAPICVisionAPI(object): """Vision API for interacting with the gRPC version of Vision. @@ -30,6 +32,27 @@ def __init__(self, client=None): self._client = client self._api = image_annotator_client.ImageAnnotatorClient() + def annotate(self, image, features): + """Annotate images through GAX. + + :type image: :class:`~google.cloud.vision.image.Image` + :param image: Instance of ``Image``. + + :type features: list + :param features: List of :class:`~google.cloud.vision.feature.Feature`. + + :rtype: :class:`~google.cloud.vision.annotations.Annotations` + :returns: Instance of ``Annotations`` with results. + """ + gapic_features = [_to_gapic_feature(feature) for feature in features] + gapic_image = _to_gapic_image(image) + request = image_annotator_pb2.AnnotateImageRequest( + image=gapic_image, features=gapic_features) + requests = [request] + api = self._api + responses = api.batch_annotate_images(requests) + return Annotations.from_pb(responses.responses[0]) + def _to_gapic_feature(feature): """Helper function to convert a ``Feature`` to a gRPC ``Feature``. diff --git a/vision/google/cloud/vision/_http.py b/vision/google/cloud/vision/_http.py index 8bacdf01bb70a..c4948e7065b3e 100644 --- a/vision/google/cloud/vision/_http.py +++ b/vision/google/cloud/vision/_http.py @@ -14,6 +14,7 @@ """HTTP Client for interacting with the Google Cloud Vision API.""" +from google.cloud.vision.annotations import Annotations from google.cloud.vision.feature import Feature @@ -49,7 +50,7 @@ def annotate(self, image, features): api_response = self._connection.api_request( method='POST', path='/images:annotate', data=data) responses = api_response.get('responses') - return responses[0] + return Annotations.from_api_repr(responses[0]) def _make_request(image, features): diff --git a/vision/google/cloud/vision/annotations.py b/vision/google/cloud/vision/annotations.py index bbc14bc1414e4..5ec994f7e61da 100644 --- a/vision/google/cloud/vision/annotations.py +++ b/vision/google/cloud/vision/annotations.py @@ -93,6 +93,55 @@ def from_api_repr(cls, response): _entity_from_response_type(feature_type, annotation)) return cls(**annotations) + @classmethod + def from_pb(cls, response): + """Factory: construct an instance of ``Annotations`` from gRPC response. + + :type response: :class:`~google.cloud.grpc.vision.v1.\ + image_annotator_pb2.AnnotateImageResponse` + :param response: ``AnnotateImageResponse`` from gRPC call. + + :rtype: :class:`~google.cloud.vision.annotations.Annotations` + :returns: ``Annotations`` instance populated from gRPC response. + """ + annotations = _process_image_annotations(response) + return cls(**annotations) + + +def _process_image_annotations(image): + """Helper for processing annotation types from gRPC responses. + + :type image: :class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.\ + AnnotateImageResponse` + :param image: ``AnnotateImageResponse`` from gRPC response. + + :rtype: dict + :returns: Dictionary populated with entities from response. + """ + annotations = {} + annotations['labels'] = _make_entity_from_pb(image.label_annotations) + annotations['landmarks'] = _make_entity_from_pb(image.landmark_annotations) + annotations['logos'] = _make_entity_from_pb(image.logo_annotations) + annotations['texts'] = _make_entity_from_pb(image.text_annotations) + return annotations + + +def _make_entity_from_pb(annotations): + """Create an entity from a gRPC response. + + :type annotations: + :class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.EntityAnnotation` + :param annotations: gRPC instance of ``EntityAnnotation``. + + :rtype: list + :returns: List of ``EntityAnnotation``. + """ + + entities = [] + for annotation in annotations: + entities.append(EntityAnnotation.from_pb(annotation)) + return entities + def _entity_from_response_type(feature_type, results): """Convert a JSON result to an entity type based on the feature. diff --git a/vision/google/cloud/vision/entity.py b/vision/google/cloud/vision/entity.py index 80cb0a74f5fc7..ca699cd65ce8d 100644 --- a/vision/google/cloud/vision/entity.py +++ b/vision/google/cloud/vision/entity.py @@ -70,6 +70,26 @@ def from_api_repr(cls, response): return cls(bounds, description, locale, locations, mid, score) + @classmethod + def from_pb(cls, response): + """Factory: construct entity from Vision gRPC response. + + :type response: :class:`~google.cloud.grpc.vision.v1.\ + image_annotator_pb2.AnnotateImageResponse` + :param response: gRPC response from Vision API with entity data. + + :rtype: :class:`~google.cloud.vision.entity.EntityAnnotation` + :returns: Instance of ``EntityAnnotation``. + """ + bounds = Bounds.from_pb(response.bounding_poly) + description = response.description + locale = response.locale + locations = [LocationInformation.from_pb(location) + for location in response.locations] + mid = response.mid + score = response.score + return cls(bounds, description, locale, locations, mid, score) + @property def bounds(self): """Bounding polygon of detected image feature. diff --git a/vision/google/cloud/vision/geometry.py b/vision/google/cloud/vision/geometry.py index 862ce60c5ef8f..23115f92ddb98 100644 --- a/vision/google/cloud/vision/geometry.py +++ b/vision/google/cloud/vision/geometry.py @@ -41,6 +41,20 @@ def from_api_repr(cls, response_vertices): vertex in response_vertices.get('vertices', [])] return cls(vertices) + @classmethod + def from_pb(cls, response_vertices): + """Factory: construct BoundsBase instance from Vision gRPC response. + + :type response_vertices: :class:`~google.cloud.grpc.vision.v1.\ + geometry_pb2.BoundingPoly` + :param response_vertices: List of vertices. + + :rtype: :class:`~google.cloud.vision.geometry.BoundsBase` or None + :returns: Instance of ``BoundsBase`` with populated verticies. + """ + return cls([Vertex(vertex.x, vertex.y) + for vertex in response_vertices.vertices]) + @property def vertices(self): """List of vertices. @@ -87,6 +101,19 @@ def from_api_repr(cls, response): longitude = response['latLng']['longitude'] return cls(latitude, longitude) + @classmethod + def from_pb(cls, response): + """Factory: construct location information from Vision gRPC response. + + :type response: :class:`~google.cloud.vision.v1.LocationInfo` + :param response: gRPC response of ``LocationInfo``. + + :rtype: :class:`~google.cloud.vision.geometry.LocationInformation` + :returns: ``LocationInformation`` with populated latitude and + longitude. + """ + return cls(response.lat_lng.latitude, response.lat_lng.longitude) + @property def latitude(self): """Latitude coordinate. diff --git a/vision/google/cloud/vision/image.py b/vision/google/cloud/vision/image.py index f9a429e24e0b4..52f6cfad9869a 100644 --- a/vision/google/cloud/vision/image.py +++ b/vision/google/cloud/vision/image.py @@ -19,7 +19,6 @@ from google.cloud._helpers import _to_bytes from google.cloud._helpers import _bytes_to_unicode -from google.cloud.vision.annotations import Annotations from google.cloud.vision.feature import Feature from google.cloud.vision.feature import FeatureTypes @@ -109,8 +108,7 @@ def _detect_annotation(self, features): :class:`~google.cloud.vision.color.ImagePropertiesAnnotation`, :class:`~google.cloud.vision.sage.SafeSearchAnnotation`, """ - results = self.client._vision_api.annotate(self, features) - return Annotations.from_api_repr(results) + return self.client._vision_api.annotate(self, features) def detect(self, features): """Detect multiple feature types. diff --git a/vision/unit_tests/test__gax.py b/vision/unit_tests/test__gax.py index 8875928443956..b0ab2eace1f56 100644 --- a/vision/unit_tests/test__gax.py +++ b/vision/unit_tests/test__gax.py @@ -32,6 +32,26 @@ def test_ctor(self): api = self._make_one(client) self.assertIs(api._client, client) + def test_annotation(self): + from google.cloud.vision.feature import Feature + from google.cloud.vision.feature import FeatureTypes + from google.cloud.vision.image import Image + + client = mock.Mock() + feature = Feature(FeatureTypes.LABEL_DETECTION, 5) + image_content = b'abc 1 2 3' + image = Image(client, content=image_content) + api = self._make_one(client) + + api._api = mock.Mock() + mock_response = mock.Mock(responses=['mock response data']) + api._api.batch_annotate_images.return_value = mock_response + + with mock.patch('google.cloud.vision._gax.Annotations') as mock_anno: + api.annotate(image, [feature]) + mock_anno.from_pb.assert_called_with('mock response data') + api._api.batch_annotate_images.assert_called() + class TestToGAPICFeature(unittest.TestCase): def _call_fut(self, feature): diff --git a/vision/unit_tests/test_annotations.py b/vision/unit_tests/test_annotations.py new file mode 100644 index 0000000000000..609b9fcf6213b --- /dev/null +++ b/vision/unit_tests/test_annotations.py @@ -0,0 +1,131 @@ +# Copyright 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + + +class TestAnnotations(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.vision.annotations import Annotations + + return Annotations + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def test_ctor(self): + annotations = self._make_one( + faces=[True], properties=[True], labels=[True], landmarks=[True], + logos=[True], safe_searches=[True], texts=[True]) + self.assertTrue(annotations.faces[0]) + self.assertTrue(annotations.properties[0]) + self.assertTrue(annotations.labels[0]) + self.assertTrue(annotations.landmarks[0]) + self.assertTrue(annotations.logos[0]) + self.assertTrue(annotations.safe_searches[0]) + self.assertTrue(annotations.texts[0]) + + def test_from_pb(self): + from google.cloud.grpc.vision.v1 import image_annotator_pb2 + + image_response = image_annotator_pb2.AnnotateImageResponse() + annotations = self._make_one().from_pb(image_response) + self.assertEquals(annotations.labels, []) + self.assertEquals(annotations.logos, []) + self.assertEquals(annotations.faces, ()) + self.assertEquals(annotations.landmarks, []) + self.assertEquals(annotations.texts, []) + self.assertEquals(annotations.safe_searches, ()) + self.assertEquals(annotations.properties, ()) + + +class TestMakeEntityFromPB(unittest.TestCase): + def _call_fut(self, annotations): + from google.cloud.vision.annotations import _make_entity_from_pb + return _make_entity_from_pb(annotations) + + def test_make_entity_from_pb(self): + from google.cloud.grpc.vision.v1 import image_annotator_pb2 + + description = 'testing 1 2 3' + locale = 'US' + mid = 'm/w/45342234' + score = 0.235434231 + entity_annotation = image_annotator_pb2.EntityAnnotation() + entity_annotation.mid = mid + entity_annotation.locale = locale + entity_annotation.description = description + entity_annotation.score = score + entity_annotation.bounding_poly.vertices.add() + entity_annotation.bounding_poly.vertices[0].x = 1 + entity_annotation.bounding_poly.vertices[0].y = 2 + entity_annotation.locations.add() + entity_annotation.locations[0].lat_lng.latitude = 1.0 + entity_annotation.locations[0].lat_lng.longitude = 2.0 + + entities = self._call_fut([entity_annotation]) + self.assertEqual(len(entities), 1) + entity = entities[0] + self.assertEqual(entity.description, description) + self.assertEqual(entity.mid, mid) + self.assertEqual(entity.locale, locale) + self.assertEqual(entity.score, score) + self.assertEqual(entity.bounds.vertices[0].x_coordinate, 1) + self.assertEqual(entity.bounds.vertices[0].y_coordinate, 2) + self.assertEqual(entity.locations[0].latitude, 1.0) + self.assertEqual(entity.locations[0].longitude, 2.0) + + +class TestProcessImageAnnotations(unittest.TestCase): + def _call_fut(self, image): + from google.cloud.vision.annotations import _process_image_annotations + + return _process_image_annotations(image) + + def test_process_image_annotations(self): + from google.cloud.grpc.vision.v1 import image_annotator_pb2 + + description = 'testing 1 2 3' + locale = 'US' + mid = 'm/w/45342234' + score = 0.235434231 + entity_annotation = image_annotator_pb2.EntityAnnotation() + entity_annotation.mid = mid + entity_annotation.locale = locale + entity_annotation.description = description + entity_annotation.score = score + entity_annotation.bounding_poly.vertices.add() + entity_annotation.bounding_poly.vertices[0].x = 1 + entity_annotation.bounding_poly.vertices[0].y = 2 + entity_annotation.locations.add() + entity_annotation.locations[0].lat_lng.latitude = 1.0 + entity_annotation.locations[0].lat_lng.longitude = 2.0 + + image_response = image_annotator_pb2.AnnotateImageResponse( + label_annotations=[entity_annotation]) + print(image_response.label_annotations) + + annotations = self._call_fut(image_response) + self.assertEqual(len(annotations['labels']), 1) + entity = annotations['labels'][0] + + self.assertEqual(entity.description, description) + self.assertEqual(entity.mid, mid) + self.assertEqual(entity.locale, locale) + self.assertEqual(entity.score, score) + self.assertEqual(entity.bounds.vertices[0].x_coordinate, 1) + self.assertEqual(entity.bounds.vertices[0].y_coordinate, 2) + self.assertEqual(entity.locations[0].latitude, 1.0) + self.assertEqual(entity.locations[0].longitude, 2.0) diff --git a/vision/unit_tests/test_client.py b/vision/unit_tests/test_client.py index 44f76e944012d..ccf3609a5ed6f 100644 --- a/vision/unit_tests/test_client.py +++ b/vision/unit_tests/test_client.py @@ -76,6 +76,7 @@ def test_make_http_client(self): self.assertIsInstance(client._vision_api, _HTTPVisionAPI) def test_face_annotation(self): + from google.cloud.vision.annotations import Annotations from google.cloud.vision.feature import Feature, FeatureTypes from unit_tests._fixtures import FACE_DETECTION_RESPONSE @@ -106,7 +107,7 @@ def test_face_annotation(self): self.assertEqual(REQUEST, client._connection._requested[0]['data']) - self.assertTrue('faceAnnotations' in response) + self.assertIsInstance(response, Annotations) def test_image_with_client_gcs_source(self): from google.cloud.vision.image import Image diff --git a/vision/unit_tests/test_entity.py b/vision/unit_tests/test_entity.py index 7ead8388ad27b..d644de925aabf 100644 --- a/vision/unit_tests/test_entity.py +++ b/vision/unit_tests/test_entity.py @@ -32,3 +32,34 @@ def test_logo_annotation(self): self.assertEqual('Brand1', logo.description) self.assertEqual(0.63192177, logo.score) self.assertEqual(162, logo.bounds.vertices[0].y_coordinate) + + def test_logo_pb_annotation(self): + from google.cloud.grpc.vision.v1 import image_annotator_pb2 + + description = 'testing 1 2 3' + locale = 'US' + mid = 'm/w/45342234' + score = 0.235434231 + entity_annotation = image_annotator_pb2.EntityAnnotation() + entity_annotation.mid = mid + entity_annotation.locale = locale + entity_annotation.description = description + entity_annotation.score = score + entity_annotation.bounding_poly.vertices.add() + entity_annotation.bounding_poly.vertices[0].x = 1 + entity_annotation.bounding_poly.vertices[0].y = 2 + entity_annotation.locations.add() + entity_annotation.locations[0].lat_lng.latitude = 1.0 + entity_annotation.locations[0].lat_lng.longitude = 2.0 + + entity_class = self._get_target_class() + entity = entity_class.from_pb(entity_annotation) + + self.assertEqual(entity.description, description) + self.assertEqual(entity.mid, mid) + self.assertEqual(entity.locale, locale) + self.assertEqual(entity.score, score) + self.assertEqual(entity.bounds.vertices[0].x_coordinate, 1) + self.assertEqual(entity.bounds.vertices[0].y_coordinate, 2) + self.assertEqual(entity.locations[0].latitude, 1.0) + self.assertEqual(entity.locations[0].longitude, 2.0)