Skip to content

Commit

Permalink
Update artifacts fetcher to download artifacts locally using FileSyst…
Browse files Browse the repository at this point in the history
…ems (#30202)

* Update artifacts fetcher to download artifacts

* Use context for tempdir creation

* Refactor artifacts fetcher to support gcs path

* Remove defaults

* Pass vocab filename param

* Fix path

* Update sdks/python/apache_beam/ml/transforms/utils.py

Co-authored-by: tvalentyn <[email protected]>

* Update sdks/python/apache_beam/ml/transforms/utils.py

Co-authored-by: tvalentyn <[email protected]>

* Remove num_workers

* Fix lint

---------

Co-authored-by: tvalentyn <[email protected]>
  • Loading branch information
AnandInguva and tvalentyn authored Feb 9, 2024
1 parent 4c2b8b1 commit aefcada
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def test_process_large_movie_review_dataset(self):

artifacts_fetcher = ArtifactsFetcher(artifact_location=artifact_location)

actual_vocab_list = artifacts_fetcher.get_vocab_list()
vocab_filename = f'vocab_{vocab_tfidf_processing.REVIEW_COLUMN}'
actual_vocab_list = artifacts_fetcher.get_vocab_list(
vocab_filename=vocab_filename)

expected_artifact_filepath = 'gs://apache-beam-ml/testing/expected_outputs/compute_and_apply_vocab' # pylint: disable=line-too-long

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,12 @@ def preprocess_data(
top_k=VOCAB_SIZE,
frequency_threshold=10,
columns=[REVIEW_COLUMN],
vocab_filename='vocab',
split_string_by_delimiter=DELIMITERS)).with_transform(
TFIDF(columns=[REVIEW_COLUMN], vocab_size=VOCAB_SIZE))
TFIDF(
columns=[REVIEW_COLUMN],
vocab_size=VOCAB_SIZE,
))
data_pcoll = data_pcoll | 'MLTransform' >> ml_transform

data_pcoll = (
Expand Down
43 changes: 34 additions & 9 deletions sdks/python/apache_beam/ml/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,47 @@
__all__ = ['ArtifactsFetcher']

import os
import tempfile
import typing

from google.cloud.storage import Client
from google.cloud.storage import transfer_manager

import tensorflow_transform as tft
from apache_beam.ml.transforms import base


class ArtifactsFetcher():
def download_artifacts_from_gcs(bucket_name, prefix, local_path):
"""Downloads artifacts from GCS to the local file system.
Args:
bucket_name: The name of the GCS bucket to download from.
prefix: Prefix of GCS objects to download.
local_path: The local path to download the folder to.
"""
client = Client()
bucket = client.get_bucket(bucket_name)
blobs = [blob.name for blob in bucket.list_blobs(prefix=prefix)]
_ = transfer_manager.download_many_to_path(
bucket, blobs, destination_directory=local_path)


class ArtifactsFetcher:
"""
Utility class used to fetch artifacts from the artifact_location passed
to the TFTProcessHandlers in MLTransform.
This is intended to be used for testing purposes only.
"""
def __init__(self, artifact_location):
def __init__(self, artifact_location: str):
tempdir = tempfile.mkdtemp()
if artifact_location.startswith('gs://'):
parts = artifact_location[5:].split('/')
bucket_name = parts[0]
prefix = '/'.join(parts[1:])
download_artifacts_from_gcs(bucket_name, prefix, tempdir)

assert os.listdir(tempdir), f"No files found in {artifact_location}"
artifact_location = os.path.join(tempdir, prefix)
files = os.listdir(artifact_location)
files.remove(base._ATTRIBUTE_FILE_NAME)
# TODO: https://github.com/apache/beam/issues/29356
Expand All @@ -43,9 +72,7 @@ def __init__(self, artifact_location):
self._artifact_location = os.path.join(artifact_location, files[0])
self.transform_output = tft.TFTransformOutput(self._artifact_location)

def get_vocab_list(
self,
vocab_filename: str = 'compute_and_apply_vocab') -> typing.List[bytes]:
def get_vocab_list(self, vocab_filename: str) -> typing.List[bytes]:
"""
Returns list of vocabulary terms created during MLTransform.
"""
Expand All @@ -57,13 +84,11 @@ def get_vocab_list(
vocab_filename)) from e
return [x.decode('utf-8') for x in vocab_list]

def get_vocab_filepath(
self, vocab_filename: str = 'compute_and_apply_vocab') -> str:
def get_vocab_filepath(self, vocab_filename: str) -> str:
"""
Return the path to the vocabulary file created during MLTransform.
"""
return self.transform_output.vocabulary_file_by_name(vocab_filename)

def get_vocab_size(
self, vocab_filename: str = 'compute_and_apply_vocab') -> int:
def get_vocab_size(self, vocab_filename: str) -> int:
return self.transform_output.vocabulary_size_by_name(vocab_filename)

0 comments on commit aefcada

Please sign in to comment.