Skip to content

Commit

Permalink
Merge pull request #2918 from daspecster/vision-add-gapic-entity-anno…
Browse files Browse the repository at this point in the history
…tation

Add gax support for entity annotations.
  • Loading branch information
daspecster authored Jan 13, 2017
2 parents 40fd881 + 7ff033d commit 885397b
Show file tree
Hide file tree
Showing 12 changed files with 455 additions and 29 deletions.
16 changes: 14 additions & 2 deletions system_tests/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def _assert_likelihood(self, likelihood):
Likelihood.VERY_UNLIKELY]
self.assertIn(likelihood, levels)

def _maybe_http_skip(self, message):
if not Config.CLIENT._use_gax:
self.skipTest(message)


class TestVisionClientLogo(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -190,6 +194,7 @@ def _assert_face(self, face):

def test_detect_faces_content(self):
client = Config.CLIENT
self._maybe_http_skip('gRPC is required for face detection.')
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
faces = image.detect_faces()
Expand All @@ -198,6 +203,7 @@ def test_detect_faces_content(self):
self._assert_face(face)

def test_detect_faces_gcs(self):
self._maybe_http_skip('gRPC is required for face detection.')
bucket_name = Config.TEST_BUCKET.name
blob_name = 'faces.jpg'
blob = Config.TEST_BUCKET.blob(blob_name)
Expand All @@ -206,7 +212,6 @@ def test_detect_faces_gcs(self):
blob.upload_from_file(file_obj)

source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
image = client.image(source_uri=source_uri)
faces = image.detect_faces()
Expand All @@ -215,6 +220,7 @@ def test_detect_faces_gcs(self):
self._assert_face(face)

def test_detect_faces_filename(self):
self._maybe_http_skip('gRPC is required for face detection.')
client = Config.CLIENT
image = client.image(filename=FACE_FILE)
faces = image.detect_faces()
Expand Down Expand Up @@ -361,6 +367,7 @@ def _assert_safe_search(self, safe_search):
self._assert_likelihood(safe_search.violence)

def test_detect_safe_search_content(self):
self._maybe_http_skip('gRPC is required for safe search detection.')
client = Config.CLIENT
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
Expand All @@ -370,6 +377,7 @@ def test_detect_safe_search_content(self):
self._assert_safe_search(safe_search)

def test_detect_safe_search_gcs(self):
self._maybe_http_skip('gRPC is required for safe search detection.')
bucket_name = Config.TEST_BUCKET.name
blob_name = 'faces.jpg'
blob = Config.TEST_BUCKET.blob(blob_name)
Expand All @@ -387,6 +395,7 @@ def test_detect_safe_search_gcs(self):
self._assert_safe_search(safe_search)

def test_detect_safe_search_filename(self):
self._maybe_http_skip('gRPC is required for safe search detection.')
client = Config.CLIENT
image = client.image(filename=FACE_FILE)
safe_searches = image.detect_safe_search()
Expand Down Expand Up @@ -484,6 +493,7 @@ def _assert_properties(self, image_property):
self.assertNotEqual(color_info.score, 0.0)

def test_detect_properties_content(self):
self._maybe_http_skip('gRPC is required for text detection.')
client = Config.CLIENT
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
Expand All @@ -493,6 +503,8 @@ def test_detect_properties_content(self):
self._assert_properties(image_property)

def test_detect_properties_gcs(self):
self._maybe_http_skip('gRPC is required for text detection.')
client = Config.CLIENT
bucket_name = Config.TEST_BUCKET.name
blob_name = 'faces.jpg'
blob = Config.TEST_BUCKET.blob(blob_name)
Expand All @@ -502,14 +514,14 @@ def test_detect_properties_gcs(self):

source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
image = client.image(source_uri=source_uri)
properties = image.detect_properties()
self.assertEqual(len(properties), 1)
image_property = properties[0]
self._assert_properties(image_property)

def test_detect_properties_filename(self):
self._maybe_http_skip('gRPC is required for text detection.')
client = Config.CLIENT
image = client.image(filename=FACE_FILE)
properties = image.detect_properties()
Expand Down
29 changes: 28 additions & 1 deletion vision/google/cloud/vision/_gax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from google.cloud._helpers import _to_bytes

from google.cloud.vision.annotations import Annotations


class _GAPICVisionAPI(object):
"""Vision API for interacting with the gRPC version of Vision.
Expand All @@ -28,7 +30,32 @@ class _GAPICVisionAPI(object):
"""
def __init__(self, client=None):
self._client = client
self._api = image_annotator_client.ImageAnnotatorClient()
self._annotator_client = image_annotator_client.ImageAnnotatorClient()

def annotate(self, image, features):
"""Annotate images through GAX.
:type image: :class:`~google.cloud.vision.image.Image`
:param image: Instance of ``Image``.
:type features: list
:param features: List of :class:`~google.cloud.vision.feature.Feature`.
:rtype: :class:`~google.cloud.vision.annotations.Annotations`
:returns: Instance of ``Annotations`` with results or ``None``.
"""
gapic_features = [_to_gapic_feature(feature) for feature in features]
gapic_image = _to_gapic_image(image)
request = image_annotator_pb2.AnnotateImageRequest(
image=gapic_image, features=gapic_features)
requests = [request]
annotator_client = self._annotator_client
images = annotator_client.batch_annotate_images(requests)
if len(images.responses) == 1:
return Annotations.from_pb(images.responses[0])
elif len(images.responses) > 1:
raise NotImplementedError(
'Multiple image processing is not yet supported.')


def _to_gapic_feature(feature):
Expand Down
9 changes: 7 additions & 2 deletions vision/google/cloud/vision/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""HTTP Client for interacting with the Google Cloud Vision API."""

from google.cloud.vision.annotations import Annotations
from google.cloud.vision.feature import Feature


Expand Down Expand Up @@ -48,8 +49,12 @@ def annotate(self, image, features):
data = {'requests': [request]}
api_response = self._connection.api_request(
method='POST', path='/images:annotate', data=data)
responses = api_response.get('responses')
return responses[0]
images = api_response.get('responses')
if len(images) == 1:
return Annotations.from_api_repr(images[0])
elif len(images) > 1:
raise NotImplementedError(
'Multiple image processing is not yet supported.')


def _make_request(image, features):
Expand Down
45 changes: 45 additions & 0 deletions vision/google/cloud/vision/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,51 @@ def from_api_repr(cls, response):
_entity_from_response_type(feature_type, annotation))
return cls(**annotations)

@classmethod
def from_pb(cls, response):
"""Factory: construct an instance of ``Annotations`` from protobuf.
:type response: :class:`~google.cloud.grpc.vision.v1.\
image_annotator_pb2.AnnotateImageResponse`
:param response: ``AnnotateImageResponse`` from protobuf call.
:rtype: :class:`~google.cloud.vision.annotations.Annotations`
:returns: ``Annotations`` instance populated from gRPC response.
"""
annotations = _process_image_annotations(response)
return cls(**annotations)


def _process_image_annotations(image):
"""Helper for processing annotation types from protobuf.
:type image: :class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.\
AnnotateImageResponse`
:param image: ``AnnotateImageResponse`` from protobuf.
:rtype: dict
:returns: Dictionary populated with entities from response.
"""
return {
'labels': _make_entity_from_pb(image.label_annotations),
'landmarks': _make_entity_from_pb(image.landmark_annotations),
'logos': _make_entity_from_pb(image.logo_annotations),
'texts': _make_entity_from_pb(image.text_annotations),
}


def _make_entity_from_pb(annotations):
"""Create an entity from a gRPC response.
:type annotations:
:class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.EntityAnnotation`
:param annotations: protobuf instance of ``EntityAnnotation``.
:rtype: list
:returns: List of ``EntityAnnotation``.
"""
return [EntityAnnotation.from_pb(annotation) for annotation in annotations]


def _entity_from_response_type(feature_type, results):
"""Convert a JSON result to an entity type based on the feature.
Expand Down
22 changes: 21 additions & 1 deletion vision/google/cloud/vision/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,32 @@ def from_api_repr(cls, response):
description = response['description']
locale = response.get('locale', None)
locations = [LocationInformation.from_api_repr(location)
for location in response.get('locations', [])]
for location in response.get('locations', ())]
mid = response.get('mid', None)
score = response.get('score', None)

return cls(bounds, description, locale, locations, mid, score)

@classmethod
def from_pb(cls, response):
"""Factory: construct entity from Vision gRPC response.
:type response: :class:`~google.cloud.grpc.vision.v1.\
image_annotator_pb2.AnnotateImageResponse`
:param response: gRPC response from Vision API with entity data.
:rtype: :class:`~google.cloud.vision.entity.EntityAnnotation`
:returns: Instance of ``EntityAnnotation``.
"""
bounds = Bounds.from_pb(response.bounding_poly)
description = response.description
locale = response.locale
locations = [LocationInformation.from_pb(location)
for location in response.locations]
mid = response.mid
score = response.score
return cls(bounds, description, locale, locations, mid, score)

@property
def bounds(self):
"""Bounding polygon of detected image feature.
Expand Down
62 changes: 46 additions & 16 deletions vision/google/cloud/vision/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,33 @@ def __init__(self, vertices):
self._vertices = vertices

@classmethod
def from_api_repr(cls, response_vertices):
def from_api_repr(cls, vertices):
"""Factory: construct BoundsBase instance from Vision API response.
:type response_vertices: dict
:param response_vertices: List of vertices.
:type vertices: dict
:param vertices: List of vertices.
:rtype: :class:`~google.cloud.vision.geometry.BoundsBase` or None
:returns: Instance of BoundsBase with populated verticies or None.
"""
if not response_vertices:
if vertices is None:
return None
return cls([Vertex(vertex.get('x', None), vertex.get('y', None))
for vertex in vertices.get('vertices', ())])

vertices = [Vertex(vertex.get('x', None), vertex.get('y', None)) for
vertex in response_vertices.get('vertices', [])]
return cls(vertices)
@classmethod
def from_pb(cls, vertices):
"""Factory: construct BoundsBase instance from Vision gRPC response.
:type vertices: :class:`~google.cloud.grpc.vision.v1.\
geometry_pb2.BoundingPoly`
:param vertices: List of vertices.
:rtype: :class:`~google.cloud.vision.geometry.BoundsBase` or None
:returns: Instance of ``BoundsBase`` with populated verticies.
"""
return cls([Vertex(vertex.x, vertex.y)
for vertex in vertices.vertices])

@property
def vertices(self):
Expand Down Expand Up @@ -73,20 +85,35 @@ def __init__(self, latitude, longitude):
self._longitude = longitude

@classmethod
def from_api_repr(cls, response):
def from_api_repr(cls, location_info):
"""Factory: construct location information from Vision API response.
:type response: dict
:param response: Dictionary response of locations.
:type location_info: dict
:param location_info: Dictionary response of locations.
:rtype: :class:`~google.cloud.vision.geometry.LocationInformation`
:returns: ``LocationInformation`` with populated latitude and
longitude.
"""
latitude = response['latLng']['latitude']
longitude = response['latLng']['longitude']
lat_long = location_info.get('latLng', {})
latitude = lat_long.get('latitude')
longitude = lat_long.get('longitude')
return cls(latitude, longitude)

@classmethod
def from_pb(cls, location_info):
"""Factory: construct location information from Vision gRPC response.
:type location_info: :class:`~google.cloud.vision.v1.LocationInfo`
:param location_info: gRPC response of ``LocationInfo``.
:rtype: :class:`~google.cloud.vision.geometry.LocationInformation`
:returns: ``LocationInformation`` with populated latitude and
longitude.
"""
return cls(location_info.lat_lng.latitude,
location_info.lat_lng.longitude)

@property
def latitude(self):
"""Latitude coordinate.
Expand Down Expand Up @@ -127,15 +154,18 @@ def __init__(self, x_coordinate, y_coordinate, z_coordinate):
self._z_coordinate = z_coordinate

@classmethod
def from_api_repr(cls, response_position):
def from_api_repr(cls, position):
"""Factory: construct 3D position from API response.
:type position: dict
:param position: Dictionary with 3 axis position data.
:rtype: :class:`~google.cloud.vision.geometry.Position`
:returns: `Position` constructed with 3D points from API response.
"""
x_coordinate = response_position['x']
y_coordinate = response_position['y']
z_coordinate = response_position['z']
x_coordinate = position['x']
y_coordinate = position['y']
z_coordinate = position['z']
return cls(x_coordinate, y_coordinate, z_coordinate)

@property
Expand Down
4 changes: 1 addition & 3 deletions vision/google/cloud/vision/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from google.cloud._helpers import _to_bytes
from google.cloud._helpers import _bytes_to_unicode
from google.cloud.vision.annotations import Annotations
from google.cloud.vision.feature import Feature
from google.cloud.vision.feature import FeatureTypes

Expand Down Expand Up @@ -109,8 +108,7 @@ def _detect_annotation(self, features):
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`,
"""
results = self.client._vision_api.annotate(self, features)
return Annotations.from_api_repr(results)
return self.client._vision_api.annotate(self, features)

def detect(self, features):
"""Detect multiple feature types.
Expand Down
Loading

0 comments on commit 885397b

Please sign in to comment.