Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Python] Add feast feature store handler for enrichment transform #30957

Merged
merged 10 commits into from
Apr 26, 2024
1 change: 0 additions & 1 deletion .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run"
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
import logging
import tempfile
from pathlib import Path
from typing import Any
from typing import Callable
from typing import List
from typing import Mapping
from typing import Optional

import apache_beam as beam
from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem
from apache_beam.io.filesystems import FileSystems
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel
from feast import FeatureStore
Expand All @@ -30,6 +33,8 @@
'FeastFeatureStoreEnrichmentHandler',
]

EntityRowFn = Callable[[beam.Row], Mapping[str, Any]]

_LOGGER = logging.getLogger(__name__)

LOCAL_FEATURE_STORE_YAML_FILENAME = 'fs_yaml_file.yaml'
Expand All @@ -38,8 +43,7 @@
def download_fs_yaml_file(gcs_fs_yaml_file: str):
"""Download the feature store config file for Feast."""
try:
fs = GCSFileSystem(pipeline_options={})
with fs.open(gcs_fs_yaml_file, 'r') as gcs_file:
with FileSystems.open(gcs_fs_yaml_file, 'r') as gcs_file:
with tempfile.NamedTemporaryFile(suffix=LOCAL_FEATURE_STORE_YAML_FILENAME,
delete=False) as local_file:
local_file.write(gcs_file.read())
Expand All @@ -55,19 +59,28 @@ def _validate_feature_names(feature_names, feature_service_name):
if ((not feature_names and not feature_service_name) or
bool(feature_names and feature_service_name)):
raise ValueError(
'Please provide either a list of feature names to fetch '
'from online store or a feature service name for the '
'Feast online feature store!')
'Please provide exactly one of a list of feature names to fetch '
'from online store (`feature_names`) or a feature service name for '
'the Feast online feature store (`feature_service_name`).')


def _validate_feature_store_yaml_path_exists(fs_yaml_file):
"""Check if the feature store yaml path exists."""
fs = GCSFileSystem(pipeline_options={})
if not fs.exists(fs_yaml_file):
if not FileSystems.exists(fs_yaml_file):
raise ValueError(
'The feature store yaml path (%s) does not exist.' % fs_yaml_file)


def _validate_entity_key_exists(entity_id, entity_row_fn):
"""Checks if the entity key or a lambda to build entity key exists."""
if ((not entity_row_fn and not entity_id) or
bool(entity_row_fn and entity_id)):
raise ValueError(
"Please specify exactly one of a `entity_id` or a lambda "
"function with `entity_row_fn` to extract the entity id "
"from the input row.")


class FeastFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row,
beam.Row]):
"""Enrichment handler to interact with the Feast feature store.
Expand All @@ -81,12 +94,13 @@ class FeastFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row,
"""
def __init__(
self,
entity_id: str,
feature_store_yaml_path: str,
feature_names: Optional[List[str]] = None,
feature_service_name: Optional[str] = "",
full_feature_names: Optional[bool] = False,
entity_id: str = "",
*,
entity_row_fn: Optional[EntityRowFn] = None,
exception_level: ExceptionLevel = ExceptionLevel.WARN,
):
"""Initializes an instance of `FeastFeatureStoreEnrichmentHandler`.
Expand All @@ -95,13 +109,21 @@ def __init__(
entity_id (str): entity name for the entity associated with the features.
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved
The `entity_id` is used to extract the entity value from the input row.
feature_store_yaml_path (str): The path to a YAML configuration file for
the Feast feature store.
the Feast feature store. See
https://docs.feast.dev/reference/feature-repository/feature-store-yaml
for configuration options supported by Feast.
feature_names: A list of feature names to be retrieved from the online
Feast feature store.
feature_service_name (str): The name of the feature service containing
the features to fetch from the online Feast feature store.
full_feature_names (bool): Whether to use full feature names
(including namespaces, etc.). Defaults to False.
entity_row_fn: a lambda function that returns a dictionary with
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved
a mapping from the entity key column name to entity key value from the
input row. It is used to build/extract the entity dict for feature
retrieval.
See https://docs.feast.dev/getting-started/concepts/feature-retrieval
for more information.
exception_level: a `enum.Enum` value from
`apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel`
to set the level when `None` feature values are fetched from the
Expand All @@ -112,7 +134,9 @@ def __init__(
self.feature_names = feature_names
self.feature_service_name = feature_service_name
self.full_feature_names = full_feature_names
self.entity_row_fn = entity_row_fn
self._exception_level = exception_level
_validate_entity_key_exists(self.entity_id, self.entity_row_fn)
_validate_feature_store_yaml_path_exists(self.feature_store_yaml_path)
_validate_feature_names(self.feature_names, self.feature_service_name)

Expand Down Expand Up @@ -144,12 +168,14 @@ def __call__(self, request: beam.Row, *args, **kwargs):
Args:
request: the input `beam.Row` to enrich.
"""
request_dict = request._asdict()
if self.entity_row_fn:
entity_dict = self.entity_row_fn(request)
else:
request_dict = request._asdict()
entity_dict = {self.entity_id: request_dict[self.entity_id]}
feature_values = self.store.get_online_features(
features=self.features,
entity_rows=[{
self.entity_id: request_dict[self.entity_id]
}],
entity_rows=[entity_dict],
full_feature_names=self.full_feature_names).to_dict()
# get_online_features() returns a list of feature values per entity-id.
# Since we do this per entity, the list of feature values only contain
Expand All @@ -164,4 +190,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def get_cache_key(self, request: beam.Row) -> str:
"""Returns a string formatted with unique entity-id for the feature values.
"""
return 'entity_id: %s' % request._asdict()[self.entity_id]
if self.entity_row_fn:
entity_dict = self.entity_row_fn(request)
entity_id = list(entity_dict.keys())[0]
else:
entity_id = self.entity_id
return 'entity_id: %s' % request._asdict()[entity_id]
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
"""

import unittest
from typing import Any
from typing import Mapping

import pytest

Expand All @@ -39,6 +41,11 @@
'Feast feature store test dependencies are not installed.')


def _entity_row_fn(request: beam.Row) -> Mapping[str, Any]:
entity_value = request.user_id # type: ignore[attr-defined]
return {'user_id': entity_value}


@pytest.mark.uses_feast
class TestFeastEnrichmentHandler(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -87,6 +94,28 @@ def test_feast_enrichment_bad_feature_service_name(self):
res = test_pipeline.run()
res.wait_until_finish()

def test_feast_enrichment_with_lambda(self):
requests = [
beam.Row(user_id=2, product_id=1),
beam.Row(user_id=6, product_id=2),
beam.Row(user_id=9, product_id=3),
]
expected_fields = [
'user_id', 'product_id', 'state', 'country', 'gender', 'age'
]
handler = FeastFeatureStoreEnrichmentHandler(
feature_store_yaml_path=self.feature_store_yaml_file,
feature_service_name=self.feature_service_name,
entity_row_fn=_entity_row_fn,
)

with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (
test_pipeline
| beam.Create(requests)
| Enrichment(handler)
| beam.ParDo(ValidateResponse(expected_fields)))


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
#
import unittest

from parameterized import parameterized

try:
from apache_beam.transforms.enrichment_handlers.feast_feature_store import \
FeastFeatureStoreEnrichmentHandler
from apache_beam.transforms.enrichment_handlers.feast_feature_store_it_test \
import _entity_row_fn
except ImportError:
raise unittest.SkipTest(
'Feast feature store test dependencies are not installed.')
Expand Down Expand Up @@ -49,6 +53,15 @@ def test_feast_enrichment_no_feature_service(self):
feature_store_yaml_path=self.feature_store_yaml_file,
)

@parameterized.expand([('user_id', _entity_row_fn), ('', None)])
def test_feast_enrichment_invalid_args(self, entity_id, entity_row_fn):
with self.assertRaises(ValueError):
_ = FeastFeatureStoreEnrichmentHandler(
feature_store_yaml_path=self.feature_store_yaml_file,
entity_id=entity_id,
entity_row_fn=entity_row_fn,
)


if __name__ == '__main__':
unittest.main()
Loading