From 2d53926542f82a8b955eb541f13475f9bef091a7 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Fri, 10 May 2024 10:12:17 -0400 Subject: [PATCH] Avoid oversizing batch sizes with size estimation function (#31228) * Avoid oversizing batch sizes with size estimation function * lint --- sdks/python/apache_beam/transforms/util.py | 18 ++++++---- .../apache_beam/transforms/util_test.py | 33 ++++++++++++++++--- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index edf79b7c7981..750d98f0789c 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -579,14 +579,15 @@ def start_bundle(self): self._batch_size_estimator.ignore_next_timing() def process(self, element): - self._batch.append(element) - self._running_batch_size += self._element_size_fn(element) - if self._running_batch_size >= self._target_batch_size: + element_size = self._element_size_fn(element) + if self._running_batch_size + element_size > self._target_batch_size: with self._batch_size_estimator.record_time(self._running_batch_size): yield window.GlobalWindows.windowed_value_at_end_of_window(self._batch) self._batch = [] self._running_batch_size = 0 self._target_batch_size = self._batch_size_estimator.next_batch_size() + self._batch.append(element) + self._running_batch_size += element_size def finish_bundle(self): if self._batch: @@ -621,15 +622,18 @@ def start_bundle(self): def process(self, element, window=DoFn.WindowParam): batch = self._batches[window] - batch.elements.append(element) - batch.size += self._element_size_fn(element) - if batch.size >= self._target_batch_size: + element_size = self._element_size_fn(element) + if batch.size + element_size > self._target_batch_size: with self._batch_size_estimator.record_time(batch.size): yield windowed_value.WindowedValue( batch.elements, window.max_timestamp(), (window, )) del self._batches[window] self._target_batch_size = self._batch_size_estimator.next_batch_size() - elif len(self._batches) > self._MAX_LIVE_WINDOWS: + + self._batches[window].elements.append(element) + self._batches[window].size += element_size + + if len(self._batches) > self._MAX_LIVE_WINDOWS: window, batch = max( self._batches.items(), key=lambda window_batch: window_batch[1].size) diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 53898d579983..74d9f438a5df 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -299,15 +299,40 @@ def test_sized_batches(self): res = ( p | beam.Create([ - 'a', 'a', 'aaaaaaaaaa', # First batch. - 'aaaaaa', 'aaaaa', # Second batch. - 'a', 'aaaaaaa', 'a', 'a' # Third batch. + 'a', 'a', # First batch. + 'aaaaaaaaaa', # Second batch. + 'aaaaa', 'aaaaa', # Third batch. + 'a', 'aaaaaaa', 'a', 'a' # Fourth batch. ], reshuffle=False) | util.BatchElements( min_batch_size=10, max_batch_size=10, element_size_fn=len) | beam.Map(lambda batch: ''.join(batch)) | beam.Map(len)) - assert_that(res, equal_to([12, 11, 10])) + assert_that(res, equal_to([2, 10, 10, 10])) + + def test_sized_windowed_batches(self): + # Assumes a single bundle, in order... + with TestPipeline() as p: + res = ( + p + | beam.Create(range(1, 8), reshuffle=False) + | beam.Map(lambda t: window.TimestampedValue('a' * t, t)) + | beam.WindowInto(window.FixedWindows(3)) + | util.BatchElements( + min_batch_size=11, + max_batch_size=11, + element_size_fn=len, + clock=FakeClock()) + | beam.Map(lambda batch: ''.join(batch))) + assert_that( + res, + equal_to([ + 'a' * (1+2), # Elements in [1, 3) + 'a' * (3+4), # Elements in [3, 6) + 'a' * 5, + 'a' * 6, # Elements in [6, 9) + 'a' * 7, + ])) def test_target_duration(self): clock = FakeClock()