Skip to content

Commit

Permalink
Shell out to model handlers to collect byte sizes (#28182)
Browse files Browse the repository at this point in the history
* Shell out to model handlers to collect byte sizes

* naming
  • Loading branch information
damccorm authored Aug 28, 2023
1 parent 48da872 commit 252c713
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
16 changes: 12 additions & 4 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,19 @@ def run_inference(
return predictions

def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
keys, unkeyed_batch = zip(*batch)
batch_bytes = len(pickle.dumps(keys))
if self._single_model:
keys, unkeyed_batch = zip(*batch)
return len(
pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
return len(pickle.dumps(batch))
return batch_bytes + self._unkeyed.get_num_bytes(unkeyed_batch)

batch_by_key = defaultdict(list)
for key, examples in batch:
batch_by_key[key].append(examples)

for key, examples in batch_by_key.items():
mh_id = self._key_to_id_map[key]
batch_bytes += self._id_to_mh_map[mh_id].get_num_bytes(examples)
return batch_bytes

def get_metrics_namespace(self) -> str:
if self._single_model:
Expand Down
25 changes: 25 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ def __init__(
max_batch_size=9999,
multi_process_shared=False,
state=None,
num_bytes_per_element=None,
**kwargs):
self._fake_clock = clock
self._min_batch_size = min_batch_size
self._max_batch_size = max_batch_size
self._env_vars = kwargs.get('env_vars', {})
self._multi_process_shared = multi_process_shared
self._state = state
self._num_bytes_per_element = num_bytes_per_element

def load_model(self):
if self._fake_clock:
Expand Down Expand Up @@ -113,6 +115,11 @@ def batch_elements_kwargs(self):
def share_model_across_processes(self):
return self._multi_process_shared

def get_num_bytes(self, batch: Sequence[int]) -> int:
if self._num_bytes_per_element:
return self._num_bytes_per_element * len(batch)
return super().get_num_bytes(batch)


class FakeModelHandlerReturnsPredictionResult(
base.ModelHandler[int, base.PredictionResult, FakeModel]):
Expand Down Expand Up @@ -319,6 +326,24 @@ def mult_two(example: str) -> int:
with self.assertRaises(ValueError):
base.KeyedModelHandler(mhs)

def test_keyed_model_handler_get_num_bytes(self):
mh = base.KeyedModelHandler(FakeModelHandler(num_bytes_per_element=10))
batch = [('key1', 1), ('key2', 2), ('key1', 3)]
expected = len(pickle.dumps(('key1', 'key2', 'key1'))) + 30
actual = mh.get_num_bytes(batch)
self.assertEqual(expected, actual)

def test_keyed_model_handler_multiple_models_get_num_bytes(self):
mhs = [
base.KeyMhMapping(['key1'], FakeModelHandler(num_bytes_per_element=10)),
base.KeyMhMapping(['key2'], FakeModelHandler(num_bytes_per_element=20))
]
mh = base.KeyedModelHandler(mhs)
batch = [('key1', 1), ('key2', 2), ('key1', 3)]
expected = len(pickle.dumps(('key1', 'key2', 'key1'))) + 40
actual = mh.get_num_bytes(batch)
self.assertEqual(expected, actual)

def test_run_inference_impl_with_maybe_keyed_examples(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
Expand Down

0 comments on commit 252c713

Please sign in to comment.