Skip to content

Commit

Permalink
Avoid oversizing batch sizes with size estimation function (#31228)
Browse files Browse the repository at this point in the history
* Avoid oversizing batch sizes with size estimation function

* lint
  • Loading branch information
damccorm authored May 10, 2024
1 parent cc1024c commit 2d53926
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 11 deletions.
18 changes: 11 additions & 7 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,14 +579,15 @@ def start_bundle(self):
self._batch_size_estimator.ignore_next_timing()

def process(self, element):
self._batch.append(element)
self._running_batch_size += self._element_size_fn(element)
if self._running_batch_size >= self._target_batch_size:
element_size = self._element_size_fn(element)
if self._running_batch_size + element_size > self._target_batch_size:
with self._batch_size_estimator.record_time(self._running_batch_size):
yield window.GlobalWindows.windowed_value_at_end_of_window(self._batch)
self._batch = []
self._running_batch_size = 0
self._target_batch_size = self._batch_size_estimator.next_batch_size()
self._batch.append(element)
self._running_batch_size += element_size

def finish_bundle(self):
if self._batch:
Expand Down Expand Up @@ -621,15 +622,18 @@ def start_bundle(self):

def process(self, element, window=DoFn.WindowParam):
batch = self._batches[window]
batch.elements.append(element)
batch.size += self._element_size_fn(element)
if batch.size >= self._target_batch_size:
element_size = self._element_size_fn(element)
if batch.size + element_size > self._target_batch_size:
with self._batch_size_estimator.record_time(batch.size):
yield windowed_value.WindowedValue(
batch.elements, window.max_timestamp(), (window, ))
del self._batches[window]
self._target_batch_size = self._batch_size_estimator.next_batch_size()
elif len(self._batches) > self._MAX_LIVE_WINDOWS:

self._batches[window].elements.append(element)
self._batches[window].size += element_size

if len(self._batches) > self._MAX_LIVE_WINDOWS:
window, batch = max(
self._batches.items(),
key=lambda window_batch: window_batch[1].size)
Expand Down
33 changes: 29 additions & 4 deletions sdks/python/apache_beam/transforms/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,40 @@ def test_sized_batches(self):
res = (
p
| beam.Create([
'a', 'a', 'aaaaaaaaaa', # First batch.
'aaaaaa', 'aaaaa', # Second batch.
'a', 'aaaaaaa', 'a', 'a' # Third batch.
'a', 'a', # First batch.
'aaaaaaaaaa', # Second batch.
'aaaaa', 'aaaaa', # Third batch.
'a', 'aaaaaaa', 'a', 'a' # Fourth batch.
], reshuffle=False)
| util.BatchElements(
min_batch_size=10, max_batch_size=10, element_size_fn=len)
| beam.Map(lambda batch: ''.join(batch))
| beam.Map(len))
assert_that(res, equal_to([12, 11, 10]))
assert_that(res, equal_to([2, 10, 10, 10]))

def test_sized_windowed_batches(self):
# Assumes a single bundle, in order...
with TestPipeline() as p:
res = (
p
| beam.Create(range(1, 8), reshuffle=False)
| beam.Map(lambda t: window.TimestampedValue('a' * t, t))
| beam.WindowInto(window.FixedWindows(3))
| util.BatchElements(
min_batch_size=11,
max_batch_size=11,
element_size_fn=len,
clock=FakeClock())
| beam.Map(lambda batch: ''.join(batch)))
assert_that(
res,
equal_to([
'a' * (1+2), # Elements in [1, 3)
'a' * (3+4), # Elements in [3, 6)
'a' * 5,
'a' * 6, # Elements in [6, 9)
'a' * 7,
]))

def test_target_duration(self):
clock = FakeClock()
Expand Down

0 comments on commit 2d53926

Please sign in to comment.