Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WatchFilePattern #25393

Merged
merged 17 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
This pipeline follows the pattern from
https://beam.apache.org/documentation/patterns/side-inputs/

This pipeline expects a PubSub topic as source, which emits an image
path(UTF-8 encoded) that is accessible by the pipeline.
To use the PubSub reading from a topic in the pipeline as source, you can
publish a path to the model(resnet152 used in the pipeline from
torchvision.models.resnet152) to the PubSub topic. Then pass that
topic via command line arg --topic. The published path(str) should be
UTF-8 encoded.
damccorm marked this conversation as resolved.
Show resolved Hide resolved

To run the example on DataflowRunner,

Expand All @@ -43,6 +46,16 @@
--requirements_file=apache_beam/ml/inference/torch_tests_requirements.txt
--topic=<pubusb_topic>
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
AnandInguva marked this conversation as resolved.
Show resolved Hide resolved
--file_pattern=<glob_pattern>
damccorm marked this conversation as resolved.
Show resolved Hide resolved

file_pattern is path(can contain glob characters), which will be passed to
WatchContinuously transform for model updates. WatchContinuously watches the
file_pattern and emits a latest file path, sorted by timestamp. Files that
are read before and updated with same name will be ignored as an update.

The pipeline expects there is at least one file present to match the
file_pattern before pipeline startup. Presumably, this would be the
`initial_model_path`. If there is no file matching before pipeline
startup time, the pipeline would fail.
"""

import argparse
Expand Down Expand Up @@ -119,9 +132,11 @@ def parse_known_args(argv):
'Path must be accessible by the pipeline.')
parser.add_argument(
'--model_path',
damccorm marked this conversation as resolved.
Show resolved Hide resolved
'--initial_model_path',
dest='model_path',
default='gs://apache-beam-samples/run_inference/resnet152.pth',
help="Path to the model's state_dict.")
help="Path to the initial model's state_dict. "
"This will be used until the first model update occurs.")
parser.add_argument(
'--file_pattern', help='Glob pattern to watch for an update.')
parser.add_argument(
Expand Down Expand Up @@ -159,18 +174,16 @@ def run(
model_class = models.resnet152
model_params = {'num_classes': 1000}

class PytorchModelHandlerTensorWithBatchSize(PytorchModelHandlerTensor):
def batch_elements_kwargs(self):
return {'min_batch_size': 10, 'max_batch_size': 100}

# In this example we pass keyed inputs to RunInference transform.
# Therefore, we use KeyedModelHandler wrapper over PytorchModelHandler.
model_handler = KeyedModelHandler(
PytorchModelHandlerTensorWithBatchSize(
PytorchModelHandlerTensor(
state_dict_path=known_args.model_path,
model_class=model_class,
model_params=model_params,
device=device))
device=device,
min_batch_size=10,
max_batch_size=100))

pipeline = test_pipeline
if not test_pipeline:
Expand Down
20 changes: 12 additions & 8 deletions sdks/python/apache_beam/ml/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class _GetLatestFileByTimeStamp(beam.DoFn):
started. If no such files are found, it returns a default file as fallback.
"""
TIME_STATE = CombiningValueStateSpec(
'count', combine_fn=partial(max, default=_START_TIME_STAMP))
'max', combine_fn=partial(max, default=_START_TIME_STAMP))
damccorm marked this conversation as resolved.
Show resolved Hide resolved

def process(
self, element, time_state=beam.DoFn.StateParam(TIME_STATE)
Expand All @@ -103,7 +103,7 @@ def process(
new_ts = file_metadata.last_updated_in_seconds
old_ts = time_state.read()
if new_ts > old_ts:
# time_state.clear()
time_state.clear()
time_state.add(new_ts)
model_path = file_metadata.path
else:
Expand All @@ -125,17 +125,21 @@ def __init__(
"""
Watches a directory for updates to files matching a given file pattern.
damccorm marked this conversation as resolved.
Show resolved Hide resolved

**Note**: Start timestamp will be defaulted to timestamp when pipeline
was run. All the files matching file_pattern, that are uploaded before
the pipeline started will be discarded.

Args:
file_pattern: A glob pattern used to watch a directory for model
updates.
file_pattern: The file path to read from as a local file path or a
GCS ``gs://`` path. The path can contain glob characters
(``*``, ``?``, and ``[...]`` sets).
interval: Interval at which to check for files matching file_pattern
in seconds.
stop_timestamp: Timestamp after which no more files will be checked.

Constraints:
1. If the file is read and then there is an update to that file, this
transform will ignore that update. Always update a file with unique
name.
damccorm marked this conversation as resolved.
Show resolved Hide resolved
2. Initially, before the pipeline startup time, WatchFilePattern expects
at least one file present that matches the file_pattern.

**Note**: This transform is supported in streaming mode since
MatchContinuously produces an unbounded source. Running in batch
mode can lead to undesired results or result in pipeline being stuck.
Expand Down
8 changes: 6 additions & 2 deletions sdks/python/apache_beam/ml/inference/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ def test_emitting_singleton_output(self):
FileMetadata(
'path3.py',
10,
last_updated_in_seconds=utils._START_TIME_STAMP + 10)
last_updated_in_seconds=utils._START_TIME_STAMP + 10),
FileMetadata(
'path4.py',
10,
last_updated_in_seconds=utils._START_TIME_STAMP + 20)
]
# returns path3.py

Expand All @@ -92,4 +96,4 @@ def test_emitting_singleton_output(self):
| beam.ParDo(utils._GetLatestFileByTimeStamp())
| beam.ParDo(utils._ConvertIterToSingleton())
| beam.Map(lambda x: x[0]))
assert_that(files_pc, equal_to(['', 'path3.py']))
assert_that(files_pc, equal_to(['', 'path3.py', 'path4.py']))