Skip to content

Commit

Permalink
feat: Support bigframes sharded parquet ingestion at remote deseriali…
Browse files Browse the repository at this point in the history
…zation (Tensorflow)

PiperOrigin-RevId: 562030438
  • Loading branch information
matthew29tang authored and copybara-github committed Sep 1, 2023
1 parent 468e6e7 commit a8f85ec
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions vertexai/preview/_workflow/serialization_engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,8 @@ def _deserialize_sklearn(self, serialized_gcs_path: str) -> PandasData:
By default, sklearn returns a numpy array which uses CloudPickleSerializer.
If a bigframes.dataframe.DataFrame is desired for the return type,
b/291147206 (cl/548228568) is required
serialized_gcs_path is a folder containing one or more parquet files.
"""
# Deserialization at remote environment
try:
Expand All @@ -1069,7 +1071,7 @@ def _deserialize_sklearn(self, serialized_gcs_path: str) -> PandasData:
def _deserialize_torch(self, serialized_gcs_path: str) -> TorchTensor:
"""Torch deserializes parquet (GCS) --> torch.tensor
Assumes one parquet file is created.
serialized_gcs_path is a folder containing one or more parquet files.
"""
# Deserialization at remote environment
try:
Expand Down Expand Up @@ -1107,7 +1109,7 @@ def reduce_tensors(a, b):
def _deserialize_tensorflow(self, serialized_gcs_path: str) -> TFDataset:
"""Tensorflow deserializes parquet (GCS) --> tf.data.Dataset
Assumes one parquet file is created.
serialized_gcs_path is a folder containing one or more parquet files.
"""
# Deserialization at remote environment
try:
Expand All @@ -1118,14 +1120,15 @@ def _deserialize_tensorflow(self, serialized_gcs_path: str) -> TFDataset:
) from e

# Deserialization always happens at remote, so gcs filesystem is mounted to /gcs/
# TODO(b/296475384): Handle multiple parquet shards
if len(os.listdir(serialized_gcs_path + "/")) > 1:
raise RuntimeError(
"Large datasets which are serialized into sharded parquet are not yet supported (b/296475384)"
)
files = os.listdir(serialized_gcs_path + "/")
files = list(
map(lambda file_name: serialized_gcs_path + "/" + file_name, files)
)
ds = tfio.IODataset.from_parquet(files[0])

single_parquet_gcs_path = serialized_gcs_path + "/" + "000000000000"
ds = tfio.IODataset.from_parquet(single_parquet_gcs_path)
for file_name in files[1:]:
ds_shard = tfio.IODataset.from_parquet(file_name)
ds = ds.concatenate(ds_shard)

# TODO(b/296474656) Parquet must have "target" column for y
def map_fn(row):
Expand Down

0 comments on commit a8f85ec

Please sign in to comment.