diff --git a/system_tests/vision.py b/system_tests/vision.py index 277bcd9d657e..d9c25fa4a2f5 100644 --- a/system_tests/vision.py +++ b/system_tests/vision.py @@ -73,6 +73,10 @@ def _assert_likelihood(self, likelihood): Likelihood.VERY_UNLIKELY] self.assertIn(likelihood, levels) + def _maybe_http_skip(self, message): + if not Config.CLIENT._use_gax: + self.skipTest(message) + class TestVisionClientLogo(unittest.TestCase): def setUp(self): @@ -190,6 +194,7 @@ def _assert_face(self, face): def test_detect_faces_content(self): client = Config.CLIENT + self._maybe_http_skip('gRPC is required for face detection.') with open(FACE_FILE, 'rb') as image_file: image = client.image(content=image_file.read()) faces = image.detect_faces() @@ -198,6 +203,7 @@ def test_detect_faces_content(self): self._assert_face(face) def test_detect_faces_gcs(self): + self._maybe_http_skip('gRPC is required for face detection.') bucket_name = Config.TEST_BUCKET.name blob_name = 'faces.jpg' blob = Config.TEST_BUCKET.blob(blob_name) @@ -206,7 +212,6 @@ def test_detect_faces_gcs(self): blob.upload_from_file(file_obj) source_uri = 'gs://%s/%s' % (bucket_name, blob_name) - client = Config.CLIENT image = client.image(source_uri=source_uri) faces = image.detect_faces() @@ -215,6 +220,7 @@ def test_detect_faces_gcs(self): self._assert_face(face) def test_detect_faces_filename(self): + self._maybe_http_skip('gRPC is required for face detection.') client = Config.CLIENT image = client.image(filename=FACE_FILE) faces = image.detect_faces() @@ -361,6 +367,7 @@ def _assert_safe_search(self, safe_search): self._assert_likelihood(safe_search.violence) def test_detect_safe_search_content(self): + self._maybe_http_skip('gRPC is required for safe search detection.') client = Config.CLIENT with open(FACE_FILE, 'rb') as image_file: image = client.image(content=image_file.read()) @@ -370,6 +377,7 @@ def test_detect_safe_search_content(self): self._assert_safe_search(safe_search) def test_detect_safe_search_gcs(self): + self._maybe_http_skip('gRPC is required for safe search detection.') bucket_name = Config.TEST_BUCKET.name blob_name = 'faces.jpg' blob = Config.TEST_BUCKET.blob(blob_name) @@ -387,6 +395,7 @@ def test_detect_safe_search_gcs(self): self._assert_safe_search(safe_search) def test_detect_safe_search_filename(self): + self._maybe_http_skip('gRPC is required for safe search detection.') client = Config.CLIENT image = client.image(filename=FACE_FILE) safe_searches = image.detect_safe_search() @@ -484,6 +493,7 @@ def _assert_properties(self, image_property): self.assertNotEqual(color_info.score, 0.0) def test_detect_properties_content(self): + self._maybe_http_skip('gRPC is required for text detection.') client = Config.CLIENT with open(FACE_FILE, 'rb') as image_file: image = client.image(content=image_file.read()) @@ -493,6 +503,8 @@ def test_detect_properties_content(self): self._assert_properties(image_property) def test_detect_properties_gcs(self): + self._maybe_http_skip('gRPC is required for text detection.') + client = Config.CLIENT bucket_name = Config.TEST_BUCKET.name blob_name = 'faces.jpg' blob = Config.TEST_BUCKET.blob(blob_name) @@ -502,7 +514,6 @@ def test_detect_properties_gcs(self): source_uri = 'gs://%s/%s' % (bucket_name, blob_name) - client = Config.CLIENT image = client.image(source_uri=source_uri) properties = image.detect_properties() self.assertEqual(len(properties), 1) @@ -510,6 +521,7 @@ def test_detect_properties_gcs(self): self._assert_properties(image_property) def test_detect_properties_filename(self): + self._maybe_http_skip('gRPC is required for text detection.') client = Config.CLIENT image = client.image(filename=FACE_FILE) properties = image.detect_properties() diff --git a/vision/google/cloud/vision/_gax.py b/vision/google/cloud/vision/_gax.py index 1820d8a95860..886d33eb108b 100644 --- a/vision/google/cloud/vision/_gax.py +++ b/vision/google/cloud/vision/_gax.py @@ -19,6 +19,8 @@ from google.cloud._helpers import _to_bytes +from google.cloud.vision.annotations import Annotations + class _GAPICVisionAPI(object): """Vision API for interacting with the gRPC version of Vision. @@ -28,7 +30,32 @@ class _GAPICVisionAPI(object): """ def __init__(self, client=None): self._client = client - self._api = image_annotator_client.ImageAnnotatorClient() + self._annotator_client = image_annotator_client.ImageAnnotatorClient() + + def annotate(self, image, features): + """Annotate images through GAX. + + :type image: :class:`~google.cloud.vision.image.Image` + :param image: Instance of ``Image``. + + :type features: list + :param features: List of :class:`~google.cloud.vision.feature.Feature`. + + :rtype: :class:`~google.cloud.vision.annotations.Annotations` + :returns: Instance of ``Annotations`` with results or ``None``. + """ + gapic_features = [_to_gapic_feature(feature) for feature in features] + gapic_image = _to_gapic_image(image) + request = image_annotator_pb2.AnnotateImageRequest( + image=gapic_image, features=gapic_features) + requests = [request] + annotator_client = self._annotator_client + images = annotator_client.batch_annotate_images(requests) + if len(images.responses) == 1: + return Annotations.from_pb(images.responses[0]) + elif len(images.responses) > 1: + raise NotImplementedError( + 'Multiple image processing is not yet supported.') def _to_gapic_feature(feature): diff --git a/vision/google/cloud/vision/_http.py b/vision/google/cloud/vision/_http.py index 8bacdf01bb70..348588693a61 100644 --- a/vision/google/cloud/vision/_http.py +++ b/vision/google/cloud/vision/_http.py @@ -14,6 +14,7 @@ """HTTP Client for interacting with the Google Cloud Vision API.""" +from google.cloud.vision.annotations import Annotations from google.cloud.vision.feature import Feature @@ -48,8 +49,12 @@ def annotate(self, image, features): data = {'requests': [request]} api_response = self._connection.api_request( method='POST', path='/images:annotate', data=data) - responses = api_response.get('responses') - return responses[0] + images = api_response.get('responses') + if len(images) == 1: + return Annotations.from_api_repr(images[0]) + elif len(images) > 1: + raise NotImplementedError( + 'Multiple image processing is not yet supported.') def _make_request(image, features): diff --git a/vision/google/cloud/vision/annotations.py b/vision/google/cloud/vision/annotations.py index bbc14bc1414e..7550bcd8c676 100644 --- a/vision/google/cloud/vision/annotations.py +++ b/vision/google/cloud/vision/annotations.py @@ -93,6 +93,51 @@ def from_api_repr(cls, response): _entity_from_response_type(feature_type, annotation)) return cls(**annotations) + @classmethod + def from_pb(cls, response): + """Factory: construct an instance of ``Annotations`` from protobuf. + + :type response: :class:`~google.cloud.grpc.vision.v1.\ + image_annotator_pb2.AnnotateImageResponse` + :param response: ``AnnotateImageResponse`` from protobuf call. + + :rtype: :class:`~google.cloud.vision.annotations.Annotations` + :returns: ``Annotations`` instance populated from gRPC response. + """ + annotations = _process_image_annotations(response) + return cls(**annotations) + + +def _process_image_annotations(image): + """Helper for processing annotation types from protobuf. + + :type image: :class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.\ + AnnotateImageResponse` + :param image: ``AnnotateImageResponse`` from protobuf. + + :rtype: dict + :returns: Dictionary populated with entities from response. + """ + return { + 'labels': _make_entity_from_pb(image.label_annotations), + 'landmarks': _make_entity_from_pb(image.landmark_annotations), + 'logos': _make_entity_from_pb(image.logo_annotations), + 'texts': _make_entity_from_pb(image.text_annotations), + } + + +def _make_entity_from_pb(annotations): + """Create an entity from a gRPC response. + + :type annotations: + :class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.EntityAnnotation` + :param annotations: protobuf instance of ``EntityAnnotation``. + + :rtype: list + :returns: List of ``EntityAnnotation``. + """ + return [EntityAnnotation.from_pb(annotation) for annotation in annotations] + def _entity_from_response_type(feature_type, results): """Convert a JSON result to an entity type based on the feature. diff --git a/vision/google/cloud/vision/entity.py b/vision/google/cloud/vision/entity.py index 80cb0a74f5fc..774c220732d6 100644 --- a/vision/google/cloud/vision/entity.py +++ b/vision/google/cloud/vision/entity.py @@ -64,12 +64,32 @@ def from_api_repr(cls, response): description = response['description'] locale = response.get('locale', None) locations = [LocationInformation.from_api_repr(location) - for location in response.get('locations', [])] + for location in response.get('locations', ())] mid = response.get('mid', None) score = response.get('score', None) return cls(bounds, description, locale, locations, mid, score) + @classmethod + def from_pb(cls, response): + """Factory: construct entity from Vision gRPC response. + + :type response: :class:`~google.cloud.grpc.vision.v1.\ + image_annotator_pb2.AnnotateImageResponse` + :param response: gRPC response from Vision API with entity data. + + :rtype: :class:`~google.cloud.vision.entity.EntityAnnotation` + :returns: Instance of ``EntityAnnotation``. + """ + bounds = Bounds.from_pb(response.bounding_poly) + description = response.description + locale = response.locale + locations = [LocationInformation.from_pb(location) + for location in response.locations] + mid = response.mid + score = response.score + return cls(bounds, description, locale, locations, mid, score) + @property def bounds(self): """Bounding polygon of detected image feature. diff --git a/vision/google/cloud/vision/geometry.py b/vision/google/cloud/vision/geometry.py index 862ce60c5ef8..6b477d8b300e 100644 --- a/vision/google/cloud/vision/geometry.py +++ b/vision/google/cloud/vision/geometry.py @@ -25,21 +25,33 @@ def __init__(self, vertices): self._vertices = vertices @classmethod - def from_api_repr(cls, response_vertices): + def from_api_repr(cls, vertices): """Factory: construct BoundsBase instance from Vision API response. - :type response_vertices: dict - :param response_vertices: List of vertices. + :type vertices: dict + :param vertices: List of vertices. :rtype: :class:`~google.cloud.vision.geometry.BoundsBase` or None :returns: Instance of BoundsBase with populated verticies or None. """ - if not response_vertices: + if vertices is None: return None + return cls([Vertex(vertex.get('x', None), vertex.get('y', None)) + for vertex in vertices.get('vertices', ())]) - vertices = [Vertex(vertex.get('x', None), vertex.get('y', None)) for - vertex in response_vertices.get('vertices', [])] - return cls(vertices) + @classmethod + def from_pb(cls, vertices): + """Factory: construct BoundsBase instance from Vision gRPC response. + + :type vertices: :class:`~google.cloud.grpc.vision.v1.\ + geometry_pb2.BoundingPoly` + :param vertices: List of vertices. + + :rtype: :class:`~google.cloud.vision.geometry.BoundsBase` or None + :returns: Instance of ``BoundsBase`` with populated verticies. + """ + return cls([Vertex(vertex.x, vertex.y) + for vertex in vertices.vertices]) @property def vertices(self): @@ -73,20 +85,35 @@ def __init__(self, latitude, longitude): self._longitude = longitude @classmethod - def from_api_repr(cls, response): + def from_api_repr(cls, location_info): """Factory: construct location information from Vision API response. - :type response: dict - :param response: Dictionary response of locations. + :type location_info: dict + :param location_info: Dictionary response of locations. :rtype: :class:`~google.cloud.vision.geometry.LocationInformation` :returns: ``LocationInformation`` with populated latitude and longitude. """ - latitude = response['latLng']['latitude'] - longitude = response['latLng']['longitude'] + lat_long = location_info.get('latLng', {}) + latitude = lat_long.get('latitude') + longitude = lat_long.get('longitude') return cls(latitude, longitude) + @classmethod + def from_pb(cls, location_info): + """Factory: construct location information from Vision gRPC response. + + :type location_info: :class:`~google.cloud.vision.v1.LocationInfo` + :param location_info: gRPC response of ``LocationInfo``. + + :rtype: :class:`~google.cloud.vision.geometry.LocationInformation` + :returns: ``LocationInformation`` with populated latitude and + longitude. + """ + return cls(location_info.lat_lng.latitude, + location_info.lat_lng.longitude) + @property def latitude(self): """Latitude coordinate. @@ -127,15 +154,18 @@ def __init__(self, x_coordinate, y_coordinate, z_coordinate): self._z_coordinate = z_coordinate @classmethod - def from_api_repr(cls, response_position): + def from_api_repr(cls, position): """Factory: construct 3D position from API response. + :type position: dict + :param position: Dictionary with 3 axis position data. + :rtype: :class:`~google.cloud.vision.geometry.Position` :returns: `Position` constructed with 3D points from API response. """ - x_coordinate = response_position['x'] - y_coordinate = response_position['y'] - z_coordinate = response_position['z'] + x_coordinate = position['x'] + y_coordinate = position['y'] + z_coordinate = position['z'] return cls(x_coordinate, y_coordinate, z_coordinate) @property diff --git a/vision/google/cloud/vision/image.py b/vision/google/cloud/vision/image.py index f9a429e24e0b..52f6cfad9869 100644 --- a/vision/google/cloud/vision/image.py +++ b/vision/google/cloud/vision/image.py @@ -19,7 +19,6 @@ from google.cloud._helpers import _to_bytes from google.cloud._helpers import _bytes_to_unicode -from google.cloud.vision.annotations import Annotations from google.cloud.vision.feature import Feature from google.cloud.vision.feature import FeatureTypes @@ -109,8 +108,7 @@ def _detect_annotation(self, features): :class:`~google.cloud.vision.color.ImagePropertiesAnnotation`, :class:`~google.cloud.vision.sage.SafeSearchAnnotation`, """ - results = self.client._vision_api.annotate(self, features) - return Annotations.from_api_repr(results) + return self.client._vision_api.annotate(self, features) def detect(self, features): """Detect multiple feature types. diff --git a/vision/unit_tests/test__gax.py b/vision/unit_tests/test__gax.py index 887592844395..02ef77362f37 100644 --- a/vision/unit_tests/test__gax.py +++ b/vision/unit_tests/test__gax.py @@ -32,8 +32,83 @@ def test_ctor(self): api = self._make_one(client) self.assertIs(api._client, client) + def test_annotation(self): + from google.cloud.vision.feature import Feature + from google.cloud.vision.feature import FeatureTypes + from google.cloud.vision.image import Image + + client = mock.Mock(spec_set=[]) + feature = Feature(FeatureTypes.LABEL_DETECTION, 5) + image_content = b'abc 1 2 3' + image = Image(client, content=image_content) + with mock.patch('google.cloud.vision._gax.image_annotator_client.' + 'ImageAnnotatorClient'): + gax_api = self._make_one(client) + + mock_response = { + 'batch_annotate_images.return_value': + mock.Mock(responses=['mock response data']), + } + + gax_api._annotator_client = mock.Mock( + spec_set=['batch_annotate_images'], **mock_response) + + with mock.patch('google.cloud.vision._gax.Annotations') as mock_anno: + gax_api.annotate(image, [feature]) + mock_anno.from_pb.assert_called_with('mock response data') + gax_api._annotator_client.batch_annotate_images.assert_called() + + def test_annotate_no_results(self): + from google.cloud.vision.feature import Feature + from google.cloud.vision.feature import FeatureTypes + from google.cloud.vision.image import Image + + client = mock.Mock(spec_set=[]) + feature = Feature(FeatureTypes.LABEL_DETECTION, 5) + image_content = b'abc 1 2 3' + image = Image(client, content=image_content) + with mock.patch('google.cloud.vision._gax.image_annotator_client.' + 'ImageAnnotatorClient'): + gax_api = self._make_one(client) + + mock_response = { + 'batch_annotate_images.return_value': mock.Mock(responses=[]), + } + + gax_api._annotator_client = mock.Mock( + spec_set=['batch_annotate_images'], **mock_response) + with mock.patch('google.cloud.vision._gax.Annotations'): + self.assertIsNone(gax_api.annotate(image, [feature])) + + gax_api._annotator_client.batch_annotate_images.assert_called() + + def test_annotate_multiple_results(self): + from google.cloud.vision.feature import Feature + from google.cloud.vision.feature import FeatureTypes + from google.cloud.vision.image import Image + + client = mock.Mock(spec_set=[]) + feature = Feature(FeatureTypes.LABEL_DETECTION, 5) + image_content = b'abc 1 2 3' + image = Image(client, content=image_content) + with mock.patch('google.cloud.vision._gax.image_annotator_client.' + 'ImageAnnotatorClient'): + gax_api = self._make_one(client) + + mock_response = { + 'batch_annotate_images.return_value': mock.Mock(responses=[1, 2]), + } + + gax_api._annotator_client = mock.Mock( + spec_set=['batch_annotate_images'], **mock_response) + with mock.patch('google.cloud.vision._gax.Annotations'): + with self.assertRaises(NotImplementedError): + gax_api.annotate(image, [feature]) + + gax_api._annotator_client.batch_annotate_images.assert_called() + -class TestToGAPICFeature(unittest.TestCase): +class Test__to_gapic_feature(unittest.TestCase): def _call_fut(self, feature): from google.cloud.vision._gax import _to_gapic_feature return _to_gapic_feature(feature) @@ -50,7 +125,7 @@ def test__to_gapic_feature(self): self.assertEqual(feature_pb.max_results, 5) -class TestToGAPICImage(unittest.TestCase): +class Test__to_gapic_image(unittest.TestCase): def _call_fut(self, image): from google.cloud.vision._gax import _to_gapic_image return _to_gapic_image(image) diff --git a/vision/unit_tests/test__http.py b/vision/unit_tests/test__http.py index d6c237d9747c..b875a77db2d8 100644 --- a/vision/unit_tests/test__http.py +++ b/vision/unit_tests/test__http.py @@ -15,12 +15,54 @@ import base64 import unittest +import mock + IMAGE_CONTENT = b'/9j/4QNURXhpZgAASUkq' PROJECT = 'PROJECT' B64_IMAGE_CONTENT = base64.b64encode(IMAGE_CONTENT).decode('ascii') +class Test_HTTPVisionAPI(unittest.TestCase): + def _get_target_class(self): + from google.cloud.vision._http import _HTTPVisionAPI + return _HTTPVisionAPI + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_call_annotate_with_no_results(self): + from google.cloud.vision.feature import Feature + from google.cloud.vision.feature import FeatureTypes + from google.cloud.vision.image import Image + + client = mock.Mock(spec_set=['_connection']) + feature = Feature(FeatureTypes.LABEL_DETECTION, 5) + image_content = b'abc 1 2 3' + image = Image(client, content=image_content) + + http_api = self._make_one(client) + http_api._connection = mock.Mock(spec_set=['api_request']) + http_api._connection.api_request.return_value = {'responses': []} + self.assertIsNone(http_api.annotate(image, [feature])) + + def test_call_annotate_with_more_than_one_result(self): + from google.cloud.vision.feature import Feature + from google.cloud.vision.feature import FeatureTypes + from google.cloud.vision.image import Image + + client = mock.Mock(spec_set=['_connection']) + feature = Feature(FeatureTypes.LABEL_DETECTION, 5) + image_content = b'abc 1 2 3' + image = Image(client, content=image_content) + + http_api = self._make_one(client) + http_api._connection = mock.Mock(spec_set=['api_request']) + http_api._connection.api_request.return_value = {'responses': [1, 2]} + with self.assertRaises(NotImplementedError): + http_api.annotate(image, [feature]) + + class TestVisionRequest(unittest.TestCase): @staticmethod def _get_target_function(): @@ -44,7 +86,6 @@ def test_call_vision_request(self): features = request['features'] self.assertEqual(len(features), 1) feature = features[0] - print(feature) self.assertEqual(feature['type'], FeatureTypes.FACE_DETECTION) self.assertEqual(feature['maxResults'], 3) diff --git a/vision/unit_tests/test_annotations.py b/vision/unit_tests/test_annotations.py new file mode 100644 index 000000000000..b176f8490859 --- /dev/null +++ b/vision/unit_tests/test_annotations.py @@ -0,0 +1,141 @@ +# Copyright 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + + +def _make_pb_entity(): + from google.cloud.grpc.vision.v1 import geometry_pb2 + from google.cloud.grpc.vision.v1 import image_annotator_pb2 + from google.type import latlng_pb2 + + description = 'testing 1 2 3' + locale = 'US' + mid = 'm/w/45342234' + score = 0.235434231 + + entity_annotation = image_annotator_pb2.EntityAnnotation( + mid=mid, + locale=locale, + description=description, + score=score, + bounding_poly=geometry_pb2.BoundingPoly( + vertices=[ + geometry_pb2.Vertex(x=1, y=2), + ], + ), + locations=[ + image_annotator_pb2.LocationInfo( + lat_lng=latlng_pb2.LatLng(latitude=1.0, longitude=2.0), + ), + ], + ) + return entity_annotation + + +class TestAnnotations(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.vision.annotations import Annotations + + return Annotations + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def test_ctor(self): + annotations = self._make_one( + faces=[True], properties=[True], labels=[True], landmarks=[True], + logos=[True], safe_searches=[True], texts=[True]) + self.assertEqual(annotations.faces, [True]) + self.assertEqual(annotations.properties, [True]) + self.assertEqual(annotations.labels, [True]) + self.assertEqual(annotations.landmarks, [True]) + self.assertEqual(annotations.logos, [True]) + self.assertEqual(annotations.safe_searches, [True]) + self.assertEqual(annotations.texts, [True]) + + def test_from_pb(self): + from google.cloud.grpc.vision.v1 import image_annotator_pb2 + + image_response = image_annotator_pb2.AnnotateImageResponse() + annotations = self._make_one().from_pb(image_response) + self.assertEqual(annotations.labels, []) + self.assertEqual(annotations.logos, []) + self.assertEqual(annotations.faces, ()) + self.assertEqual(annotations.landmarks, []) + self.assertEqual(annotations.texts, []) + self.assertEqual(annotations.safe_searches, ()) + self.assertEqual(annotations.properties, ()) + + +class Test__make_entity_from_pb(unittest.TestCase): + def _call_fut(self, annotations): + from google.cloud.vision.annotations import _make_entity_from_pb + + return _make_entity_from_pb(annotations) + + def test_it(self): + description = 'testing 1 2 3' + locale = 'US' + mid = 'm/w/45342234' + score = 0.235434231 + entity_annotation = _make_pb_entity() + entities = self._call_fut([entity_annotation]) + self.assertEqual(len(entities), 1) + entity = entities[0] + self.assertEqual(entity.description, description) + self.assertEqual(entity.mid, mid) + self.assertEqual(entity.locale, locale) + self.assertEqual(entity.score, score) + self.assertEqual(len(entity.bounds.vertices), 1) + self.assertEqual(entity.bounds.vertices[0].x_coordinate, 1) + self.assertEqual(entity.bounds.vertices[0].y_coordinate, 2) + self.assertEqual(len(entity.locations), 1) + self.assertEqual(entity.locations[0].latitude, 1.0) + self.assertEqual(entity.locations[0].longitude, 2.0) + + +class Test__process_image_annotations(unittest.TestCase): + def _call_fut(self, image): + from google.cloud.vision.annotations import _process_image_annotations + + return _process_image_annotations(image) + + def test_it(self): + from google.cloud.grpc.vision.v1 import image_annotator_pb2 + + description = 'testing 1 2 3' + locale = 'US' + mid = 'm/w/45342234' + score = 0.235434231 + entity_annotation = _make_pb_entity() + + image_response = image_annotator_pb2.AnnotateImageResponse( + label_annotations=[entity_annotation]) + + annotations = self._call_fut(image_response) + self.assertEqual(len(annotations['labels']), 1) + entity = annotations['labels'][0] + + self.assertEqual(entity.description, description) + self.assertEqual(entity.mid, mid) + self.assertEqual(entity.locale, locale) + self.assertEqual(entity.score, score) + self.assertEqual(len(entity.bounds.vertices), 1) + self.assertEqual(entity.bounds.vertices[0].x_coordinate, 1) + self.assertEqual(entity.bounds.vertices[0].y_coordinate, 2) + self.assertEqual(len(entity.locations), 1) + self.assertEqual(entity.locations[0].latitude, 1.0) + self.assertEqual(entity.locations[0].longitude, 2.0) diff --git a/vision/unit_tests/test_client.py b/vision/unit_tests/test_client.py index 44f76e944012..ccf3609a5ed6 100644 --- a/vision/unit_tests/test_client.py +++ b/vision/unit_tests/test_client.py @@ -76,6 +76,7 @@ def test_make_http_client(self): self.assertIsInstance(client._vision_api, _HTTPVisionAPI) def test_face_annotation(self): + from google.cloud.vision.annotations import Annotations from google.cloud.vision.feature import Feature, FeatureTypes from unit_tests._fixtures import FACE_DETECTION_RESPONSE @@ -106,7 +107,7 @@ def test_face_annotation(self): self.assertEqual(REQUEST, client._connection._requested[0]['data']) - self.assertTrue('faceAnnotations' in response) + self.assertIsInstance(response, Annotations) def test_image_with_client_gcs_source(self): from google.cloud.vision.image import Image diff --git a/vision/unit_tests/test_entity.py b/vision/unit_tests/test_entity.py index 7ead8388ad27..d644de925aab 100644 --- a/vision/unit_tests/test_entity.py +++ b/vision/unit_tests/test_entity.py @@ -32,3 +32,34 @@ def test_logo_annotation(self): self.assertEqual('Brand1', logo.description) self.assertEqual(0.63192177, logo.score) self.assertEqual(162, logo.bounds.vertices[0].y_coordinate) + + def test_logo_pb_annotation(self): + from google.cloud.grpc.vision.v1 import image_annotator_pb2 + + description = 'testing 1 2 3' + locale = 'US' + mid = 'm/w/45342234' + score = 0.235434231 + entity_annotation = image_annotator_pb2.EntityAnnotation() + entity_annotation.mid = mid + entity_annotation.locale = locale + entity_annotation.description = description + entity_annotation.score = score + entity_annotation.bounding_poly.vertices.add() + entity_annotation.bounding_poly.vertices[0].x = 1 + entity_annotation.bounding_poly.vertices[0].y = 2 + entity_annotation.locations.add() + entity_annotation.locations[0].lat_lng.latitude = 1.0 + entity_annotation.locations[0].lat_lng.longitude = 2.0 + + entity_class = self._get_target_class() + entity = entity_class.from_pb(entity_annotation) + + self.assertEqual(entity.description, description) + self.assertEqual(entity.mid, mid) + self.assertEqual(entity.locale, locale) + self.assertEqual(entity.score, score) + self.assertEqual(entity.bounds.vertices[0].x_coordinate, 1) + self.assertEqual(entity.bounds.vertices[0].y_coordinate, 2) + self.assertEqual(entity.locations[0].latitude, 1.0) + self.assertEqual(entity.locations[0].longitude, 2.0)