Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AnandInguva committed Oct 12, 2023
1 parent a25d73a commit cee70f6
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,28 @@ def test_model_manager_evicts_correct_num_of_models_after_being_incremented(
mh3.load_model, tag=tag3).acquire()
self.assertEqual(8, model3.predict(10))

def test_run_inference_watch_file_pattern_side_input_label(self):
pipeline = TestPipeline()
# label of the WatchPattern transform.
side_input_str = 'WatchFilePattern/ApplyGlobalWindow'
from apache_beam.ml.inference.utils import WatchFilePattern
file_pattern_side_input = (
pipeline
| 'WatchFilePattern' >> WatchFilePattern(file_pattern='fake/path/*'))
pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
result_pcoll = pcoll | base.RunInference(
FakeModelHandler(), model_metadata_pcoll=file_pattern_side_input)
assert side_input_str in str(result_pcoll.producer.side_inputs[0])

def test_run_inference_watch_file_pattern_keyword_arg_side_input_label(self):
# label of the WatchPattern transform.
side_input_str = 'WatchFilePattern/ApplyGlobalWindow'
pipeline = TestPipeline()
pcoll = pipeline | 'start' >> beam.Create([1, 2, 3])
result_pcoll = pcoll | base.RunInference(
FakeModelHandler(), watch_model_pattern='fake/path/*')
assert side_input_str in str(result_pcoll.producer.side_inputs[0])


if __name__ == '__main__':
unittest.main()

0 comments on commit cee70f6

Please sign in to comment.