From f84dba4079053ebd1a9b1610a601adb9e18fd79e Mon Sep 17 00:00:00 2001 From: Anand Inguva <34158215+AnandInguva@users.noreply.github.com> Date: Fri, 14 Jul 2023 14:17:33 -0400 Subject: [PATCH] Refactor MLTransform basic example (#27430) * Refactor MLTransform basic example * Refactor and add comments on artifacts * Add example output in comments * Add comments --- .../ml_transform/ml_transform_basic.py | 122 +++++++++++++----- 1 file changed, 92 insertions(+), 30 deletions(-) diff --git a/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py b/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py index 65e943f5c697..2166d0db366e 100644 --- a/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py +++ b/sdks/python/apache_beam/examples/ml_transform/ml_transform_basic.py @@ -61,49 +61,111 @@ def parse_args(): return parser.parse_known_args() -def run(args): - data = [ - dict(x=["Let's", "go", "to", "the", "park"]), - dict(x=["I", "enjoy", "going", "to", "the", "park"]), - dict(x=["I", "enjoy", "reading", "books"]), - dict(x=["Beam", "can", "be", "fun"]), - dict(x=["The", "weather", "is", "really", "nice", "today"]), - dict(x=["I", "love", "to", "go", "to", "the", "park"]), - dict(x=["I", "love", "to", "read", "books"]), - dict(x=["I", "love", "to", "program"]), - ] +def preprocess_data_for_ml_training(train_data, artifact_mode, args): + """ + Preprocess the data for ML training. This method runs a pipeline to + preprocess the data needed for ML training. It produces artifacts that + can be used for ML inference later. + """ with beam.Pipeline() as p: - input_data = p | beam.Create(data) - - # arfifacts produce mode. - input_data |= ( - 'MLTransform' >> MLTransform( + train_data_pcoll = (p | "CreateData" >> beam.Create(train_data)) + + # When 'artifact_mode' is set to 'produce', the ComputeAndApplyVocabulary + # function generates a vocabulary file. This file, stored in + # 'artifact_location', contains the vocabulary of the entire dataset. + # This is considered as an artifact of ComputeAndApplyVocabulary transform. + # The indices of the vocabulary in this file are returned as + # the output of MLTransform. + transformed_data_pcoll = ( + train_data_pcoll + | 'MLTransform' >> MLTransform( artifact_location=args.artifact_location, - artifact_mode=ArtifactMode.PRODUCE, + artifact_mode=artifact_mode, ).with_transform(ComputeAndApplyVocabulary( columns=['x'])).with_transform(TFIDF(columns=['x']))) - # _ = input_data | beam.Map(logging.info) + _ = transformed_data_pcoll | beam.Map(logging.info) + # output for the element dict(x=["Let's", "go", "to", "the", "park"]) + # will be: + # Row(x=array([21, 5, 0, 2, 3]), + # x_tfidf_weight=array([0.28109303, 0.36218604, 0.36218604, 0.41972247, + # 0.5008155 ], dtype=float32), x_vocab_index=array([ 0, 2, 3, 5, 21])) + +def preprocess_data_for_ml_inference(test_data, artifact_mode, args): + """ + Preprocess the data for ML inference. This method runs a pipeline to + preprocess the data needed for ML inference. It consumes the artifacts + produced during the preprocessing stage for ML training. + """ with beam.Pipeline() as p: - input_data = [ - dict(x=['I', 'love', 'books']), dict(x=['I', 'love', 'Apache', 'Beam']) - ] - input_data = p | beam.Create(input_data) - - # artifacts consume mode. - input_data |= ( - MLTransform( + + test_data_pcoll = (p | beam.Create(test_data)) + # Here, the previously saved vocabulary from an MLTransform run is used by + # ComputeAndApplyVocabulary to access and apply the stored artifacts to the + # test data. + transformed_data_pcoll = ( + test_data_pcoll + | "MLTransformOnTestData" >> MLTransform( artifact_location=args.artifact_location, - artifact_mode=ArtifactMode.CONSUME, - # you don't need to specify transforms as they are already saved in + artifact_mode=artifact_mode, + # ww don't need to specify transforms as they are already saved in # in the artifacts. )) + _ = transformed_data_pcoll | beam.Map(logging.info) + # output for dict(x=['I', 'love', 'books']) will be: + # Row(x=array([1, 4, 7]), + # x_tfidf_weight=array([0.4684884 , 0.6036434 , 0.69953746], dtype=float32) + # , x_vocab_index=array([1, 4, 7])) - _ = input_data | beam.Map(logging.info) - # To fetch the artifacts after the pipeline is run +def run(args): + """ + This example demonstrates how to use MLTransform in ML workflow. + 1. Preprocess the data for ML training. + 2. Do some ML model training. + 3. Preprocess the data for ML inference. + + training and inference on ML modes are not shown in this example. + This example only shows how to use MLTransform for preparing data for ML + training and inference. + """ + + train_data = [ + dict(x=["Let's", "go", "to", "the", "park"]), + dict(x=["I", "enjoy", "going", "to", "the", "park"]), + dict(x=["I", "enjoy", "reading", "books"]), + dict(x=["Beam", "can", "be", "fun"]), + dict(x=["The", "weather", "is", "really", "nice", "today"]), + dict(x=["I", "love", "to", "go", "to", "the", "park"]), + dict(x=["I", "love", "to", "read", "books"]), + dict(x=["I", "love", "to", "program"]), + ] + + test_data = [ + dict(x=['I', 'love', 'books']), dict(x=['I', 'love', 'Apache', 'Beam']) + ] + + # Preprocess the data for ML training. + # For the data going into the ML model training, we want to produce the + # artifacts. So, we set artifact_mode to ArtifactMode.PRODUCE. + preprocess_data_for_ml_training( + train_data, artifact_mode=ArtifactMode.PRODUCE, args=args) + + # Do some ML model training here. + + # Preprocess the data for ML inference. + # For the data going into the ML model inference, we want to consume the + # artifacts produced during the stage where we preprocessed the data for ML + # training. So, we set artifact_mode to ArtifactMode.CONSUME. + preprocess_data_for_ml_inference( + test_data, artifact_mode=ArtifactMode.CONSUME, args=args) + + # To fetch the artifacts produced in MLTransform, you can use + # ArtifactsFetcher for fetching vocab related artifacts. For + # others such as TFIDF weight, they can be accessed directly + # from the output of MLTransform. artifacts_fetcher = ArtifactsFetcher(artifact_location=args.artifact_location) vocab_list = artifacts_fetcher.get_vocab_list() assert vocab_list[22] == 'Beam'