Skip to content

Commit

Permalink
feat: Support custom target y column name for Bigframes Tensorflow
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 592910297
  • Loading branch information
matthew29tang authored and copybara-github committed Dec 21, 2023
1 parent 6e6d005 commit 1634940
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
8 changes: 5 additions & 3 deletions tests/system/vertexai/test_bigframes_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def test_remote_execution_keras(self, shared_state):
"virginica": 1,
"setosa": 2,
}
df["target"] = df["species"].map(species_categories)
df = df.drop(columns=["species"])
df["species"] = df["species"].map(species_categories)

train, _ = bf_train_test_split(df, test_size=0.2)

Expand All @@ -96,7 +95,10 @@ def test_remote_execution_keras(self, shared_state):
enable_cuda=True,
display_name=self._make_display_name("bigframes-keras-training"),
)
model.fit.vertex.remote_config.serializer_args[train] = {"batch_size": 10}
model.fit.vertex.remote_config.serializer_args[train] = {
"batch_size": 10,
"target_col": "species",
}

# Train model on Vertex
model.fit(train, epochs=10)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/vertexai/test_any_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,7 @@ def test_any_serializer_deserialize_bigframe_tensorflow(
any_serializer_instance._instances[serializers.BigframeSerializer],
serialized_gcs_path=fake_gcs_path,
batch_size=None,
target_col=None,
)

def test_any_serializer_deserialize_tf_dataset(
Expand Down
15 changes: 9 additions & 6 deletions vertexai/preview/_workflow/serialization_engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ def deserialize(
return self._deserialize_torch(serialized_gcs_path)
elif detected_framework == "tensorflow":
return self._deserialize_tensorflow(
serialized_gcs_path, kwargs.get("batch_size")
serialized_gcs_path, kwargs.get("batch_size"), kwargs.get("target_col")
)
else:
raise ValueError(f"Unsupported framework: {detected_framework}")
Expand Down Expand Up @@ -1273,14 +1273,18 @@ def reduce_tensors(a, b):
return functools.reduce(reduce_tensors, list(parquet_df_dp))

def _deserialize_tensorflow(
self, serialized_gcs_path: str, batch_size: Optional[int] = None
self,
serialized_gcs_path: str,
batch_size: Optional[int] = None,
target_col: Optional[str] = None,
) -> TFDataset:
"""Tensorflow deserializes parquet (GCS) --> tf.data.Dataset
serialized_gcs_path is a folder containing one or more parquet files.
"""
# Set default batch_size
# Set default kwarg values
batch_size = batch_size or DEFAULT_TENSORFLOW_BATCHSIZE
target_col = target_col.encode("ASCII") or b"target"

# Deserialization at remote environment
try:
Expand All @@ -1301,13 +1305,12 @@ def _deserialize_tensorflow(
ds_shard = tfio.IODataset.from_parquet(file_name)
ds = ds.concatenate(ds_shard)

# TODO(b/296474656) Parquet must have "target" column for y
def map_fn(row):
target = row[b"target"]
target = row[target_col]
row = {
k: tf.expand_dims(v, -1)
for k, v in row.items()
if k != b"target" and k != b"index"
if k != target_col and k != b"index"
}

def reduce_fn(a, b):
Expand Down

0 comments on commit 1634940

Please sign in to comment.