Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding support to return additional features from vector retrieval for Milvus db #4971

Merged
merged 5 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 143 additions & 2 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from feast.feast_object import FeastObject
from feast.feature_service import FeatureService
from feast.feature_view import DUMMY_ENTITY, DUMMY_ENTITY_NAME, FeatureView
from feast.field import Field
from feast.inference import (
update_data_sources_with_inferred_event_timestamp_col,
update_feature_views_with_inferred_features_and_entities,
Expand Down Expand Up @@ -1834,7 +1835,6 @@ def retrieve_online_documents(
top_k,
distance_metric,
)

# TODO currently not return the vector value since it is same as feature value, if embedding is supported,
# the feature value can be raw text before embedded
entity_key_vals = [feature[1] for feature in document_features]
Expand Down Expand Up @@ -1862,6 +1862,66 @@ def retrieve_online_documents(
)
return OnlineResponse(online_features_response)

def retrieve_online_documents_v2(
self,
query: Union[str, List[float]],
top_k: int,
features: List[str],
distance_metric: Optional[str] = "L2",
) -> OnlineResponse:
"""
Retrieves the top k closest document features. Note, embeddings are a subset of features.

Args:
features: The list of features that should be retrieved from the online document store. These features can be
specified either as a list of string document feature references or as a feature service. String feature
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
query: The query to retrieve the closest document features for.
top_k: The number of closest document features to retrieve.
distance_metric: The distance metric to use for retrieval.
"""
if isinstance(query, str):
raise ValueError(
"Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents."
)

(
available_feature_views,
_,
) = utils._get_feature_views_to_use(
registry=self._registry,
project=self.project,
features=features,
allow_cache=True,
hide_dummy_entity=False,
)
feature_view_set = set()
for feature in features:
feature_view_name = feature.split(":")[0]
feature_view = self.get_feature_view(feature_view_name)
feature_view_set.add(feature_view.name)
if len(feature_view_set) > 1:
raise ValueError("Document retrieval only supports a single feature view.")
requested_features = [
f.split(":")[1] for f in features if isinstance(f, str) and ":" in f
]

requested_feature_view = available_feature_views[0]
if not requested_feature_view:
raise ValueError(
f"Feature view {requested_feature_view} not found in the registry."
)

provider = self._get_provider()
return self._retrieve_from_online_store_v2(
provider,
requested_feature_view,
requested_features,
query,
top_k,
distance_metric,
)

def _retrieve_from_online_store(
self,
provider: Provider,
Expand All @@ -1879,6 +1939,10 @@ def _retrieve_from_online_store(
"""
Search and return document features from the online document store.
"""
vector_field_metadata = _get_feature_view_vector_field_metadata(table)
if vector_field_metadata:
distance_metric = vector_field_metadata.vector_search_metric

documents = provider.retrieve_online_documents(
config=self.config,
table=table,
Expand All @@ -1892,7 +1956,7 @@ def _retrieve_from_online_store(
read_row_protos = []
row_ts_proto = Timestamp()

for row_ts, entity_key, feature_val, vector_value, distance_val in documents:
for row_ts, entity_key, feature_val, vector_value, distance_val in documents: # type: ignore[misc]
# Reset timestamp to default or update if row_ts is not None
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)
Expand All @@ -1917,6 +1981,70 @@ def _retrieve_from_online_store(
)
return read_row_protos

def _retrieve_from_online_store_v2(
self,
provider: Provider,
table: FeatureView,
requested_features: List[str],
query: List[float],
top_k: int,
distance_metric: Optional[str],
) -> OnlineResponse:
"""
Search and return document features from the online document store.
"""
vector_field_metadata = _get_feature_view_vector_field_metadata(table)
if vector_field_metadata:
distance_metric = vector_field_metadata.vector_search_metric

documents = provider.retrieve_online_documents_v2(
config=self.config,
table=table,
requested_features=requested_features,
query=query,
top_k=top_k,
distance_metric=distance_metric,
)

entity_key_dict: Dict[str, List[ValueProto]] = {}
datevals, entityvals, list_of_feature_dicts = [], [], []
for row_ts, entity_key, feature_dict in documents: # type: ignore[misc]
datevals.append(row_ts)
entityvals.append(entity_key)
list_of_feature_dicts.append(feature_dict)
if entity_key:
for key, value in zip(entity_key.join_keys, entity_key.entity_values):
python_value = value
if key not in entity_key_dict:
entity_key_dict[key] = []
entity_key_dict[key].append(python_value)

table_entity_values, idxs = utils._get_unique_entities_from_values(
entity_key_dict,
)

features_to_request: List[str] = []
if requested_features:
features_to_request = requested_features + ["distance"]
else:
features_to_request = ["distance"]
feature_data = utils._convert_rows_to_protobuf(
requested_features=features_to_request,
read_rows=list(zip(datevals, list_of_feature_dicts)),
)

online_features_response = GetOnlineFeaturesResponse(results=[])
utils._populate_response_from_feature_data(
feature_data=feature_data,
indexes=idxs,
online_features_response=online_features_response,
full_feature_names=False,
requested_features=features_to_request,
table=table,
)

return OnlineResponse(online_features_response)

def serve(
self,
host: str,
Expand Down Expand Up @@ -2266,3 +2394,16 @@ def _validate_data_sources(data_sources: List[DataSource]):
raise DataSourceRepeatNamesException(case_insensitive_ds_name)
else:
ds_names.add(case_insensitive_ds_name)


def _get_feature_view_vector_field_metadata(
feature_view: FeatureView,
) -> Optional[Field]:
vector_fields = [field for field in feature_view.schema if field.vector_index]
if len(vector_fields) > 1:
raise ValueError(
f"Feature view {feature_view.name} has multiple vector fields. Only one vector field per feature view is supported."
)
if not vector_fields:
return None
return vector_fields[0]
Loading
Loading