Skip to content

Commit

Permalink
feat: add start_time support for BatchReadFeatureValues wrapper methods.
Browse files Browse the repository at this point in the history
Always run BatchRead Dataflow tests.

PiperOrigin-RevId: 520782661
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Mar 31, 2023
1 parent f66beaa commit 91d8459
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 15 deletions.
36 changes: 35 additions & 1 deletion google/cloud/aiplatform/featurestore/featurestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from google.auth import credentials as auth_credentials
from google.protobuf import field_mask_pb2
from google.protobuf import timestamp_pb2

from google.cloud.aiplatform import base
from google.cloud.aiplatform.compat.types import (
Expand All @@ -31,7 +32,10 @@
from google.cloud.aiplatform import featurestore
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import featurestore_utils, resource_manager_utils
from google.cloud.aiplatform.utils import (
featurestore_utils,
resource_manager_utils,
)

from google.cloud import bigquery

Expand Down Expand Up @@ -695,6 +699,7 @@ def _validate_and_get_batch_read_feature_values_request(
read_instances: Union[gca_io.BigQuerySource, gca_io.CsvSource],
pass_through_fields: Optional[List[str]] = None,
feature_destination_fields: Optional[Dict[str, str]] = None,
start_time: [timestamp_pb2.Timestamp] = None,
) -> gca_featurestore_service.BatchReadFeatureValuesRequest:
"""Validates and gets batch_read_feature_values_request
Expand Down Expand Up @@ -736,6 +741,10 @@ def _validate_and_get_batch_read_feature_values_request(
'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar',
}
start_time (timestamp_pb2.Timestamp):
Optional. Excludes Feature values with feature generation timestamp before this timestamp. If not set, retrieve
oldest values kept in Feature Store. Timestamp, if present, must not have higher than millisecond precision.
Returns:
gca_featurestore_service.BatchReadFeatureValuesRequest: batch read feature values request
"""
Expand Down Expand Up @@ -819,6 +828,9 @@ def _validate_and_get_batch_read_feature_values_request(
for pass_through_field in pass_through_fields
]

if start_time is not None:
batch_read_feature_values_request.start_time = start_time

return batch_read_feature_values_request

@base.optional_sync(return_input_arg="self")
Expand All @@ -829,6 +841,7 @@ def batch_serve_to_bq(
read_instances_uri: str,
pass_through_fields: Optional[List[str]] = None,
feature_destination_fields: Optional[Dict[str, str]] = None,
start_time: Optional[timestamp_pb2.Timestamp] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
serve_request_timeout: Optional[float] = None,
sync: bool = True,
Expand Down Expand Up @@ -903,8 +916,14 @@ def batch_serve_to_bq(
'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id1/features/f_id11': 'foo',
'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar',
}
start_time (timestamp_pb2.Timestamp):
Optional. Excludes Feature values with feature generation timestamp before this timestamp. If not set, retrieve
oldest values kept in Feature Store. Timestamp, if present, must not have higher than millisecond precision.
serve_request_timeout (float):
Optional. The timeout for the serve request in seconds.
Returns:
Featurestore: The featurestore resource object batch read feature values from.
Expand All @@ -924,6 +943,7 @@ def batch_serve_to_bq(
feature_destination_fields=feature_destination_fields,
read_instances=read_instances,
pass_through_fields=pass_through_fields,
start_time=start_time,
)
)

Expand All @@ -942,6 +962,7 @@ def batch_serve_to_gcs(
read_instances_uri: str,
pass_through_fields: Optional[List[str]] = None,
feature_destination_fields: Optional[Dict[str, str]] = None,
start_time: Optional[timestamp_pb2.Timestamp] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync: bool = True,
serve_request_timeout: Optional[float] = None,
Expand Down Expand Up @@ -1037,6 +1058,11 @@ def batch_serve_to_gcs(
'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id1/features/f_id11': 'foo',
'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar',
}
start_time (timestamp_pb2.Timestamp):
Optional. Excludes Feature values with feature generation timestamp before this timestamp. If not set, retrieve
oldest values kept in Feature Store. Timestamp, if present, must not have higher than millisecond precision.
serve_request_timeout (float):
Optional. The timeout for the serve request in seconds.
Expand Down Expand Up @@ -1075,6 +1101,7 @@ def batch_serve_to_gcs(
feature_destination_fields=feature_destination_fields,
read_instances=read_instances,
pass_through_fields=pass_through_fields,
start_time=start_time,
)
)

Expand All @@ -1090,6 +1117,7 @@ def batch_serve_to_df(
read_instances_df: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd'
pass_through_fields: Optional[List[str]] = None,
feature_destination_fields: Optional[Dict[str, str]] = None,
start_time: Optional[timestamp_pb2.Timestamp] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
serve_request_timeout: Optional[float] = None,
bq_dataset_id: Optional[str] = None,
Expand Down Expand Up @@ -1182,6 +1210,11 @@ def batch_serve_to_df(
for temporarily staging data. If specified, caller must have
`bigquery.tables.create` permissions for Dataset.
start_time (timestamp_pb2.Timestamp):
Optional. Excludes Feature values with feature generation timestamp before this timestamp. If not set, retrieve
oldest values kept in Feature Store. Timestamp, if present, must not have higher than millisecond precision.
Returns:
pd.DataFrame: The pandas DataFrame containing feature values from batch serving.
Expand Down Expand Up @@ -1264,6 +1297,7 @@ def batch_serve_to_df(
feature_destination_fields=feature_destination_fields,
request_metadata=request_metadata,
serve_request_timeout=serve_request_timeout,
start_time=start_time,
)

bigquery_storage_read_client = bigquery_storage.BigQueryReadClient(
Expand Down
195 changes: 181 additions & 14 deletions tests/unit/aiplatform/test_featurestores.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,8 @@
)

from google.cloud import bigquery

try:
from google.cloud import bigquery_storage
from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream

_USE_BQ_STORAGE = True
except ImportError:
_USE_BQ_STORAGE = False
from google.cloud import bigquery_storage
from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream

from google.cloud import resourcemanager

Expand Down Expand Up @@ -283,6 +277,7 @@
_TEST_GCS_SOURCE_TYPE_AVRO = "avro"
_TEST_GCS_SOURCE_TYPE_INVALID = "json"

_TEST_BATCH_SERVE_START_TIME = datetime.datetime.now()
_TEST_BQ_DESTINATION_URI = "bq://project.dataset.table_name"
_TEST_GCS_OUTPUT_URI_PREFIX = "gs://my_bucket/path/to_prefix"

Expand Down Expand Up @@ -1613,6 +1608,57 @@ def test_batch_serve_to_bq_with_timeout_not_explicitly_set(
timeout=None,
)

@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_featurestore_mock")
def test_batch_serve_to_bq_with_start_time(
self, batch_read_feature_values_mock, sync
):
aiplatform.init(project=_TEST_PROJECT)
my_featurestore = aiplatform.Featurestore(
featurestore_name=_TEST_FEATURESTORE_NAME
)

expected_entity_type_specs = [
_get_entity_type_spec_proto_with_feature_ids(
entity_type_id="my_entity_type_id_1",
feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"],
),
_get_entity_type_spec_proto_with_feature_ids(
entity_type_id="my_entity_type_id_2",
feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"],
),
]

expected_batch_read_feature_values_request = (
gca_featurestore_service.BatchReadFeatureValuesRequest(
featurestore=my_featurestore.resource_name,
destination=gca_featurestore_service.FeatureValueDestination(
bigquery_destination=_TEST_BQ_DESTINATION,
),
entity_type_specs=expected_entity_type_specs,
bigquery_read_instances=_TEST_BQ_SOURCE,
start_time=_TEST_BATCH_SERVE_START_TIME,
)
)

my_featurestore.batch_serve_to_bq(
bq_destination_output_uri=_TEST_BQ_DESTINATION_URI,
serving_feature_ids=_TEST_SERVING_FEATURE_IDS,
read_instances_uri=_TEST_BQ_SOURCE_URI,
sync=sync,
serve_request_timeout=None,
start_time=_TEST_BATCH_SERVE_START_TIME,
)

if not sync:
my_featurestore.wait()

batch_read_feature_values_mock.assert_called_once_with(
request=expected_batch_read_feature_values_request,
metadata=_TEST_REQUEST_METADATA,
timeout=None,
)

@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_featurestore_mock")
def test_batch_serve_to_gcs(self, batch_read_feature_values_mock, sync):
Expand Down Expand Up @@ -1677,9 +1723,58 @@ def test_batch_serve_to_gcs_with_invalid_gcs_destination_type(self):
read_instances_uri=_TEST_GCS_CSV_SOURCE_URI,
)

@pytest.mark.skipif(
_USE_BQ_STORAGE is False, reason="batch_serve_to_df requires bigquery_storage"
)
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_featurestore_mock")
def test_batch_serve_to_gcs_with_start_time(
self, batch_read_feature_values_mock, sync
):
aiplatform.init(project=_TEST_PROJECT)
my_featurestore = aiplatform.Featurestore(
featurestore_name=_TEST_FEATURESTORE_NAME
)

expected_entity_type_specs = [
_get_entity_type_spec_proto_with_feature_ids(
entity_type_id="my_entity_type_id_1",
feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"],
),
_get_entity_type_spec_proto_with_feature_ids(
entity_type_id="my_entity_type_id_2",
feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"],
),
]

expected_batch_read_feature_values_request = (
gca_featurestore_service.BatchReadFeatureValuesRequest(
featurestore=my_featurestore.resource_name,
destination=gca_featurestore_service.FeatureValueDestination(
tfrecord_destination=_TEST_TFRECORD_DESTINATION,
),
entity_type_specs=expected_entity_type_specs,
csv_read_instances=_TEST_CSV_SOURCE,
start_time=_TEST_BATCH_SERVE_START_TIME,
)
)

my_featurestore.batch_serve_to_gcs(
gcs_destination_output_uri_prefix=_TEST_GCS_OUTPUT_URI_PREFIX,
gcs_destination_type=_TEST_GCS_DESTINATION_TYPE_TFRECORD,
serving_feature_ids=_TEST_SERVING_FEATURE_IDS,
read_instances_uri=_TEST_GCS_CSV_SOURCE_URI,
sync=sync,
serve_request_timeout=None,
start_time=_TEST_BATCH_SERVE_START_TIME,
)

if not sync:
my_featurestore.wait()

batch_read_feature_values_mock.assert_called_once_with(
request=expected_batch_read_feature_values_request,
metadata=_TEST_REQUEST_METADATA,
timeout=None,
)

@pytest.mark.usefixtures(
"get_featurestore_mock",
"bq_init_client_mock",
Expand Down Expand Up @@ -1753,9 +1848,6 @@ def test_batch_serve_to_df(self, batch_read_feature_values_mock):
timeout=None,
)

@pytest.mark.skipif(
_USE_BQ_STORAGE is False, reason="batch_serve_to_df requires bigquery_storage"
)
@pytest.mark.usefixtures(
"get_featurestore_mock",
"bq_init_client_mock",
Expand Down Expand Up @@ -1850,6 +1942,81 @@ def test_batch_serve_to_df_user_specified_bq_dataset(
bq_create_dataset_mock.assert_not_called()
bq_delete_dataset_mock.assert_not_called()

@pytest.mark.usefixtures(
"get_featurestore_mock",
"bq_init_client_mock",
"bq_init_dataset_mock",
"bq_create_dataset_mock",
"bq_load_table_from_dataframe_mock",
"bq_delete_dataset_mock",
"bqs_init_client_mock",
"bqs_create_read_session",
"get_project_mock",
)
@patch("uuid.uuid4", uuid_mock)
def test_batch_serve_to_df_with_start_time(self, batch_read_feature_values_mock):

aiplatform.init(project=_TEST_PROJECT_DIFF)

my_featurestore = aiplatform.Featurestore(
featurestore_name=_TEST_FEATURESTORE_NAME
)

read_instances_df = pd.DataFrame()

expected_temp_bq_dataset_name = (
f"temp_{_TEST_FEATURESTORE_ID}_{uuid.uuid4()}".replace("-", "_")
)
expecte_temp_bq_dataset_id = f"{_TEST_PROJECT}.{expected_temp_bq_dataset_name}"[
:1024
]
expected_temp_bq_read_instances_table_id = (
f"{expecte_temp_bq_dataset_id}.read_instances"
)
expected_temp_bq_batch_serve_table_id = (
f"{expecte_temp_bq_dataset_id}.batch_serve"
)

expected_entity_type_specs = [
_get_entity_type_spec_proto_with_feature_ids(
entity_type_id="my_entity_type_id_1",
feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"],
),
_get_entity_type_spec_proto_with_feature_ids(
entity_type_id="my_entity_type_id_2",
feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"],
),
]

expected_batch_read_feature_values_request = (
gca_featurestore_service.BatchReadFeatureValuesRequest(
featurestore=my_featurestore.resource_name,
destination=gca_featurestore_service.FeatureValueDestination(
bigquery_destination=gca_io.BigQueryDestination(
output_uri=f"bq://{expected_temp_bq_batch_serve_table_id}"
),
),
entity_type_specs=expected_entity_type_specs,
bigquery_read_instances=gca_io.BigQuerySource(
input_uri=f"bq://{expected_temp_bq_read_instances_table_id}"
),
start_time=_TEST_BATCH_SERVE_START_TIME,
)
)

my_featurestore.batch_serve_to_df(
serving_feature_ids=_TEST_SERVING_FEATURE_IDS,
read_instances_df=read_instances_df,
serve_request_timeout=None,
start_time=_TEST_BATCH_SERVE_START_TIME,
)

batch_read_feature_values_mock.assert_called_once_with(
request=expected_batch_read_feature_values_request,
metadata=_TEST_REQUEST_METADATA,
timeout=None,
)


@pytest.mark.usefixtures("google_auth_mock")
class TestEntityType:
Expand Down

0 comments on commit 91d8459

Please sign in to comment.