Skip to content

Commit

Permalink
address issues in Storage API writes
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenlax committed May 25, 2021
1 parent 8922c1c commit 73db9c1
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2846,7 +2846,6 @@ private <DestinationT> WriteResult continueExpandTyped(
StorageApiLoads<DestinationT, T> storageApiLoads =
new StorageApiLoads<DestinationT, T>(
destinationCoder,
elementCoder,
storageApiDynamicDestinations,
getCreateDisposition(),
getKmsKey(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.gcp.bigquery;

import org.apache.beam.sdk.io.gcp.bigquery.StorageApiDynamicDestinations.MessageConverter;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;

/**
* A transform that converts messages to protocol buffers in preparation for writing to BigQuery.
*/
public class StorageApiConvertMessages<DestinationT, ElementT>
extends PTransform<
PCollection<KV<DestinationT, ElementT>>, PCollection<KV<DestinationT, byte[]>>> {
private final StorageApiDynamicDestinations<ElementT, DestinationT> dynamicDestinations;

public StorageApiConvertMessages(
StorageApiDynamicDestinations<ElementT, DestinationT> dynamicDestinations) {
this.dynamicDestinations = dynamicDestinations;
}

@Override
public PCollection<KV<DestinationT, byte[]>> expand(
PCollection<KV<DestinationT, ElementT>> input) {
String operationName = input.getName() + "/" + getName();

return input.apply(
"Convert to message",
ParDo.of(new ConvertMessagesDoFn<>(dynamicDestinations, operationName))
.withSideInputs(dynamicDestinations.getSideInputs()));
}

public static class ConvertMessagesDoFn<DestinationT, ElementT>
extends DoFn<KV<DestinationT, ElementT>, KV<DestinationT, byte[]>> {
private final StorageApiDynamicDestinations<ElementT, DestinationT> dynamicDestinations;
private TwoLevelMessageConverterCache<DestinationT, ElementT> messageConverters;

ConvertMessagesDoFn(
StorageApiDynamicDestinations<ElementT, DestinationT> dynamicDestinations,
String operationName) {
this.dynamicDestinations = dynamicDestinations;
this.messageConverters = new TwoLevelMessageConverterCache<>(operationName);
}

@ProcessElement
public void processElement(
ProcessContext c,
@Element KV<DestinationT, ElementT> element,
OutputReceiver<KV<DestinationT, byte[]>> o)
throws Exception {
dynamicDestinations.setSideInputAccessorFromProcessContext(c);
MessageConverter<ElementT> messageConverter =
messageConverters.get(element.getKey(), dynamicDestinations);
o.output(
KV.of(element.getKey(), messageConverter.toMessage(element.getValue()).toByteArray()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,16 @@
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.GroupIntoBatches;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.windowing.AfterFirst;
import org.apache.beam.sdk.transforms.windowing.AfterPane;
import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Repeatedly;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.ShardedKey;
import org.apache.beam.sdk.values.KV;
Expand All @@ -48,10 +45,9 @@
public class StorageApiLoads<DestinationT, ElementT>
extends PTransform<PCollection<KV<DestinationT, ElementT>>, WriteResult> {
private static final Logger LOG = LoggerFactory.getLogger(StorageApiLoads.class);
static final int FILE_TRIGGERING_RECORD_COUNT = 100;
static final int MAX_BATCH_SIZE_BYTES = 2 * 1024 * 1024;

private final Coder<DestinationT> destinationCoder;
private final Coder<ElementT> elementCoder;
private final StorageApiDynamicDestinations<ElementT, DestinationT> dynamicDestinations;
private final CreateDisposition createDisposition;
private final String kmsKey;
Expand All @@ -61,15 +57,13 @@ public class StorageApiLoads<DestinationT, ElementT>

public StorageApiLoads(
Coder<DestinationT> destinationCoder,
Coder<ElementT> elementCoder,
StorageApiDynamicDestinations<ElementT, DestinationT> dynamicDestinations,
CreateDisposition createDisposition,
String kmsKey,
Duration triggeringFrequency,
BigQueryServices bqServices,
int numShards) {
this.destinationCoder = destinationCoder;
this.elementCoder = elementCoder;
this.dynamicDestinations = dynamicDestinations;
this.createDisposition = createDisposition;
this.kmsKey = kmsKey;
Expand All @@ -86,25 +80,17 @@ public WriteResult expand(PCollection<KV<DestinationT, ElementT>> input) {
public WriteResult expandTriggered(PCollection<KV<DestinationT, ElementT>> input) {
// Handle triggered, low-latency loads into BigQuery.
PCollection<KV<DestinationT, ElementT>> inputInGlobalWindow =
input.apply(
"rewindowIntoGlobal",
Window.<KV<DestinationT, ElementT>>into(new GlobalWindows())
.triggering(
Repeatedly.forever(
AfterFirst.of(
AfterProcessingTime.pastFirstElementInPane()
.plusDelayOf(triggeringFrequency),
AfterPane.elementCountAtLeast(FILE_TRIGGERING_RECORD_COUNT))))
.discardingFiredPanes());
input.apply("rewindowIntoGlobal", Window.into(new GlobalWindows()));

// First shard all the records.
// TODO(reuvenlax): Add autosharding support so that users don't have to pick a shard count.
PCollection<KV<ShardedKey<DestinationT>, ElementT>> shardedRecords =
PCollection<KV<ShardedKey<DestinationT>, byte[]>> shardedRecords =
inputInGlobalWindow
.apply("Convert", new StorageApiConvertMessages<>(dynamicDestinations))
.apply(
"AddShard",
ParDo.of(
new DoFn<KV<DestinationT, ElementT>, KV<ShardedKey<DestinationT>, ElementT>>() {
new DoFn<KV<DestinationT, byte[]>, KV<ShardedKey<DestinationT>, byte[]>>() {
int shardNumber;

@Setup
Expand All @@ -114,19 +100,23 @@ public void setup() {

@ProcessElement
public void processElement(
@Element KV<DestinationT, ElementT> element,
OutputReceiver<KV<ShardedKey<DestinationT>, ElementT>> o) {
@Element KV<DestinationT, byte[]> element,
OutputReceiver<KV<ShardedKey<DestinationT>, byte[]>> o) {
DestinationT destination = element.getKey();
ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES);
buffer.putInt(++shardNumber % numShards);
o.output(
KV.of(ShardedKey.of(destination, buffer.array()), element.getValue()));
}
}))
.setCoder(KvCoder.of(ShardedKey.Coder.of(destinationCoder), elementCoder));
.setCoder(KvCoder.of(ShardedKey.Coder.of(destinationCoder), ByteArrayCoder.of()));

PCollection<KV<ShardedKey<DestinationT>, Iterable<ElementT>>> groupedRecords =
shardedRecords.apply("GroupIntoShards", GroupByKey.create());
PCollection<KV<ShardedKey<DestinationT>, Iterable<byte[]>>> groupedRecords =
shardedRecords.apply(
"GroupIntoBatches",
GroupIntoBatches.<ShardedKey<DestinationT>, byte[]>ofByteSize(
MAX_BATCH_SIZE_BYTES, (byte[] e) -> (long) e.length)
.withMaxBufferingDuration(triggeringFrequency));

groupedRecords.apply(
"StorageApiWriteSharded",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import com.google.cloud.bigquery.storage.v1beta2.WriteStream.Type;
import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Message;
import io.grpc.Status;
import io.grpc.Status.Code;
import java.io.IOException;
Expand Down Expand Up @@ -91,7 +90,7 @@
@SuppressWarnings("FutureReturnValueIgnored")
public class StorageApiWritesShardedRecords<DestinationT, ElementT>
extends PTransform<
PCollection<KV<ShardedKey<DestinationT>, Iterable<ElementT>>>, PCollection<Void>> {
PCollection<KV<ShardedKey<DestinationT>, Iterable<byte[]>>>, PCollection<Void>> {
private static final Logger LOG = LoggerFactory.getLogger(StorageApiWritesShardedRecords.class);

private final StorageApiDynamicDestinations<ElementT, DestinationT> dynamicDestinations;
Expand All @@ -104,7 +103,7 @@ public class StorageApiWritesShardedRecords<DestinationT, ElementT>

private static final Cache<String, StreamAppendClient> APPEND_CLIENTS =
CacheBuilder.newBuilder()
.expireAfterAccess(15, TimeUnit.MINUTES)
.expireAfterAccess(5, TimeUnit.MINUTES)
.removalListener(
(RemovalNotification<String, StreamAppendClient> removal) -> {
@Nullable final StreamAppendClient streamAppendClient = removal.getValue();
Expand Down Expand Up @@ -152,7 +151,7 @@ public StorageApiWritesShardedRecords(

@Override
public PCollection<Void> expand(
PCollection<KV<ShardedKey<DestinationT>, Iterable<ElementT>>> input) {
PCollection<KV<ShardedKey<DestinationT>, Iterable<byte[]>>> input) {
String operationName = input.getName() + "/" + getName();
// Append records to the Storage API streams.
PCollection<KV<String, Operation>> written =
Expand Down Expand Up @@ -194,19 +193,19 @@ public PCollection<Void> expand(
* parameter controls how many rows are batched into a single ProtoRows object before we move on
* to the next one.
*/
static class SplittingIterable<T extends Message> implements Iterable<ProtoRows> {
private final Iterable<T> underlying;
static class SplittingIterable implements Iterable<ProtoRows> {
private final Iterable<byte[]> underlying;
private final long splitSize;

public SplittingIterable(Iterable<T> underlying, long splitSize) {
public SplittingIterable(Iterable<byte[]> underlying, long splitSize) {
this.underlying = underlying;
this.splitSize = splitSize;
}

@Override
public Iterator<ProtoRows> iterator() {
return new Iterator<ProtoRows>() {
final Iterator<T> underlyingIterator = underlying.iterator();
final Iterator<byte[]> underlyingIterator = underlying.iterator();

@Override
public boolean hasNext() {
Expand All @@ -222,7 +221,7 @@ public ProtoRows next() {
ProtoRows.Builder inserts = ProtoRows.newBuilder();
long bytesSize = 0;
while (underlyingIterator.hasNext()) {
ByteString byteString = underlyingIterator.next().toByteString();
ByteString byteString = ByteString.copyFrom(underlyingIterator.next());
inserts.addSerializedRows(byteString);
bytesSize += byteString.size();
if (bytesSize > splitSize) {
Expand All @@ -236,7 +235,7 @@ public ProtoRows next() {
}

class WriteRecordsDoFn
extends DoFn<KV<ShardedKey<DestinationT>, Iterable<ElementT>>, KV<String, Operation>> {
extends DoFn<KV<ShardedKey<DestinationT>, Iterable<byte[]>>, KV<String, Operation>> {
private final Counter recordsAppended =
Metrics.counter(WriteRecordsDoFn.class, "recordsAppended");
private final Counter streamsCreated =
Expand All @@ -254,10 +253,10 @@ class WriteRecordsDoFn
private final Distribution appendSplitDistribution =
Metrics.distribution(WriteRecordsDoFn.class, "appendSplitDistribution");

private Map<DestinationT, TableDestination> destinations = Maps.newHashMap();

private TwoLevelMessageConverterCache<DestinationT, ElementT> messageConverters;

private Map<DestinationT, TableDestination> destinations = Maps.newHashMap();

// Stores the current stream for this key.
@StateId("streamName")
private final StateSpec<ValueState<String>> streamNameSpec = StateSpecs.value();
Expand Down Expand Up @@ -301,7 +300,7 @@ String getOrCreateStream(
public void process(
ProcessContext c,
final PipelineOptions pipelineOptions,
@Element KV<ShardedKey<DestinationT>, Iterable<ElementT>> element,
@Element KV<ShardedKey<DestinationT>, Iterable<byte[]>> element,
final @AlwaysFetched @StateId("streamName") ValueState<String> streamName,
final @AlwaysFetched @StateId("streamOffset") ValueState<Long> streamOffset,
final OutputReceiver<KV<String, Operation>> o)
Expand Down Expand Up @@ -336,12 +335,9 @@ public void process(

// Each ProtoRows object contains at most 1MB of rows.
// TODO: Push messageFromTableRow up to top level. That we we cans skip TableRow entirely if
// already proto or
// already schema.
// already proto or already schema.
final long oneMb = 1024 * 1024;
Iterable<ProtoRows> messages =
new SplittingIterable<>(
Iterables.transform(element.getValue(), e -> messageConverter.toMessage(e)), oneMb);
Iterable<ProtoRows> messages = new SplittingIterable(element.getValue(), oneMb);

class AppendRowsContext extends RetryManager.Operation.Context<AppendRowsResponse> {
final ShardedKey<DestinationT> key;
Expand Down Expand Up @@ -412,7 +408,7 @@ public String toString() {
Instant now = Instant.now();
List<AppendRowsContext> contexts = Lists.newArrayList();
RetryManager<AppendRowsResponse, AppendRowsContext> retryManager =
new RetryManager<>(Duration.standardSeconds(1), Duration.standardMinutes(1), 1000);
new RetryManager<>(Duration.standardSeconds(1), Duration.standardSeconds(10), 1000);
int numSplits = 0;
for (ProtoRows protoRows : messages) {
++numSplits;
Expand Down

0 comments on commit 73db9c1

Please sign in to comment.