Skip to content

Commit

Permalink
Pass user specified destination type to UpdateSchemaDestination (#22624
Browse files Browse the repository at this point in the history
…) fixing #22543

* keeping hold of user specified dynamic destination type to be able to use it in UpdateSchemaDestinations

* fix for testWriteTables

* cleanup and support default project when not included in table ref

* allow side inputs called from getTable()

* style fixes
  • Loading branch information
ahmedabu98 authored Aug 26, 2022
1 parent 8347b9e commit 3217017
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;

import com.google.api.services.bigquery.model.TableReference;
import com.google.api.services.bigquery.model.TableRow;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -354,7 +355,7 @@ private WriteResult expandTriggered(PCollection<KV<DestinationT, ElementT>> inpu
rowWriterFactory))
.withSideInputs(tempFilePrefixView)
.withOutputTags(multiPartitionsTag, TupleTagList.of(singlePartitionTag)));
PCollection<KV<TableDestination, WriteTables.Result>> tempTables =
PCollection<KV<DestinationT, WriteTables.Result>> tempTables =
writeTempTables(partitions.get(multiPartitionsTag), tempLoadJobIdPrefixView);

List<PCollectionView<?>> sideInputsForUpdateSchema =
Expand All @@ -366,15 +367,15 @@ private WriteResult expandTriggered(PCollection<KV<DestinationT, ElementT>> inpu
// Now that the load job has happened, we want the rename to happen immediately.
.apply(
"Window Into Global Windows",
Window.<KV<TableDestination, WriteTables.Result>>into(new GlobalWindows())
Window.<KV<DestinationT, WriteTables.Result>>into(new GlobalWindows())
.triggering(Repeatedly.forever(AfterPane.elementCountAtLeast(1))))
.apply("Add Void Key", WithKeys.of((Void) null))
.setCoder(KvCoder.of(VoidCoder.of(), tempTables.getCoder()))
.apply("GroupByKey", GroupByKey.create())
.apply("Extract Values", Values.create())
.apply(
ParDo.of(
new UpdateSchemaDestination(
new UpdateSchemaDestination<DestinationT>(
bigQueryServices,
tempLoadJobIdPrefixView,
loadJobProjectId,
Expand Down Expand Up @@ -470,7 +471,7 @@ public WriteResult expandUntriggered(PCollection<KV<DestinationT, ElementT>> inp
.apply("ReifyRenameInput", new ReifyAsIterable<>())
.apply(
ParDo.of(
new UpdateSchemaDestination(
new UpdateSchemaDestination<DestinationT>(
bigQueryServices,
tempLoadJobIdPrefixView,
loadJobProjectId,
Expand Down Expand Up @@ -702,7 +703,7 @@ public KV<DestinationT, Iterable<ElementT>> apply(
}

// Take in a list of files and write them to temporary tables.
private PCollection<KV<TableDestination, WriteTables.Result>> writeTempTables(
private PCollection<KV<DestinationT, WriteTables.Result>> writeTempTables(
PCollection<KV<ShardedKey<DestinationT>, WritePartition.Result>> input,
PCollectionView<String> jobIdTokenView) {
List<PCollectionView<?>> sideInputs = Lists.newArrayList(jobIdTokenView);
Expand All @@ -713,9 +714,6 @@ private PCollection<KV<TableDestination, WriteTables.Result>> writeTempTables(
ShardedKeyCoder.of(NullableCoder.of(destinationCoder)),
WritePartition.ResultCoder.INSTANCE);

Coder<TableDestination> tableDestinationCoder =
clusteringEnabled ? TableDestinationCoderV3.of() : TableDestinationCoderV2.of();

// If WriteBundlesToFiles produced more than DEFAULT_MAX_FILES_PER_PARTITION files or
// DEFAULT_MAX_BYTES_PER_PARTITION bytes, then
// the import needs to be split into multiple partitions, and those partitions will be
Expand Down Expand Up @@ -746,7 +744,7 @@ private PCollection<KV<TableDestination, WriteTables.Result>> writeTempTables(
// https://github.com/apache/beam/issues/21105 for additional details.
schemaUpdateOptions,
tempDataset))
.setCoder(KvCoder.of(tableDestinationCoder, WriteTables.ResultCoder.INSTANCE));
.setCoder(KvCoder.of(destinationCoder, WriteTables.ResultCoder.INSTANCE));
}

// In the case where the files fit into a single load job, there's no need to write temporary
Expand All @@ -765,7 +763,7 @@ PCollection<TableDestination> writeSinglePartition(
ShardedKeyCoder.of(NullableCoder.of(destinationCoder)),
WritePartition.ResultCoder.INSTANCE);
// Write single partition to final table
PCollection<KV<TableDestination, WriteTables.Result>> successfulWrites =
PCollection<KV<DestinationT, WriteTables.Result>> successfulWrites =
input
.setCoder(partitionsCoder)
// Reshuffle will distribute this among multiple workers, and also guard against
Expand All @@ -789,9 +787,35 @@ PCollection<TableDestination> writeSinglePartition(
useAvroLogicalTypes,
schemaUpdateOptions,
null))
.setCoder(KvCoder.of(tableDestinationCoder, WriteTables.ResultCoder.INSTANCE));
.setCoder(KvCoder.of(destinationCoder, WriteTables.ResultCoder.INSTANCE));

return successfulWrites.apply(Keys.create());
BigQueryOptions options = input.getPipeline().getOptions().as(BigQueryOptions.class);
String defaultProjectId =
options.getBigQueryProject() == null ? options.getProject() : options.getBigQueryProject();

return successfulWrites
.apply(Keys.create())
.apply(
"Convert to TableDestinations",
ParDo.of(
new DoFn<DestinationT, TableDestination>() {
@ProcessElement
public void processElement(ProcessContext c) {
dynamicDestinations.setSideInputAccessorFromProcessContext(c);
TableDestination tableDestination =
dynamicDestinations.getTable(c.element());
TableReference tableReference = tableDestination.getTableReference();

// get project ID from options if it's not included in the table reference
if (Strings.isNullOrEmpty(tableReference.getProjectId())) {
tableReference.setProjectId(defaultProjectId);
tableDestination = tableDestination.withTableReference(tableReference);
}
c.output(tableDestination);
}
})
.withSideInputs(sideInputs))
.setCoder(tableDestinationCoder);
}

private WriteResult writeResult(Pipeline p, PCollection<TableDestination> successfulWrites) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.google.api.services.bigquery.model.TableSchema;
import com.google.api.services.bigquery.model.TimePartitioning;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
Expand All @@ -38,14 +39,15 @@
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@SuppressWarnings({"nullness", "rawtypes"})
public class UpdateSchemaDestination
public class UpdateSchemaDestination<DestinationT>
extends DoFn<
Iterable<KV<TableDestination, WriteTables.Result>>,
Iterable<KV<DestinationT, WriteTables.Result>>,
Iterable<KV<TableDestination, WriteTables.Result>>> {

private static final Logger LOG = LoggerFactory.getLogger(UpdateSchemaDestination.class);
Expand Down Expand Up @@ -104,21 +106,37 @@ public void startBundle(StartBundleContext c) {
pendingJobs.clear();
}

TableDestination getTableWithDefaultProject(DestinationT destination, BigQueryOptions options) {
TableDestination tableDestination = dynamicDestinations.getTable(destination);
TableReference tableReference = tableDestination.getTableReference();

if (Strings.isNullOrEmpty(tableReference.getProjectId())) {
tableReference.setProjectId(
options.getBigQueryProject() == null
? options.getProject()
: options.getBigQueryProject());
tableDestination = tableDestination.withTableReference(tableReference);
}

return tableDestination;
}

@ProcessElement
public void processElement(
@Element Iterable<KV<TableDestination, WriteTables.Result>> element,
@Element Iterable<KV<DestinationT, WriteTables.Result>> element,
ProcessContext context,
BoundedWindow window)
throws IOException {
Object destination = null;
for (KV<TableDestination, WriteTables.Result> entry : element) {
DestinationT destination = null;
BigQueryOptions options = context.getPipelineOptions().as(BigQueryOptions.class);
for (KV<DestinationT, WriteTables.Result> entry : element) {
destination = entry.getKey();
if (destination != null) {
break;
}
}
if (destination != null) {
TableDestination tableDestination = dynamicDestinations.getTable(destination);
TableDestination tableDestination = getTableWithDefaultProject(destination, options);
TableSchema schema = dynamicDestinations.getSchema(destination);
TableReference tableReference = tableDestination.getTableReference();
String jobIdPrefix =
Expand All @@ -143,8 +161,13 @@ public void processElement(
if (updateSchemaDestinationJob != null) {
pendingJobs.add(new PendingJobData(updateSchemaDestinationJob, tableDestination, window));
}
context.output(element);
}
List<KV<TableDestination, WriteTables.Result>> tableDestinations = new ArrayList<>();
for (KV<DestinationT, WriteTables.Result> entry : element) {
tableDestinations.add(
KV.of(getTableWithDefaultProject(destination, options), entry.getValue()));
}
context.output(tableDestinations);
}

@Teardown
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
class WriteTables<DestinationT extends @NonNull Object>
extends PTransform<
PCollection<KV<ShardedKey<DestinationT>, WritePartition.Result>>,
PCollection<KV<TableDestination, WriteTables.Result>>> {
PCollection<KV<DestinationT, WriteTables.Result>>> {
@AutoValue
abstract static class Result {
abstract String getTableName();
Expand Down Expand Up @@ -135,7 +135,7 @@ public Result decode(@UnknownKeyFor @NonNull @Initialized InputStream inStream)
private final Set<SchemaUpdateOption> schemaUpdateOptions;
private final DynamicDestinations<?, DestinationT> dynamicDestinations;
private final List<PCollectionView<?>> sideInputs;
private final TupleTag<KV<TableDestination, WriteTables.Result>> mainOutputTag;
private final TupleTag<KV<DestinationT, WriteTables.Result>> mainOutputTag;
private final TupleTag<String> temporaryFilesTag;
private final @Nullable ValueProvider<String> loadJobProjectId;
private final int maxRetryJobs;
Expand All @@ -148,8 +148,7 @@ public Result decode(@UnknownKeyFor @NonNull @Initialized InputStream inStream)
private final @Nullable String tempDataset;

private class WriteTablesDoFn
extends DoFn<
KV<ShardedKey<DestinationT>, WritePartition.Result>, KV<TableDestination, Result>> {
extends DoFn<KV<ShardedKey<DestinationT>, WritePartition.Result>, KV<DestinationT, Result>> {

private Map<DestinationT, String> jsonSchemas = Maps.newHashMap();

Expand All @@ -160,6 +159,7 @@ private class PendingJobData {
final List<String> partitionFiles;
final TableDestination tableDestination;
final TableReference tableReference;
final DestinationT destinationT;
final boolean isFirstPane;

public PendingJobData(
Expand All @@ -168,12 +168,14 @@ public PendingJobData(
List<String> partitionFiles,
TableDestination tableDestination,
TableReference tableReference,
DestinationT destinationT,
boolean isFirstPane) {
this.window = window;
this.retryJob = retryJob;
this.partitionFiles = partitionFiles;
this.tableDestination = tableDestination;
this.tableReference = tableReference;
this.destinationT = destinationT;
this.isFirstPane = isFirstPane;
}
}
Expand Down Expand Up @@ -292,6 +294,7 @@ public void processElement(
partitionFiles,
tableDestination,
tableReference,
destination,
element.getValue().isFirstPane()));
}

Expand Down Expand Up @@ -359,7 +362,7 @@ public void finishBundle(FinishBundleContext c) throws Exception {
pendingJob.isFirstPane);
c.output(
mainOutputTag,
KV.of(pendingJob.tableDestination, result),
KV.of(pendingJob.destinationT, result),
pendingJob.window.maxTimestamp(),
pendingJob.window);
for (String file : pendingJob.partitionFiles) {
Expand Down Expand Up @@ -423,7 +426,7 @@ public WriteTables(
}

@Override
public PCollection<KV<TableDestination, Result>> expand(
public PCollection<KV<DestinationT, Result>> expand(
PCollection<KV<ShardedKey<DestinationT>, WritePartition.Result>> input) {
PCollectionTuple writeTablesOutputs =
input.apply(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@
import org.apache.beam.sdk.testing.TestStream;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.Distinct;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnTester;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.SimpleFunction;
Expand Down Expand Up @@ -2209,6 +2211,8 @@ public void testWriteTables() throws Exception {
p.apply("CreateJobId", Create.of("jobId")).apply(View.asSingleton());
List<PCollectionView<?>> sideInputs = ImmutableList.of(jobIdTokenView);

DynamicDestinations<String, String> dynamicDestinations = new IdentityDynamicTables();

fakeJobService.setNumFailuresExpected(3);
WriteTables<String> writeTables =
new WriteTables<>(
Expand All @@ -2218,7 +2222,7 @@ public void testWriteTables() throws Exception {
BigQueryIO.Write.WriteDisposition.WRITE_EMPTY,
BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED,
sideInputs,
new IdentityDynamicTables(),
dynamicDestinations,
null,
4,
false,
Expand All @@ -2231,12 +2235,24 @@ public void testWriteTables() throws Exception {
PCollection<KV<TableDestination, WriteTables.Result>> writeTablesOutput =
writeTablesInput
.apply(writeTables)
.setCoder(KvCoder.of(TableDestinationCoderV3.of(), WriteTables.ResultCoder.INSTANCE));
.setCoder(KvCoder.of(StringUtf8Coder.of(), WriteTables.ResultCoder.INSTANCE))
.apply(
ParDo.of(
new DoFn<
KV<String, WriteTables.Result>,
KV<TableDestination, WriteTables.Result>>() {
@ProcessElement
public void processElement(
@Element KV<String, WriteTables.Result> e,
OutputReceiver<KV<TableDestination, WriteTables.Result>> o) {
o.output(KV.of(dynamicDestinations.getTable(e.getKey()), e.getValue()));
}
}));

PAssert.thatMultimap(writeTablesOutput)
.satisfies(
input -> {
assertEquals(input.keySet(), expectedTempTables.keySet());
assertEquals(expectedTempTables.keySet(), input.keySet());
for (Map.Entry<TableDestination, Iterable<WriteTables.Result>> entry :
input.entrySet()) {
Iterable<String> tableNames =
Expand Down

0 comments on commit 3217017

Please sign in to comment.