From 1634940c91182fbd080556949d6c2557288216fb Mon Sep 17 00:00:00 2001 From: Matthew Tang Date: Thu, 21 Dec 2023 11:23:45 -0800 Subject: [PATCH] feat: Support custom target y column name for Bigframes Tensorflow PiperOrigin-RevId: 592910297 --- .../system/vertexai/test_bigframes_tensorflow.py | 8 +++++--- tests/unit/vertexai/test_any_serializer.py | 1 + .../_workflow/serialization_engine/serializers.py | 15 +++++++++------ 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/system/vertexai/test_bigframes_tensorflow.py b/tests/system/vertexai/test_bigframes_tensorflow.py index 22ecd068ec..18cb384cff 100644 --- a/tests/system/vertexai/test_bigframes_tensorflow.py +++ b/tests/system/vertexai/test_bigframes_tensorflow.py @@ -80,8 +80,7 @@ def test_remote_execution_keras(self, shared_state): "virginica": 1, "setosa": 2, } - df["target"] = df["species"].map(species_categories) - df = df.drop(columns=["species"]) + df["species"] = df["species"].map(species_categories) train, _ = bf_train_test_split(df, test_size=0.2) @@ -96,7 +95,10 @@ def test_remote_execution_keras(self, shared_state): enable_cuda=True, display_name=self._make_display_name("bigframes-keras-training"), ) - model.fit.vertex.remote_config.serializer_args[train] = {"batch_size": 10} + model.fit.vertex.remote_config.serializer_args[train] = { + "batch_size": 10, + "target_col": "species", + } # Train model on Vertex model.fit(train, epochs=10) diff --git a/tests/unit/vertexai/test_any_serializer.py b/tests/unit/vertexai/test_any_serializer.py index b46444f44c..d29ec441e2 100644 --- a/tests/unit/vertexai/test_any_serializer.py +++ b/tests/unit/vertexai/test_any_serializer.py @@ -1106,6 +1106,7 @@ def test_any_serializer_deserialize_bigframe_tensorflow( any_serializer_instance._instances[serializers.BigframeSerializer], serialized_gcs_path=fake_gcs_path, batch_size=None, + target_col=None, ) def test_any_serializer_deserialize_tf_dataset( diff --git a/vertexai/preview/_workflow/serialization_engine/serializers.py b/vertexai/preview/_workflow/serialization_engine/serializers.py index 5161248e07..6886a47565 100644 --- a/vertexai/preview/_workflow/serialization_engine/serializers.py +++ b/vertexai/preview/_workflow/serialization_engine/serializers.py @@ -1200,7 +1200,7 @@ def deserialize( return self._deserialize_torch(serialized_gcs_path) elif detected_framework == "tensorflow": return self._deserialize_tensorflow( - serialized_gcs_path, kwargs.get("batch_size") + serialized_gcs_path, kwargs.get("batch_size"), kwargs.get("target_col") ) else: raise ValueError(f"Unsupported framework: {detected_framework}") @@ -1273,14 +1273,18 @@ def reduce_tensors(a, b): return functools.reduce(reduce_tensors, list(parquet_df_dp)) def _deserialize_tensorflow( - self, serialized_gcs_path: str, batch_size: Optional[int] = None + self, + serialized_gcs_path: str, + batch_size: Optional[int] = None, + target_col: Optional[str] = None, ) -> TFDataset: """Tensorflow deserializes parquet (GCS) --> tf.data.Dataset serialized_gcs_path is a folder containing one or more parquet files. """ - # Set default batch_size + # Set default kwarg values batch_size = batch_size or DEFAULT_TENSORFLOW_BATCHSIZE + target_col = target_col.encode("ASCII") or b"target" # Deserialization at remote environment try: @@ -1301,13 +1305,12 @@ def _deserialize_tensorflow( ds_shard = tfio.IODataset.from_parquet(file_name) ds = ds.concatenate(ds_shard) - # TODO(b/296474656) Parquet must have "target" column for y def map_fn(row): - target = row[b"target"] + target = row[target_col] row = { k: tf.expand_dims(v, -1) for k, v in row.items() - if k != b"target" and k != b"index" + if k != target_col and k != b"index" } def reduce_fn(a, b):