diff --git a/CHANGES.md b/CHANGES.md index 10310c6cbef1..55a106f3513b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -68,6 +68,7 @@ * Add UDF metrics support for Samza portable mode. * Option for SparkRunner to avoid the need of SDF output to fit in memory ([#23852](https://github.com/apache/beam/issues/23852)). This helps e.g. with ParquetIO reads. Turn the feature on by adding experiment `use_bounded_concurrent_output_for_sdf`. +* Add `WatchFilePattern` transform, which can be used as a side input to the RunInference PTransfrom to watch for model updates using a file pattern. ([#24042](https://github.com/apache/beam/issues/24042)) * Add support for loading TorchScript models with `PytorchModelHandler`. The TorchScript model path can be passed to PytorchModelHandler using `torch_script_model_path=`. ([#25321](https://github.com/apache/beam/pull/25321)) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py new file mode 100644 index 000000000000..2a4e6e9a9bc6 --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py @@ -0,0 +1,218 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A pipeline that uses RunInference PTransform to perform image classification +and uses WatchFilePattern as side input to the RunInference PTransform. +WatchFilePattern is used to watch for a file updates matching the file_pattern +based on timestamps and emits latest model metadata, which is used in +RunInference API for the dynamic model updates without the need for stopping +the beam pipeline. + +This pipeline follows the pattern from +https://beam.apache.org/documentation/patterns/side-inputs/ + +To use the PubSub reading from a topic in the pipeline as source, you can +publish a path to the model(resnet152 used in the pipeline from +torchvision.models.resnet152) to the PubSub topic. Then pass that +topic via command line arg --topic. The published path(str) should be +UTF-8 encoded. + +To run the example on DataflowRunner, + +python apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py # pylint: disable=line-too-long + --project= + --re= + --temp_location= + --staging_location= + --runner=DataflowRunner + --streaming + --interval=10 + --num_workers=5 + --requirements_file=apache_beam/ml/inference/torch_tests_requirements.txt + --topic= + --file_pattern= + +file_pattern is path(can contain glob characters), which will be passed to +WatchContinuously transform for model updates. WatchContinuously watches the +file_pattern and emits a latest file path, sorted by timestamp. Files that +are read before and updated with same name will be ignored as an update. + +The pipeline expects there is at least one file present to match the +file_pattern before pipeline startup. Presumably, this would be the +`initial_model_path`. If there is no file matching before pipeline +startup time, the pipeline would fail. +""" + +import argparse +import io +import logging +import os +from typing import Iterable +from typing import Iterator +from typing import Optional +from typing import Tuple + +import apache_beam as beam +import torch +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.ml.inference.utils import WatchFilePattern +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.runners.runner import PipelineResult +from PIL import Image +from torchvision import models +from torchvision import transforms + + +def read_image(image_file_name: str, + path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + if path_to_dir is not None: + image_file_name = os.path.join(path_to_dir, image_file_name) + with FileSystems().open(image_file_name, 'r') as file: + data = Image.open(io.BytesIO(file.read())).convert('RGB') + return image_file_name, data + + +def preprocess_image(data: Image.Image) -> torch.Tensor: + image_size = (224, 224) + # Pre-trained PyTorch models expect input images normalized with the + # below values (see: https://pytorch.org/vision/stable/models.html) + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + transform = transforms.Compose([ + transforms.Resize(image_size), + transforms.ToTensor(), + normalize, + ]) + return transform(data) + + +def filter_empty_lines(text: str) -> Iterator[str]: + if len(text.strip()) > 0: + yield text + + +class PostProcessor(beam.DoFn): + """ + Return filename, prediction and the model id used to perform the + prediction + """ + def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + filename, prediction_result = element + prediction = torch.argmax(prediction_result.inference, dim=0) + yield filename, prediction, prediction_result.model_id + + +def parse_known_args(argv): + """Parses args for the workflow.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--topic', + dest='topic', + help='PubSub topic emitting absolute path to the images.' + 'Path must be accessible by the pipeline.') + parser.add_argument( + '--model_path', + '--initial_model_path', + dest='model_path', + default='gs://apache-beam-samples/run_inference/resnet152.pth', + help="Path to the initial model's state_dict. " + "This will be used until the first model update occurs.") + parser.add_argument( + '--file_pattern', help='Glob pattern to watch for an update.') + parser.add_argument( + '--interval', + default=10, + type=int, + help='Interval used to check for file updates.') + + return parser.parse_known_args(argv) + + +def run( + argv=None, + model_class=None, + model_params=None, + save_main_session=True, + device='CPU', + test_pipeline=None) -> PipelineResult: + """ + Args: + argv: Command line arguments defined for this example. + model_class: Reference to the class definition of the model. + model_params: Parameters passed to the constructor of the model_class. + These will be used to instantiate the model object in the + RunInference PTransform. + save_main_session: Used for internal testing. + device: Device to be used on the Runner. Choices are (CPU, GPU). + test_pipeline: Used for internal testing. + """ + known_args, pipeline_args = parse_known_args(argv) + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + + if not model_class: + model_class = models.resnet152 + model_params = {'num_classes': 1000} + + # In this example we pass keyed inputs to RunInference transform. + # Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler. + model_handler = KeyedModelHandler( + PytorchModelHandlerTensor( + state_dict_path=known_args.model_path, + model_class=model_class, + model_params=model_params, + device=device, + min_batch_size=10, + max_batch_size=100)) + + pipeline = test_pipeline + if not test_pipeline: + pipeline = beam.Pipeline(options=pipeline_options) + + side_input = pipeline | WatchFilePattern( + interval=known_args.interval, file_pattern=known_args.file_pattern) + + filename_value_pair = ( + pipeline + | 'ReadImageNamesFromPubSub' >> beam.io.ReadFromPubSub(known_args.topic) + | 'DecodeBytes' >> beam.Map(lambda x: x.decode('utf-8')) + | 'ReadImageData' >> + beam.Map(lambda image_name: read_image(image_file_name=image_name)) + | 'PreprocessImages' >> beam.MapTuple( + lambda file_name, data: (file_name, preprocess_image(data)))) + predictions = ( + filename_value_pair + | 'PyTorchRunInference' >> RunInference( + model_handler, model_metadata_pcoll=side_input) + | 'ProcessOutput' >> beam.ParDo(PostProcessor())) + + _ = predictions | beam.Map(logging.info) + + result = pipeline.run() + result.wait_until_finish() + return result + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py index da6e86e2cf34..2be5d06a0264 100644 --- a/sdks/python/apache_beam/io/fileio.py +++ b/sdks/python/apache_beam/io/fileio.py @@ -277,7 +277,8 @@ def __init__( start_timestamp=Timestamp.now(), stop_timestamp=MAX_TIMESTAMP, match_updated_files=False, - apply_windowing=False): + apply_windowing=False, + empty_match_treatment=EmptyMatchTreatment.ALLOW): """Initializes a MatchContinuously transform. Args: @@ -299,6 +300,7 @@ def __init__( self.stop_ts = stop_timestamp self.match_upd = match_updated_files self.apply_windowing = apply_windowing + self.empty_match_treatment = empty_match_treatment def expand(self, pbegin) -> beam.PCollection[filesystem.FileMetadata]: # invoke periodic impulse @@ -311,7 +313,7 @@ def expand(self, pbegin) -> beam.PCollection[filesystem.FileMetadata]: match_files = ( impulse | 'GetFilePattern' >> beam.Map(lambda x: self.file_pattern) - | MatchAll()) + | MatchAll(self.empty_match_treatment)) # apply deduplication strategy if required if self.has_deduplication: diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 319735da2363..dad18c7b9e18 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -440,15 +440,19 @@ def test_run_inference_side_input_in_batch(self): first_ts + 22, ]) - sample_side_input_elements = [( - first_ts + 8, - base.ModelMetadata( - model_id='fake_model_id_1', model_name='fake_model_id_1')), - ( - first_ts + 15, - base.ModelMetadata( - model_id='fake_model_id_2', - model_name='fake_model_id_2'))] + sample_side_input_elements = [ + (first_ts + 1, base.ModelMetadata(model_id='', model_name='')), + # if model_id is empty string, we use the default model + # handler model URI. + ( + first_ts + 8, + base.ModelMetadata( + model_id='fake_model_id_1', model_name='fake_model_id_1')), + ( + first_ts + 15, + base.ModelMetadata( + model_id='fake_model_id_2', model_name='fake_model_id_2')) + ] model_handler = FakeModelHandlerReturnsPredictionResult() diff --git a/sdks/python/apache_beam/ml/inference/utils.py b/sdks/python/apache_beam/ml/inference/utils.py index f30d8a8f6486..4936ab5fe1d4 100644 --- a/sdks/python/apache_beam/ml/inference/utils.py +++ b/sdks/python/apache_beam/ml/inference/utils.py @@ -19,13 +19,26 @@ """ Util/helper functions used in apache_beam.ml.inference. """ +import os +from functools import partial from typing import Any from typing import Dict from typing import Iterable from typing import Optional from typing import Union +import apache_beam as beam +from apache_beam.io.fileio import EmptyMatchTreatment +from apache_beam.io.fileio import MatchContinuously +from apache_beam.ml.inference.base import ModelMetadata from apache_beam.ml.inference.base import PredictionResult +from apache_beam.transforms import trigger +from apache_beam.transforms import window +from apache_beam.transforms.userstate import CombiningValueStateSpec +from apache_beam.utils.timestamp import MAX_TIMESTAMP +from apache_beam.utils.timestamp import Timestamp + +_START_TIME_STAMP = Timestamp.now() def _convert_to_result( @@ -46,3 +59,106 @@ def _convert_to_result( y in zip(batch, predictions_per_tensor) ] return [PredictionResult(x, y, model_id) for x, y in zip(batch, predictions)] + + +class _ConvertIterToSingleton(beam.DoFn): + """ + Internal only; No backwards compatibility. + + The MatchContinuously transform examines all files present in a given + directory and returns those that have timestamps older than the + pipeline's start time. This can produce an Iterable rather than a + Singleton. This class only returns the file path when it is first + encountered, and it is cached as part of the side input caching mechanism. + If the path is seen again, it will not return anything. + By doing this, we can ensure that the output of this transform can be wrapped + with beam.pvalue.AsSingleton(). + """ + COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum) + + def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)): + counter = count_state.read() + if counter == 0: + count_state.add(1) + yield element[1] + + +class _GetLatestFileByTimeStamp(beam.DoFn): + """ + Internal only; No backwards compatibility. + + This DoFn checks the timestamps of files against the time that the pipeline + began running. It returns the files that were modified after the pipeline + started. If no such files are found, it returns a default file as fallback. + """ + TIME_STATE = CombiningValueStateSpec( + 'max', combine_fn=partial(max, default=_START_TIME_STAMP)) + + def process(self, element, time_state=beam.DoFn.StateParam(TIME_STATE)): + _, file_metadata = element + new_ts = file_metadata.last_updated_in_seconds + old_ts = time_state.read() + if new_ts > old_ts: + time_state.clear() + time_state.add(new_ts) + model_path = file_metadata.path + else: + model_path = '' + + model_name = os.path.splitext(os.path.basename(model_path))[0] + return [ + (model_path, ModelMetadata(model_id=model_path, model_name=model_name)) + ] + + +class WatchFilePattern(beam.PTransform): + def __init__( + self, + file_pattern, + interval=360, + stop_timestamp=MAX_TIMESTAMP, + ): + """ + Watches a directory for updates to files matching a given file pattern. + + Args: + file_pattern: The file path to read from as a local file path or a + GCS ``gs://`` path. The path can contain glob characters + (``*``, ``?``, and ``[...]`` sets). + interval: Interval at which to check for files matching file_pattern + in seconds. + stop_timestamp: Timestamp after which no more files will be checked. + + **Note**: + + 1. Any previously used filenames cannot be reused. If a file is added + or updated to a previously used filename, this transform will ignore + that update. To trigger a model update, always upload a file with + unique name. + 2. Initially, before the pipeline startup time, WatchFilePattern expects + at least one file present that matches the file_pattern. + 3. This transform is supported in streaming mode since + MatchContinuously produces an unbounded source. Running in batch + mode can lead to undesired results or result in pipeline being stuck. + + + """ + self.file_pattern = file_pattern + self.interval = interval + self.stop_timestamp = stop_timestamp + + def expand(self, pcoll) -> beam.PCollection[ModelMetadata]: + return ( + pcoll + | 'MatchContinuously' >> MatchContinuously( + file_pattern=self.file_pattern, + interval=self.interval, + stop_timestamp=self.stop_timestamp, + empty_match_treatment=EmptyMatchTreatment.DISALLOW) + | "AttachKey" >> beam.Map(lambda x: (x.path, x)) + | "GetLatestFileMetaData" >> beam.ParDo(_GetLatestFileByTimeStamp()) + | "AcceptNewSideInputOnly" >> beam.ParDo(_ConvertIterToSingleton()) + | 'ApplyGlobalWindow' >> beam.transforms.WindowInto( + window.GlobalWindows(), + trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)), + accumulation_mode=trigger.AccumulationMode.DISCARDING)) diff --git a/sdks/python/apache_beam/ml/inference/utils_test.py b/sdks/python/apache_beam/ml/inference/utils_test.py new file mode 100644 index 000000000000..66499a5a6f48 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/utils_test.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import unittest + +import apache_beam as beam +from apache_beam.io.filesystem import FileMetadata +from apache_beam.ml.inference import utils +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + + +class WatchFilePatternTest(unittest.TestCase): + def test_latest_file_by_timestamp_default_value(self): + # match continuously returns the files in sorted timestamp order. + main_input_pcoll = [ + FileMetadata( + 'path1.py', + 10, + last_updated_in_seconds=utils._START_TIME_STAMP - 20), + FileMetadata( + 'path2.py', + 10, + last_updated_in_seconds=utils._START_TIME_STAMP - 10) + ] + with TestPipeline() as p: + files_pc = ( + p + | beam.Create(main_input_pcoll) + | beam.Map(lambda x: (x.path, x)) + | beam.ParDo(utils._GetLatestFileByTimeStamp()) + | beam.Map(lambda x: x[0])) + assert_that(files_pc, equal_to(['', ''])) + + def test_latest_file_with_timestamp_after_pipeline_construction_time(self): + main_input_pcoll = [ + FileMetadata( + 'path1.py', + 10, + last_updated_in_seconds=utils._START_TIME_STAMP + 10) + ] + with TestPipeline() as p: + files_pc = ( + p + | beam.Create(main_input_pcoll) + | beam.Map(lambda x: (x.path, x)) + | beam.ParDo(utils._GetLatestFileByTimeStamp()) + | beam.Map(lambda x: x[0])) + assert_that(files_pc, equal_to(['path1.py'])) + + def test_emitting_singleton_output(self): + # match continuously returns the files in sorted timestamp order. + main_input_pcoll = [ + FileMetadata( + 'path1.py', + 10, + last_updated_in_seconds=utils._START_TIME_STAMP - 20), + # returns default + FileMetadata( + 'path2.py', + 10, + last_updated_in_seconds=utils._START_TIME_STAMP - 10), + # returns default + FileMetadata( + 'path3.py', + 10, + last_updated_in_seconds=utils._START_TIME_STAMP + 10), + FileMetadata( + 'path4.py', + 10, + last_updated_in_seconds=utils._START_TIME_STAMP + 20) + ] + # returns path3.py + + with TestPipeline() as p: + files_pc = ( + p + | beam.Create(main_input_pcoll) + | beam.Map(lambda x: (x.path, x)) + | beam.ParDo(utils._GetLatestFileByTimeStamp()) + | beam.ParDo(utils._ConvertIterToSingleton()) + | beam.Map(lambda x: x[0])) + assert_that(files_pc, equal_to(['', 'path3.py', 'path4.py'])) + + +if __name__ == '__main__': + unittest.main()