Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force BQIO to output elements in the correct row #32584

Merged
merged 4 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run.",
"modification": 2
"modification": 3
}

Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run"
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 1
}
59 changes: 34 additions & 25 deletions sdks/python/apache_beam/io/gcp/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def chain_after(result):
from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX
from apache_beam.transforms.sideinputs import get_sideinput_index
from apache_beam.transforms.util import ReshufflePerKey
from apache_beam.transforms.window import GlobalWindows
from apache_beam.typehints.row_type import RowTypeConstraint
from apache_beam.typehints.schemas import schema_from_element_type
from apache_beam.utils import retry
Expand Down Expand Up @@ -1581,7 +1580,8 @@ def _create_table_if_needed(self, table_reference, schema=None):
additional_create_parameters=self.additional_bq_parameters)
_KNOWN_TABLES.add(str_table_reference)

def process(self, element, *schema_side_inputs):
def process(
self, element, window_value=DoFn.WindowedValueParam, *schema_side_inputs):
destination = bigquery_tools.get_hashable_destination(element[0])

if callable(self.schema):
Expand All @@ -1608,12 +1608,11 @@ def process(self, element, *schema_side_inputs):
return [
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS,
GlobalWindows.windowed_value(
window_value.with_value(
(destination, row_and_insert_id[0], error))),
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS,
GlobalWindows.windowed_value(
(destination, row_and_insert_id[0])))
window_value.with_value((destination, row_and_insert_id[0])))
]

# Flush current batch first if adding this row will exceed our limits
Expand All @@ -1624,19 +1623,20 @@ def process(self, element, *schema_side_inputs):
flushed_batch = self._flush_batch(destination)
# After flushing our existing batch, we now buffer the current row
# for the next flush
self._rows_buffer[destination].append(row_and_insert_id)
self._rows_buffer[destination].append((row_and_insert_id, window_value))
self._destination_buffer_byte_size[destination] = row_byte_size
return flushed_batch

self._rows_buffer[destination].append(row_and_insert_id)
self._rows_buffer[destination].append((row_and_insert_id, window_value))
self._destination_buffer_byte_size[destination] += row_byte_size
self._total_buffered_rows += 1
if self._total_buffered_rows >= self._max_buffered_rows:
return self._flush_all_batches()
else:
# The input is already batched per destination, flush the rows now.
batched_rows = element[1]
self._rows_buffer[destination].extend(batched_rows)
for r in batched_rows:
self._rows_buffer[destination].append((r, window_value))
return self._flush_batch(destination)

def finish_bundle(self):
Expand All @@ -1659,7 +1659,7 @@ def _flush_all_batches(self):
def _flush_batch(self, destination):

# Flush the current batch of rows to BigQuery.
rows_and_insert_ids = self._rows_buffer[destination]
rows_and_insert_ids_with_windows = self._rows_buffer[destination]
table_reference = bigquery_tools.parse_table_reference(destination)
if table_reference.projectId is None:
table_reference.projectId = vp.RuntimeValueProvider.get_value(
Expand All @@ -1668,9 +1668,10 @@ def _flush_batch(self, destination):
_LOGGER.debug(
'Flushing data to %s. Total %s rows.',
destination,
len(rows_and_insert_ids))
self.batch_size_metric.update(len(rows_and_insert_ids))
len(rows_and_insert_ids_with_windows))
self.batch_size_metric.update(len(rows_and_insert_ids_with_windows))

rows_and_insert_ids, window_values = zip(*rows_and_insert_ids_with_windows)
rows = [r[0] for r in rows_and_insert_ids]
if self.ignore_insert_ids:
insert_ids = [None for r in rows_and_insert_ids]
Expand All @@ -1689,8 +1690,10 @@ def _flush_batch(self, destination):
ignore_unknown_values=self.ignore_unknown_columns)
self.batch_latency_metric.update((time.time() - start) * 1000)

failed_rows = [(rows[entry['index']], entry["errors"])
failed_rows = [(
rows[entry['index']], entry["errors"], window_values[entry['index']])
for entry in errors]
failed_insert_ids = [insert_ids[entry['index']] for entry in errors]
retry_backoff = next(self._backoff_calculator, None)

# If retry_backoff is None, then we will not retry and must log.
Expand Down Expand Up @@ -1721,27 +1724,33 @@ def _flush_batch(self, destination):
_LOGGER.info(
'Sleeping %s seconds before retrying insertion.', retry_backoff)
time.sleep(retry_backoff)
# We can now safely discard all information about successful rows and
# just focus on the failed ones
rows = [fr[0] for fr in failed_rows]
window_values = [fr[2] for fr in failed_rows]
insert_ids = failed_insert_ids
self._throttled_secs.inc(retry_backoff)

self._total_buffered_rows -= len(self._rows_buffer[destination])
del self._rows_buffer[destination]
if destination in self._destination_buffer_byte_size:
del self._destination_buffer_byte_size[destination]

damccorm marked this conversation as resolved.
Show resolved Hide resolved
return itertools.chain([
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS,
GlobalWindows.windowed_value((destination, row, err))) for row,
err in failed_rows
],
[
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS,
GlobalWindows.windowed_value(
(destination, row))) for row,
unused_err in failed_rows
])
return itertools.chain(
[
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS,
w.with_value((destination, row, err))) for row,
err,
w in failed_rows
],
[
pvalue.TaggedOutput(
BigQueryWriteFn.FAILED_ROWS, w.with_value((destination, row)))
for row,
unused_err,
w in failed_rows
])


# The number of shards per destination when writing via streaming inserts.
Expand Down
1 change: 1 addition & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def test_big_query_write_insert_non_transient_api_call_error(self):
# pylint: disable=expression-not-assigned
errors = (
p | 'create' >> beam.Create(input_data)
| beam.WindowInto(beam.transforms.window.FixedWindows(10))
| 'write' >> beam.io.WriteToBigQuery(
table_id,
schema=table_schema,
Expand Down
Loading