Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WatchFilePattern #25393

Merged
merged 17 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
* Add UDF metrics support for Samza portable mode.
* Option for SparkRunner to avoid the need of SDF output to fit in memory ([#23852](https://github.com/apache/beam/issues/23852)).
This helps e.g. with ParquetIO reads. Turn the feature on by adding experiment `use_bounded_concurrent_output_for_sdf`.
* Add `WatchFilePattern` transform, which can be used as a side input to the RunInference PTransfrom to watch for model updates using a file pattern. ([#24042](https://github.com/apache/beam/issues/24042))

## Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
A pipeline that uses RunInference PTransform to perform image classification
and uses WatchFilePattern as side input to the RunInference PTransform.
WatchFilePattern is used to watch for a file updates matching the file_pattern
based on timestamps and emits latest model metadata, which is used in
RunInference API for the dynamic model updates without the need for stopping
the beam pipeline.

This pipeline follows the pattern from
https://beam.apache.org/documentation/patterns/side-inputs/

This pipeline expects a PubSub topic as source, which emits an image
path(UTF-8 encoded) that is accessible by the pipeline.
damccorm marked this conversation as resolved.
Show resolved Hide resolved

To run the example on DataflowRunner,

python apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py # pylint: disable=line-too-long
--project=<your-project>
--re=<your-region>
--temp_location=<your-tmp-location>
--staging_location=<your-staging-location>
--runner=DataflowRunner
--streaming
--interval=10
--num_workers=5
--requirements_file=apache_beam/ml/inference/torch_tests_requirements.txt
--topic=<pubusb_topic>
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
--file_pattern=<glob_pattern>
damccorm marked this conversation as resolved.
Show resolved Hide resolved
"""

import argparse
import io
import logging
import os
from typing import Iterable
from typing import Iterator
from typing import Optional
from typing import Tuple

import apache_beam as beam
import torch
from apache_beam.io.filesystems import FileSystems
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
from apache_beam.ml.inference.utils import WatchFilePattern
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.runners.runner import PipelineResult
from PIL import Image
from torchvision import models
from torchvision import transforms


def read_image(image_file_name: str,
path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]:
if path_to_dir is not None:
image_file_name = os.path.join(path_to_dir, image_file_name)
with FileSystems().open(image_file_name, 'r') as file:
data = Image.open(io.BytesIO(file.read())).convert('RGB')
return image_file_name, data


def preprocess_image(data: Image.Image) -> torch.Tensor:
image_size = (224, 224)
# Pre-trained PyTorch models expect input images normalized with the
# below values (see: https://pytorch.org/vision/stable/models.html)
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
normalize,
])
return transform(data)


def filter_empty_lines(text: str) -> Iterator[str]:
if len(text.strip()) > 0:
yield text


class PostProcessor(beam.DoFn):
"""
Return filename, prediction and the model id used to perform the
prediction
"""
def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:
filename, prediction_result = element
prediction = torch.argmax(prediction_result.inference, dim=0)
yield filename, prediction, prediction_result.model_id


def parse_known_args(argv):
"""Parses args for the workflow."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--topic',
dest='topic',
help='PubSub topic emitting absolute path to the images.'
'Path must be accessible by the pipeline.')
parser.add_argument(
'--model_path',
damccorm marked this conversation as resolved.
Show resolved Hide resolved
dest='model_path',
default='gs://apache-beam-samples/run_inference/resnet152.pth',
help="Path to the model's state_dict.")
parser.add_argument(
'--file_pattern', help='Glob pattern to watch for an update.')
parser.add_argument(
'--interval',
default=10,
type=int,
help='Interval used to check for file updates.')

return parser.parse_known_args(argv)


def run(
argv=None,
model_class=None,
model_params=None,
save_main_session=True,
device='CPU',
test_pipeline=None) -> PipelineResult:
"""
Args:
argv: Command line arguments defined for this example.
model_class: Reference to the class definition of the model.
model_params: Parameters passed to the constructor of the model_class.
These will be used to instantiate the model object in the
RunInference PTransform.
save_main_session: Used for internal testing.
device: Device to be used on the Runner. Choices are (CPU, GPU).
test_pipeline: Used for internal testing.
"""
known_args, pipeline_args = parse_known_args(argv)
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

if not model_class:
model_class = models.resnet152
model_params = {'num_classes': 1000}

class PytorchModelHandlerTensorWithBatchSize(PytorchModelHandlerTensor):
def batch_elements_kwargs(self):
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
return {'min_batch_size': 10, 'max_batch_size': 100}

# In this example we pass keyed inputs to RunInference transform.
# Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler.
model_handler = KeyedModelHandler(
PytorchModelHandlerTensorWithBatchSize(
state_dict_path=known_args.model_path,
model_class=model_class,
model_params=model_params,
device=device))

pipeline = test_pipeline
if not test_pipeline:
pipeline = beam.Pipeline(options=pipeline_options)

side_input = pipeline | WatchFilePattern(
interval=known_args.interval, file_pattern=known_args.file_pattern)
Comment on lines +192 to +193
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome, just 2 lines and an additional arg to RunInference to do this!


filename_value_pair = (
pipeline
| 'ReadImageNamesFromPubSub' >> beam.io.ReadFromPubSub(known_args.topic)
| 'DecodeBytes' >> beam.Map(lambda x: x.decode('utf-8'))
| 'ReadImageData' >>
beam.Map(lambda image_name: read_image(image_file_name=image_name))
| 'PreprocessImages' >> beam.MapTuple(
lambda file_name, data: (file_name, preprocess_image(data))))
predictions = (
filename_value_pair
| 'PyTorchRunInference' >> RunInference(
model_handler, model_metadata_pcoll=side_input)
| 'ProcessOutput' >> beam.ParDo(PostProcessor()))

_ = predictions | beam.Map(logging.info)

result = pipeline.run()
result.wait_until_finish()
return result


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()
6 changes: 4 additions & 2 deletions sdks/python/apache_beam/io/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def __init__(
start_timestamp=Timestamp.now(),
stop_timestamp=MAX_TIMESTAMP,
match_updated_files=False,
apply_windowing=False):
apply_windowing=False,
empty_match_treatment=EmptyMatchTreatment.ALLOW):
"""Initializes a MatchContinuously transform.

Args:
Expand All @@ -299,6 +300,7 @@ def __init__(
self.stop_ts = stop_timestamp
self.match_upd = match_updated_files
self.apply_windowing = apply_windowing
self.empty_match_treatment = empty_match_treatment

def expand(self, pbegin) -> beam.PCollection[filesystem.FileMetadata]:
# invoke periodic impulse
Expand All @@ -311,7 +313,7 @@ def expand(self, pbegin) -> beam.PCollection[filesystem.FileMetadata]:
match_files = (
impulse
| 'GetFilePattern' >> beam.Map(lambda x: self.file_pattern)
| MatchAll())
| MatchAll(self.empty_match_treatment))

# apply deduplication strategy if required
if self.has_deduplication:
Expand Down
22 changes: 13 additions & 9 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,15 +440,19 @@ def test_run_inference_side_input_in_batch(self):
first_ts + 22,
])

sample_side_input_elements = [(
first_ts + 8,
base.ModelMetdata(
model_id='fake_model_id_1', model_name='fake_model_id_1')),
(
first_ts + 15,
base.ModelMetdata(
model_id='fake_model_id_2',
model_name='fake_model_id_2'))]
sample_side_input_elements = [
(first_ts + 1, base.ModelMetdata(model_id='', model_name='')),
# if model_id is empty string, we use the default model
# handler model URI.
(
first_ts + 8,
base.ModelMetdata(
model_id='fake_model_id_1', model_name='fake_model_id_1')),
(
first_ts + 15,
base.ModelMetdata(
model_id='fake_model_id_2', model_name='fake_model_id_2'))
]

model_handler = FakeModelHandlerReturnsPredictionResult()

Expand Down
Loading