From 252c713e9f88e3b38e699049fda988d3065387b1 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 28 Aug 2023 15:02:57 -0400 Subject: [PATCH] Shell out to model handlers to collect byte sizes (#28182) * Shell out to model handlers to collect byte sizes * naming --- sdks/python/apache_beam/ml/inference/base.py | 16 +++++++++--- .../apache_beam/ml/inference/base_test.py | 25 +++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index fdfd79e1af1b..e7798f661203 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -477,11 +477,19 @@ def run_inference( return predictions def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int: + keys, unkeyed_batch = zip(*batch) + batch_bytes = len(pickle.dumps(keys)) if self._single_model: - keys, unkeyed_batch = zip(*batch) - return len( - pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) - return len(pickle.dumps(batch)) + return batch_bytes + self._unkeyed.get_num_bytes(unkeyed_batch) + + batch_by_key = defaultdict(list) + for key, examples in batch: + batch_by_key[key].append(examples) + + for key, examples in batch_by_key.items(): + mh_id = self._key_to_id_map[key] + batch_bytes += self._id_to_mh_map[mh_id].get_num_bytes(examples) + return batch_bytes def get_metrics_namespace(self) -> str: if self._single_model: diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index c79189718a90..d6e4b9a76673 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -70,6 +70,7 @@ def __init__( max_batch_size=9999, multi_process_shared=False, state=None, + num_bytes_per_element=None, **kwargs): self._fake_clock = clock self._min_batch_size = min_batch_size @@ -77,6 +78,7 @@ def __init__( self._env_vars = kwargs.get('env_vars', {}) self._multi_process_shared = multi_process_shared self._state = state + self._num_bytes_per_element = num_bytes_per_element def load_model(self): if self._fake_clock: @@ -113,6 +115,11 @@ def batch_elements_kwargs(self): def share_model_across_processes(self): return self._multi_process_shared + def get_num_bytes(self, batch: Sequence[int]) -> int: + if self._num_bytes_per_element: + return self._num_bytes_per_element * len(batch) + return super().get_num_bytes(batch) + class FakeModelHandlerReturnsPredictionResult( base.ModelHandler[int, base.PredictionResult, FakeModel]): @@ -319,6 +326,24 @@ def mult_two(example: str) -> int: with self.assertRaises(ValueError): base.KeyedModelHandler(mhs) + def test_keyed_model_handler_get_num_bytes(self): + mh = base.KeyedModelHandler(FakeModelHandler(num_bytes_per_element=10)) + batch = [('key1', 1), ('key2', 2), ('key1', 3)] + expected = len(pickle.dumps(('key1', 'key2', 'key1'))) + 30 + actual = mh.get_num_bytes(batch) + self.assertEqual(expected, actual) + + def test_keyed_model_handler_multiple_models_get_num_bytes(self): + mhs = [ + base.KeyMhMapping(['key1'], FakeModelHandler(num_bytes_per_element=10)), + base.KeyMhMapping(['key2'], FakeModelHandler(num_bytes_per_element=20)) + ] + mh = base.KeyedModelHandler(mhs) + batch = [('key1', 1), ('key2', 2), ('key1', 3)] + expected = len(pickle.dumps(('key1', 'key2', 'key1'))) + 40 + actual = mh.get_num_bytes(batch) + self.assertEqual(expected, actual) + def test_run_inference_impl_with_maybe_keyed_examples(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10]