From f9a4bfcb027f2e3a8e32578adf49981aeef3586a Mon Sep 17 00:00:00 2001 From: Matthias Baetens Date: Wed, 30 Jun 2021 16:16:14 +0200 Subject: [PATCH] [BEAM-11289] [Python] Integrate Google Cloud Recommendations AI functionality (#14806) [BEAM-11289] [Python] Integrate Google Cloud Recommendations AI functionality --- .../apache_beam/ml/gcp/recommendations_ai.py | 585 ++++++++++++++++++ .../ml/gcp/recommendations_ai_test.py | 207 +++++++ .../ml/gcp/recommendations_ai_test_it.py | 108 ++++ .../load_tests/load_test_metrics_utils.py | 2 +- sdks/python/setup.py | 3 +- 5 files changed, 903 insertions(+), 2 deletions(-) create mode 100644 sdks/python/apache_beam/ml/gcp/recommendations_ai.py create mode 100644 sdks/python/apache_beam/ml/gcp/recommendations_ai_test.py create mode 100644 sdks/python/apache_beam/ml/gcp/recommendations_ai_test_it.py diff --git a/sdks/python/apache_beam/ml/gcp/recommendations_ai.py b/sdks/python/apache_beam/ml/gcp/recommendations_ai.py new file mode 100644 index 000000000000..b6eb4cfb4bcd --- /dev/null +++ b/sdks/python/apache_beam/ml/gcp/recommendations_ai.py @@ -0,0 +1,585 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""A connector for sending API requests to the GCP Recommendations AI +API (https://cloud.google.com/recommendations). +""" + +from __future__ import absolute_import + +from typing import Sequence +from typing import Tuple + +from google.api_core.retry import Retry + +from apache_beam import pvalue +from apache_beam.metrics import Metrics +from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.transforms import DoFn +from apache_beam.transforms import ParDo +from apache_beam.transforms import PTransform +from apache_beam.transforms.util import GroupIntoBatches +from cachetools.func import ttl_cache + +# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports +try: + from google.cloud import recommendationengine +except ImportError: + raise ImportError( + 'Google Cloud Recommendation AI not supported for this execution ' + 'environment (could not import google.cloud.recommendationengine).') +# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports + +__all__ = [ + 'CreateCatalogItem', + 'WriteUserEvent', + 'ImportCatalogItems', + 'ImportUserEvents', + 'PredictUserEvent' +] + +FAILED_CATALOG_ITEMS = "failed_catalog_items" + + +@ttl_cache(maxsize=128, ttl=3600) +def get_recommendation_prediction_client(): + """Returns a Recommendation AI - Prediction Service client.""" + _client = recommendationengine.PredictionServiceClient() + return _client + + +@ttl_cache(maxsize=128, ttl=3600) +def get_recommendation_catalog_client(): + """Returns a Recommendation AI - Catalog Service client.""" + _client = recommendationengine.CatalogServiceClient() + return _client + + +@ttl_cache(maxsize=128, ttl=3600) +def get_recommendation_user_event_client(): + """Returns a Recommendation AI - UserEvent Service client.""" + _client = recommendationengine.UserEventServiceClient() + return _client + + +class CreateCatalogItem(PTransform): + """Creates catalogitem information. + The ``PTranform`` returns a PCollectionTuple with a PCollections of + successfully and failed created CatalogItems. + + Example usage:: + + pipeline | CreateCatalogItem( + project='example-gcp-project', + catalog_name='my-catalog') + """ + def __init__( + self, + project: str = None, + retry: Retry = None, + timeout: float = 120, + metadata: Sequence[Tuple[str, str]] = (), + catalog_name: str = "default_catalog"): + """Initializes a :class:`CreateCatalogItem` transform. + + Args: + project (str): Optional. GCP project name in which the catalog + data will be imported. + retry: Optional. Designation of what + errors, if any, should be retried. + timeout (float): Optional. The amount of time, in seconds, to wait + for the request to complete. + metadata: Optional. Strings which + should be sent along with the request as metadata. + catalog_name (str): Optional. Name of the catalog. + Default: 'default_catalog' + """ + self.project = project + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.catalog_name = catalog_name + + def expand(self, pcoll): + if self.project is None: + self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project + if self.project is None: + raise ValueError( + """GCP project name needs to be specified in "project" pipeline + option""") + return pcoll | ParDo( + _CreateCatalogItemFn( + self.project, + self.retry, + self.timeout, + self.metadata, + self.catalog_name)) + + +class _CreateCatalogItemFn(DoFn): + def __init__( + self, + project: str = None, + retry: Retry = None, + timeout: float = 120, + metadata: Sequence[Tuple[str, str]] = (), + catalog_name: str = None): + self._client = None + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.parent = f"projects/{project}/locations/global/catalogs/{catalog_name}" + self.counter = Metrics.counter(self.__class__, "api_calls") + + def setup(self): + if self._client is None: + self._client = get_recommendation_catalog_client() + + def process(self, element): + catalog_item = recommendationengine.CatalogItem(element) + request = recommendationengine.CreateCatalogItemRequest( + parent=self.parent, catalog_item=catalog_item) + + try: + created_catalog_item = self._client.create_catalog_item( + request=request, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata) + + self.counter.inc() + yield recommendationengine.CatalogItem.to_dict(created_catalog_item) + except Exception: + yield pvalue.TaggedOutput( + FAILED_CATALOG_ITEMS, + recommendationengine.CatalogItem.to_dict(catalog_item)) + + +class ImportCatalogItems(PTransform): + """Imports catalogitems in bulk. + The `PTransform` returns a PCollectionTuple with PCollections of + successfully and failed imported CatalogItems. + + Example usage:: + + pipeline + | ImportCatalogItems( + project='example-gcp-project', + catalog_name='my-catalog') + """ + def __init__( + self, + max_batch_size: int = 5000, + project: str = None, + retry: Retry = None, + timeout: float = 120, + metadata: Sequence[Tuple[str, str]] = (), + catalog_name: str = "default_catalog"): + """Initializes a :class:`ImportCatalogItems` transform + + Args: + batch_size (int): Required. Maximum number of catalogitems per + request. + project (str): Optional. GCP project name in which the catalog + data will be imported. + retry: Optional. Designation of what + errors, if any, should be retried. + timeout (float): Optional. The amount of time, in seconds, to wait + for the request to complete. + metadata: Optional. Strings which + should be sent along with the request as metadata. + catalog_name (str): Optional. Name of the catalog. + Default: 'default_catalog' + """ + self.max_batch_size = max_batch_size + self.project = project + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.catalog_name = catalog_name + + def expand(self, pcoll): + if self.project is None: + self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project + if self.project is None: + raise ValueError( + 'GCP project name needs to be specified in "project" pipeline option') + return ( + pcoll | GroupIntoBatches.WithShardedKey(self.max_batch_size) | ParDo( + _ImportCatalogItemsFn( + self.project, + self.retry, + self.timeout, + self.metadata, + self.catalog_name))) + + +class _ImportCatalogItemsFn(DoFn): + def __init__( + self, + project=None, + retry=None, + timeout=120, + metadata=None, + catalog_name=None): + self._client = None + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.parent = f"projects/{project}/locations/global/catalogs/{catalog_name}" + self.counter = Metrics.counter(self.__class__, "api_calls") + + def setup(self): + if self._client is None: + self.client = get_recommendation_catalog_client() + + def process(self, element): + catalog_items = [recommendationengine.CatalogItem(e) for e in element[1]] + catalog_inline_source = recommendationengine.CatalogInlineSource( + {"catalog_items": catalog_items}) + input_config = recommendationengine.InputConfig( + catalog_inline_source=catalog_inline_source) + + request = recommendationengine.ImportCatalogItemsRequest( + parent=self.parent, input_config=input_config) + + try: + operation = self._client.import_catalog_items( + request=request, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata) + self.counter.inc(len(catalog_items)) + yield operation.result() + except Exception: + yield pvalue.TaggedOutput(FAILED_CATALOG_ITEMS, catalog_items) + + +class WriteUserEvent(PTransform): + """Write user event information. + The `PTransform` returns a PCollectionTuple with PCollections of + successfully and failed written UserEvents. + + Example usage:: + + pipeline + | WriteUserEvent( + project='example-gcp-project', + catalog_name='my-catalog', + event_store='my_event_store') + """ + def __init__( + self, + project: str = None, + retry: Retry = None, + timeout: float = 120, + metadata: Sequence[Tuple[str, str]] = (), + catalog_name: str = "default_catalog", + event_store: str = "default_event_store"): + """Initializes a :class:`WriteUserEvent` transform. + + Args: + project (str): Optional. GCP project name in which the catalog + data will be imported. + retry: Optional. Designation of what + errors, if any, should be retried. + timeout (float): Optional. The amount of time, in seconds, to wait + for the request to complete. + metadata: Optional. Strings which + should be sent along with the request as metadata. + catalog_name (str): Optional. Name of the catalog. + Default: 'default_catalog' + event_store (str): Optional. Name of the event store. + Default: 'default_event_store' + """ + self.project = project + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.catalog_name = catalog_name + self.event_store = event_store + + def expand(self, pcoll): + if self.project is None: + self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project + if self.project is None: + raise ValueError( + 'GCP project name needs to be specified in "project" pipeline option') + return pcoll | ParDo( + _WriteUserEventFn( + self.project, + self.retry, + self.timeout, + self.metadata, + self.catalog_name, + self.event_store)) + + +class _WriteUserEventFn(DoFn): + FAILED_USER_EVENTS = "failed_user_events" + + def __init__( + self, + project=None, + retry=None, + timeout=120, + metadata=None, + catalog_name=None, + event_store=None): + self._client = None + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.parent = f"projects/{project}/locations/global/catalogs/"\ + f"{catalog_name}/eventStores/{event_store}" + self.counter = Metrics.counter(self.__class__, "api_calls") + + def setup(self): + if self._client is None: + self._client = get_recommendation_user_event_client() + + def process(self, element): + user_event = recommendationengine.UserEvent(element) + request = recommendationengine.WriteUserEventRequest( + parent=self.parent, user_event=user_event) + + try: + created_user_event = self._client.write_user_event(request) + self.counter.inc() + yield recommendationengine.UserEvent.to_dict(created_user_event) + except Exception: + yield pvalue.TaggedOutput( + self.FAILED_USER_EVENTS, + recommendationengine.UserEvent.to_dict(user_event)) + + +class ImportUserEvents(PTransform): + """Imports userevents in bulk. + The `PTransform` returns a PCollectionTuple with PCollections of + successfully and failed imported UserEvents. + + Example usage:: + + pipeline + | ImportUserEvents( + project='example-gcp-project', + catalog_name='my-catalog', + event_store='my_event_store') + """ + def __init__( + self, + max_batch_size: int = 5000, + project: str = None, + retry: Retry = None, + timeout: float = 120, + metadata: Sequence[Tuple[str, str]] = (), + catalog_name: str = "default_catalog", + event_store: str = "default_event_store"): + """Initializes a :class:`WriteUserEvent` transform. + + Args: + batch_size (int): Required. Maximum number of catalogitems + per request. + project (str): Optional. GCP project name in which the catalog + data will be imported. + retry: Optional. Designation of what + errors, if any, should be retried. + timeout (float): Optional. The amount of time, in seconds, to wait + for the request to complete. + metadata: Optional. Strings which + should be sent along with the request as metadata. + catalog_name (str): Optional. Name of the catalog. + Default: 'default_catalog' + event_store (str): Optional. Name of the event store. + Default: 'default_event_store' + """ + self.max_batch_size = max_batch_size + self.project = project + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.catalog_name = catalog_name + self.event_store = event_store + + def expand(self, pcoll): + if self.project is None: + self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project + if self.project is None: + raise ValueError( + 'GCP project name needs to be specified in "project" pipeline option') + return ( + pcoll | GroupIntoBatches.WithShardedKey(self.max_batch_size) | ParDo( + _ImportUserEventsFn( + self.project, + self.retry, + self.timeout, + self.metadata, + self.catalog_name, + self.event_store))) + + +class _ImportUserEventsFn(DoFn): + FAILED_USER_EVENTS = "failed_user_events" + + def __init__( + self, + project=None, + retry=None, + timeout=120, + metadata=None, + catalog_name=None, + event_store=None): + self._client = None + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.parent = f"projects/{project}/locations/global/catalogs/"\ + f"{catalog_name}/eventStores/{event_store}" + self.counter = Metrics.counter(self.__class__, "api_calls") + + def setup(self): + if self._client is None: + self.client = get_recommendation_user_event_client() + + def process(self, element): + + user_events = [recommendationengine.UserEvent(e) for e in element[1]] + user_event_inline_source = recommendationengine.UserEventInlineSource( + {"user_events": user_events}) + input_config = recommendationengine.InputConfig( + user_event_inline_source=user_event_inline_source) + + request = recommendationengine.ImportUserEventsRequest( + parent=self.parent, input_config=input_config) + + try: + operation = self._client.write_user_event(request) + self.counter.inc(len(user_events)) + yield recommendationengine.PredictResponse.to_dict(operation.result()) + except Exception: + yield pvalue.TaggedOutput(self.FAILED_USER_EVENTS, user_events) + + +class PredictUserEvent(PTransform): + """Make a recommendation prediction. + The `PTransform` returns a PCollection + + Example usage:: + + pipeline + | PredictUserEvent( + project='example-gcp-project', + catalog_name='my-catalog', + event_store='my_event_store', + placement_id='recently_viewed_default') + """ + def __init__( + self, + project: str = None, + retry: Retry = None, + timeout: float = 120, + metadata: Sequence[Tuple[str, str]] = (), + catalog_name: str = "default_catalog", + event_store: str = "default_event_store", + placement_id: str = None): + """Initializes a :class:`PredictUserEvent` transform. + + Args: + project (str): Optional. GCP project name in which the catalog + data will be imported. + retry: Optional. Designation of what + errors, if any, should be retried. + timeout (float): Optional. The amount of time, in seconds, to wait + for the request to complete. + metadata: Optional. Strings which + should be sent along with the request as metadata. + catalog_name (str): Optional. Name of the catalog. + Default: 'default_catalog' + event_store (str): Optional. Name of the event store. + Default: 'default_event_store' + placement_id (str): Required. ID of the recommendation engine + placement. This id is used to identify the set of models that + will be used to make the prediction. + """ + self.project = project + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.placement_id = placement_id + self.catalog_name = catalog_name + self.event_store = event_store + if placement_id is None: + raise ValueError('placement_id must be specified') + else: + self.placement_id = placement_id + + def expand(self, pcoll): + if self.project is None: + self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project + if self.project is None: + raise ValueError( + 'GCP project name needs to be specified in "project" pipeline option') + return pcoll | ParDo( + _PredictUserEventFn( + self.project, + self.retry, + self.timeout, + self.metadata, + self.catalog_name, + self.event_store, + self.placement_id)) + + +class _PredictUserEventFn(DoFn): + FAILED_PREDICTIONS = "failed_predictions" + + def __init__( + self, + project=None, + retry=None, + timeout=120, + metadata=None, + catalog_name=None, + event_store=None, + placement_id=None): + self._client = None + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.name = f"projects/{project}/locations/global/catalogs/"\ + f"{catalog_name}/eventStores/{event_store}/placements/"\ + f"{placement_id}" + self.counter = Metrics.counter(self.__class__, "api_calls") + + def setup(self): + if self._client is None: + self._client = get_recommendation_prediction_client() + + def process(self, element): + user_event = recommendationengine.UserEvent(element) + request = recommendationengine.PredictRequest( + name=self.name, user_event=user_event) + + try: + prediction = self._client.predict(request) + self.counter.inc() + yield [ + recommendationengine.PredictResponse.to_dict(p) + for p in prediction.pages + ] + except Exception: + yield pvalue.TaggedOutput(self.FAILED_PREDICTIONS, user_event) diff --git a/sdks/python/apache_beam/ml/gcp/recommendations_ai_test.py b/sdks/python/apache_beam/ml/gcp/recommendations_ai_test.py new file mode 100644 index 000000000000..2f688d97a309 --- /dev/null +++ b/sdks/python/apache_beam/ml/gcp/recommendations_ai_test.py @@ -0,0 +1,207 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for Recommendations AI transforms.""" + +from __future__ import absolute_import + +import unittest + +import mock + +import apache_beam as beam +from apache_beam.metrics import MetricsFilter + +# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports +try: + from google.cloud import recommendationengine + from apache_beam.ml.gcp import recommendations_ai +except ImportError: + recommendationengine = None +# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports + + +@unittest.skipIf( + recommendationengine is None, + "Recommendations AI dependencies not installed.") +class RecommendationsAICatalogItemTest(unittest.TestCase): + def setUp(self): + self._mock_client = mock.Mock() + self._mock_client.create_catalog_item.return_value = ( + recommendationengine.CatalogItem()) + self.m2 = mock.Mock() + self.m2.result.return_value = None + self._mock_client.import_catalog_items.return_value = self.m2 + + self._catalog_item = { + "id": "12345", + "title": "Sample laptop", + "description": "Indisputably the most fantastic laptop ever created.", + "language_code": "en", + "category_hierarchies": [{ + "categories": ["Electronic", "Computers"] + }] + } + + def test_CreateCatalogItem(self): + expected_counter = 1 + with mock.patch.object(recommendations_ai, + 'get_recommendation_catalog_client', + return_value=self._mock_client): + p = beam.Pipeline() + + _ = ( + p | "Create data" >> beam.Create([self._catalog_item]) + | "Create CatalogItem" >> + recommendations_ai.CreateCatalogItem(project="test")) + + result = p.run() + result.wait_until_finish() + + read_filter = MetricsFilter().with_name('api_calls') + query_result = result.metrics().query(read_filter) + if query_result['counters']: + read_counter = query_result['counters'][0] + self.assertTrue(read_counter.result == expected_counter) + + def test_ImportCatalogItems(self): + expected_counter = 1 + with mock.patch.object(recommendations_ai, + 'get_recommendation_catalog_client', + return_value=self._mock_client): + p = beam.Pipeline() + + _ = ( + p | "Create data" >> beam.Create([ + (self._catalog_item["id"], self._catalog_item), + (self._catalog_item["id"], self._catalog_item) + ]) | "Create CatalogItems" >> + recommendations_ai.ImportCatalogItems(project="test")) + + result = p.run() + result.wait_until_finish() + + read_filter = MetricsFilter().with_name('api_calls') + query_result = result.metrics().query(read_filter) + if query_result['counters']: + read_counter = query_result['counters'][0] + self.assertTrue(read_counter.result == expected_counter) + + +@unittest.skipIf( + recommendationengine is None, + "Recommendations AI dependencies not installed.") +class RecommendationsAIUserEventTest(unittest.TestCase): + def setUp(self): + self._mock_client = mock.Mock() + self._mock_client.write_user_event.return_value = ( + recommendationengine.UserEvent()) + self.m2 = mock.Mock() + self.m2.result.return_value = None + self._mock_client.import_user_events.return_value = self.m2 + + self._user_event = { + "event_type": "page-visit", "user_info": { + "visitor_id": "1" + } + } + + def test_CreateUserEvent(self): + expected_counter = 1 + with mock.patch.object(recommendations_ai, + 'get_recommendation_user_event_client', + return_value=self._mock_client): + p = beam.Pipeline() + + _ = ( + p | "Create data" >> beam.Create([self._user_event]) + | "Create UserEvent" >> + recommendations_ai.WriteUserEvent(project="test")) + + result = p.run() + result.wait_until_finish() + + read_filter = MetricsFilter().with_name('api_calls') + query_result = result.metrics().query(read_filter) + if query_result['counters']: + read_counter = query_result['counters'][0] + self.assertTrue(read_counter.result == expected_counter) + + def test_ImportUserEvents(self): + expected_counter = 1 + with mock.patch.object(recommendations_ai, + 'get_recommendation_user_event_client', + return_value=self._mock_client): + p = beam.Pipeline() + + _ = ( + p | "Create data" >> beam.Create([ + (self._user_event["user_info"]["visitor_id"], self._user_event), + (self._user_event["user_info"]["visitor_id"], self._user_event) + ]) | "Create UserEvents" >> + recommendations_ai.ImportUserEvents(project="test")) + + result = p.run() + result.wait_until_finish() + + read_filter = MetricsFilter().with_name('api_calls') + query_result = result.metrics().query(read_filter) + if query_result['counters']: + read_counter = query_result['counters'][0] + self.assertTrue(read_counter.result == expected_counter) + + +@unittest.skipIf( + recommendationengine is None, + "Recommendations AI dependencies not installed.") +class RecommendationsAIPredictTest(unittest.TestCase): + def setUp(self): + self._mock_client = mock.Mock() + self._mock_client.predict.return_value = [ + recommendationengine.PredictResponse() + ] + + self._user_event = { + "event_type": "page-visit", "user_info": { + "visitor_id": "1" + } + } + + def test_Predict(self): + expected_counter = 1 + with mock.patch.object(recommendations_ai, + 'get_recommendation_prediction_client', + return_value=self._mock_client): + p = beam.Pipeline() + + _ = ( + p | "Create data" >> beam.Create([self._user_event]) + | "Prediction UserEvents" >> recommendations_ai.PredictUserEvent( + project="test", placement_id="recently_viewed_default")) + + result = p.run() + result.wait_until_finish() + + read_filter = MetricsFilter().with_name('api_calls') + query_result = result.metrics().query(read_filter) + if query_result['counters']: + read_counter = query_result['counters'][0] + self.assertTrue(read_counter.result == expected_counter) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/gcp/recommendations_ai_test_it.py b/sdks/python/apache_beam/ml/gcp/recommendations_ai_test_it.py new file mode 100644 index 000000000000..8c94a29004d8 --- /dev/null +++ b/sdks/python/apache_beam/ml/gcp/recommendations_ai_test_it.py @@ -0,0 +1,108 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Integration tests for Recommendations AI transforms.""" + +from __future__ import absolute_import + +import random +import unittest + +from nose.plugins.attrib import attr + +import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.testing.util import is_not_empty + +# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports +try: + from google.cloud import recommendationengine + from apache_beam.ml.gcp import recommendations_ai +except ImportError: + recommendationengine = None +# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports + +GCP_TEST_PROJECT = 'apache-beam-testing' + + +def extract_id(response): + yield response["id"] + + +def extract_event_type(response): + yield response["event_type"] + + +def extract_prediction(response): + yield response[0]["results"] + + +@attr('IT') +@unittest.skipIf( + recommendationengine is None, + "Recommendations AI dependencies not installed.") +class RecommendationAIIT(unittest.TestCase): + def test_create_catalog_item(self): + + CATALOG_ITEM = { + "id": str(int(random.randrange(100000))), + "title": "Sample laptop", + "description": "Indisputably the most fantastic laptop ever created.", + "language_code": "en", + "category_hierarchies": [{ + "categories": ["Electronic", "Computers"] + }] + } + + with TestPipeline(is_integration_test=True) as p: + output = ( + p | 'Create data' >> beam.Create([CATALOG_ITEM]) + | 'Create CatalogItem' >> + recommendations_ai.CreateCatalogItem(project=GCP_TEST_PROJECT) + | beam.ParDo(extract_id) | beam.combiners.ToList()) + + assert_that(output, equal_to([[CATALOG_ITEM["id"]]])) + + def test_create_user_event(self): + USER_EVENT = {"event_type": "page-visit", "user_info": {"visitor_id": "1"}} + + with TestPipeline(is_integration_test=True) as p: + output = ( + p | 'Create data' >> beam.Create([USER_EVENT]) | 'Create UserEvent' >> + recommendations_ai.WriteUserEvent(project=GCP_TEST_PROJECT) + | beam.ParDo(extract_event_type) | beam.combiners.ToList()) + + assert_that(output, equal_to([[USER_EVENT["event_type"]]])) + + def test_predict(self): + USER_EVENT = {"event_type": "page-visit", "user_info": {"visitor_id": "1"}} + + with TestPipeline(is_integration_test=True) as p: + output = ( + p | 'Create data' >> beam.Create([USER_EVENT]) + | 'Predict UserEvent' >> recommendations_ai.PredictUserEvent( + project=GCP_TEST_PROJECT, placement_id="recently_viewed_default") + | beam.ParDo(extract_prediction)) + + assert_that(output, is_not_empty()) + + +if __name__ == '__main__': + print(recommendationengine.CatalogItem.__module__) + unittest.main() diff --git a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py index 68139cf54e5d..f6d33409c7e9 100644 --- a/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py +++ b/sdks/python/apache_beam/testing/load_tests/load_test_metrics_utils.py @@ -48,7 +48,7 @@ from apache_beam.utils.timestamp import Timestamp try: - from google.cloud import bigquery + from google.cloud import bigquery # type: ignore from google.cloud.bigquery.schema import SchemaField from google.cloud.exceptions import NotFound except ImportError: diff --git a/sdks/python/setup.py b/sdks/python/setup.py index f90f5de98695..db6ea6672a14 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -200,7 +200,8 @@ def get_version(): 'google-cloud-videointelligence>=1.8.0,<2', 'google-cloud-vision>=0.38.0,<2', # GCP Package required by Google Cloud Profiler. - 'google-cloud-profiler>=3.0.4,<4' + 'google-cloud-profiler>=3.0.4,<4', + 'google-cloud-recommendations-ai>=0.1.0,<=0.2.0' ] INTERACTIVE_BEAM = [