Skip to content

Commit

Permalink
Support dynamic destinations with Python Storage API (#30045)
Browse files Browse the repository at this point in the history
* support dynamic destinations and add tests

* put all relevant logic in StorageWriteToBigQuery
  • Loading branch information
ahmedabu98 authored Jan 24, 2024
1 parent e85d070 commit 2721414
Show file tree
Hide file tree
Showing 7 changed files with 334 additions and 151 deletions.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
## I/Os

* Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Added support for writing to BigQuery dynamic destinations with Python's Storage Write API ([#30045](https://github.com/apache/beam/pull/30045))
* Adding support for Tuples DataType in ClickHouse (Java) ([#29715](https://github.com/apache/beam/pull/29715)).
* Added support for handling bad records to FileIO, TextIO, AvroIO ([#29670](https://github.com/apache/beam/pull/29670)).
* Added support for handling bad records to BigtableIO ([#29885](https://github.com/apache/beam/pull/29885)).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;

import com.google.api.services.bigquery.model.TableSchema;
import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import java.util.Arrays;
Expand All @@ -35,6 +36,8 @@
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryStorageApiInsertError;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils;
import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations;
import org.apache.beam.sdk.io.gcp.bigquery.TableDestination;
import org.apache.beam.sdk.io.gcp.bigquery.WriteResult;
import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration;
import org.apache.beam.sdk.metrics.Counter;
Expand All @@ -56,6 +59,7 @@
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.values.ValueInSingleWindow;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
Expand All @@ -81,6 +85,8 @@ public class BigQueryStorageWriteApiSchemaTransformProvider
private static final String INPUT_ROWS_TAG = "input";
private static final String FAILED_ROWS_TAG = "FailedRows";
private static final String FAILED_ROWS_WITH_ERRORS_TAG = "FailedRowsWithErrors";
// magic string that tells us to write to dynamic destinations
protected static final String DYNAMIC_DESTINATIONS = "DYNAMIC_DESTINATIONS";

@Override
protected Class<BigQueryStorageWriteApiSchemaTransformConfiguration> configurationClass() {
Expand Down Expand Up @@ -161,7 +167,11 @@ public void validate() {
checkArgument(
!Strings.isNullOrEmpty(this.getTable()),
invalidConfigMessage + "Table spec for a BigQuery Write must be specified.");
checkNotNull(BigQueryHelpers.parseTableSpec(this.getTable()));

// if we have an input table spec, validate it
if (!this.getTable().equals(DYNAMIC_DESTINATIONS)) {
checkNotNull(BigQueryHelpers.parseTableSpec(this.getTable()));
}

// validate create and write dispositions
if (!Strings.isNullOrEmpty(this.getCreateDisposition())) {
Expand Down Expand Up @@ -337,13 +347,36 @@ private static class NoOutputDoFn<T> extends DoFn<T, Row> {
public void process(ProcessContext c) {}
}

private static class RowDynamicDestinations extends DynamicDestinations<Row, String> {
Schema schema;

RowDynamicDestinations(Schema schema) {
this.schema = schema;
}

@Override
public String getDestination(ValueInSingleWindow<Row> element) {
return element.getValue().getString("destination");
}

@Override
public TableDestination getTable(String destination) {
return new TableDestination(destination, null);
}

@Override
public TableSchema getSchema(String destination) {
return BigQueryUtils.toTableSchema(schema);
}
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
// Check that the input exists
checkArgument(input.has(INPUT_ROWS_TAG), "Missing expected input tag: %s", INPUT_ROWS_TAG);
PCollection<Row> inputRows = input.get(INPUT_ROWS_TAG);

BigQueryIO.Write<Row> write = createStorageWriteApiTransform();
BigQueryIO.Write<Row> write = createStorageWriteApiTransform(inputRows.getSchema());

if (inputRows.isBounded() == IsBounded.UNBOUNDED) {
Long triggeringFrequency = configuration.getTriggeringFrequencySeconds();
Expand All @@ -358,9 +391,8 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
}

boolean useAtLeastOnceSemantics =
configuration.getUseAtLeastOnceSemantics() == null
? false
: configuration.getUseAtLeastOnceSemantics();
configuration.getUseAtLeastOnceSemantics() != null
&& configuration.getUseAtLeastOnceSemantics();
// Triggering frequency is only applicable for exactly-once
if (!useAtLeastOnceSemantics) {
write =
Expand Down Expand Up @@ -433,7 +465,7 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
}
}

BigQueryIO.Write<Row> createStorageWriteApiTransform() {
BigQueryIO.Write<Row> createStorageWriteApiTransform(Schema schema) {
Method writeMethod =
configuration.getUseAtLeastOnceSemantics() != null
&& configuration.getUseAtLeastOnceSemantics()
Expand All @@ -442,12 +474,23 @@ BigQueryIO.Write<Row> createStorageWriteApiTransform() {

BigQueryIO.Write<Row> write =
BigQueryIO.<Row>write()
.to(configuration.getTable())
.withMethod(writeMethod)
.useBeamSchema()
.withFormatFunction(BigQueryUtils.toTableRow())
.withWriteDisposition(WriteDisposition.WRITE_APPEND);

if (configuration.getTable().equals(DYNAMIC_DESTINATIONS)) {
checkArgument(
schema.getFieldNames().equals(Arrays.asList("destination", "record")),
"When writing to dynamic destinations, we expect Row Schema with a "
+ "\"destination\" string field and a \"record\" Row field.");
write =
write
.to(new RowDynamicDestinations(schema.getField("record").getType().getRowSchema()))
.withFormatFunction(row -> BigQueryUtils.toTableRow(row.getRow("record")));
} else {
write = write.to(configuration.getTable()).useBeamSchema();
}

if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) {
CreateDisposition createDisposition =
BigQueryStorageWriteApiSchemaTransformConfiguration.CREATE_DISPOSITIONS.get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers;
import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransform;
Expand Down Expand Up @@ -136,8 +137,8 @@ public PCollectionRowTuple runWithConfig(

writeTransform.setBigQueryServices(fakeBigQueryServices);
String tag = provider.inputCollectionNames().get(0);

PCollection<Row> rows = p.apply(Create.of(inputRows).withRowSchema(SCHEMA));
PCollection<Row> rows =
p.apply(Create.of(inputRows).withRowSchema(inputRows.get(0).getSchema()));

PCollectionRowTuple input = PCollectionRowTuple.of(tag, rows);
PCollectionRowTuple result = input.apply(writeTransform);
Expand All @@ -155,16 +156,20 @@ public Boolean rowsEquals(List<Row> expectedRows, List<TableRow> actualRows) {
TableRow actualRow = actualRows.get(i);
Row expectedRow = expectedRows.get(Integer.parseInt(actualRow.get("number").toString()) - 1);

if (!expectedRow.getValue("name").equals(actualRow.get("name"))
|| !expectedRow
.getValue("number")
.equals(Long.parseLong(actualRow.get("number").toString()))) {
if (!rowEquals(expectedRow, actualRow)) {
return false;
}
}
return true;
}

public boolean rowEquals(Row expectedRow, TableRow actualRow) {
return expectedRow.getValue("name").equals(actualRow.get("name"))
&& expectedRow
.getValue("number")
.equals(Long.parseLong(actualRow.get("number").toString()));
}

@Test
public void testSimpleWrite() throws Exception {
String tableSpec = "project:dataset.simple_write";
Expand All @@ -179,6 +184,43 @@ public void testSimpleWrite() throws Exception {
rowsEquals(ROWS, fakeDatasetService.getAllRows("project", "dataset", "simple_write")));
}

@Test
public void testWriteToDynamicDestinations() throws Exception {
String dynamic = BigQueryStorageWriteApiSchemaTransformProvider.DYNAMIC_DESTINATIONS;
BigQueryStorageWriteApiSchemaTransformConfiguration config =
BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(dynamic).build();

String baseTableSpec = "project:dataset.dynamic_write_";

Schema schemaWithDestinations =
Schema.builder().addStringField("destination").addRowField("record", SCHEMA).build();
List<Row> rowsWithDestinations =
ROWS.stream()
.map(
row ->
Row.withSchema(schemaWithDestinations)
.withFieldValue("destination", baseTableSpec + row.getInt64("number"))
.withFieldValue("record", row)
.build())
.collect(Collectors.toList());

runWithConfig(config, rowsWithDestinations);
p.run().waitUntilFinish();

assertTrue(
rowEquals(
ROWS.get(0),
fakeDatasetService.getAllRows("project", "dataset", "dynamic_write_1").get(0)));
assertTrue(
rowEquals(
ROWS.get(1),
fakeDatasetService.getAllRows("project", "dataset", "dynamic_write_2").get(0)));
assertTrue(
rowEquals(
ROWS.get(2),
fakeDatasetService.getAllRows("project", "dataset", "dynamic_write_3").get(0)));
}

@Test
public void testInputElementCount() throws Exception {
String tableSpec = "project:dataset.input_count";
Expand Down
77 changes: 71 additions & 6 deletions sdks/python/apache_beam/io/external/xlang_bigqueryio_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import pytest
from hamcrest.core import assert_that as hamcrest_assert
from hamcrest.core.core.allof import all_of

import apache_beam as beam
from apache_beam.io.gcp.bigquery import StorageWriteToBigQuery
Expand All @@ -52,9 +53,6 @@


@pytest.mark.uses_gcp_java_expansion_service
# @unittest.skipUnless(
# os.environ.get('EXPANSION_PORT'),
# "EXPANSION_PORT environment var is not provided.")
class BigQueryXlangStorageWriteIT(unittest.TestCase):
BIGQUERY_DATASET = 'python_xlang_storage_write'

Expand Down Expand Up @@ -114,7 +112,8 @@ def setUp(self):
_LOGGER.info(
"Created dataset %s in project %s", self.dataset_id, self.project)

_LOGGER.info("expansion port: %s", os.environ.get('EXPANSION_PORT'))
self.assertTrue(
os.environ.get('EXPANSION_PORT'), "Expansion service port not found!")
self.expansion_service = ('localhost:%s' % os.environ.get('EXPANSION_PORT'))

def tearDown(self):
Expand All @@ -132,6 +131,8 @@ def tearDown(self):
self.project)

def parse_expected_data(self, expected_elements):
if not isinstance(expected_elements, list):
expected_elements = [expected_elements]
data = []
for row in expected_elements:
values = list(row.values())
Expand Down Expand Up @@ -246,6 +247,66 @@ def test_write_with_beam_rows(self):
table=table_id, expansion_service=self.expansion_service))
hamcrest_assert(p, bq_matcher)

def test_write_to_dynamic_destinations(self):
base_table_spec = '{}.dynamic_dest_'.format(self.dataset_id)
spec_with_project = '{}:{}'.format(self.project, base_table_spec)
tables = [base_table_spec + str(record['int']) for record in self.ELEMENTS]

bq_matchers = [
BigqueryFullResultMatcher(
project=self.project,
query="SELECT * FROM %s" % tables[i],
data=self.parse_expected_data(self.ELEMENTS[i]))
for i in range(len(tables))
]

with beam.Pipeline(argv=self.args) as p:
_ = (
p
| beam.Create(self.ELEMENTS)
| beam.io.WriteToBigQuery(
table=lambda record: spec_with_project + str(record['int']),
method=beam.io.WriteToBigQuery.Method.STORAGE_WRITE_API,
schema=self.ALL_TYPES_SCHEMA,
use_at_least_once=False,
expansion_service=self.expansion_service))
hamcrest_assert(p, all_of(*bq_matchers))

def test_write_to_dynamic_destinations_with_beam_rows(self):
base_table_spec = '{}.dynamic_dest_'.format(self.dataset_id)
spec_with_project = '{}:{}'.format(self.project, base_table_spec)
tables = [base_table_spec + str(record['int']) for record in self.ELEMENTS]

bq_matchers = [
BigqueryFullResultMatcher(
project=self.project,
query="SELECT * FROM %s" % tables[i],
data=self.parse_expected_data(self.ELEMENTS[i]))
for i in range(len(tables))
]

row_elements = [
beam.Row(
my_int=e['int'],
my_float=e['float'],
my_numeric=e['numeric'],
my_string=e['str'],
my_bool=e['bool'],
my_bytes=e['bytes'],
my_timestamp=e['timestamp']) for e in self.ELEMENTS
]

with beam.Pipeline(argv=self.args) as p:
_ = (
p
| beam.Create(row_elements)
| beam.io.WriteToBigQuery(
table=lambda record: spec_with_project + str(record.my_int),
method=beam.io.WriteToBigQuery.Method.STORAGE_WRITE_API,
use_at_least_once=False,
expansion_service=self.expansion_service))
hamcrest_assert(p, all_of(*bq_matchers))

def run_streaming(self, table_name, num_streams=0, use_at_least_once=False):
elements = self.ELEMENTS.copy()
schema = self.ALL_TYPES_SCHEMA
Expand Down Expand Up @@ -278,20 +339,24 @@ def run_streaming(self, table_name, num_streams=0, use_at_least_once=False):
expansion_service=self.expansion_service))
hamcrest_assert(p, bq_matcher)

def test_streaming_with_fixed_num_streams(self):
def skip_if_not_dataflow_runner(self) -> bool:
# skip if dataflow runner is not specified
if not self._runner or "dataflowrunner" not in self._runner.lower():
self.skipTest(
"The exactly-once route has the requirement "
"Streaming with exactly-once route has the requirement "
"`beam:requirement:pardo:on_window_expiration:v1`, "
"which is currently only supported by the Dataflow runner")

def test_streaming_with_fixed_num_streams(self):
self.skip_if_not_dataflow_runner()
table = 'streaming_fixed_num_streams'
self.run_streaming(table_name=table, num_streams=4)

@unittest.skip(
"Streaming to the Storage Write API sink with autosharding is broken "
"with Dataflow Runner V2.")
def test_streaming_with_auto_sharding(self):
self.skip_if_not_dataflow_runner()
table = 'streaming_with_auto_sharding'
self.run_streaming(table_name=table)

Expand Down
Loading

0 comments on commit 2721414

Please sign in to comment.