diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index fce8f1c5d40c..491768fc3ca2 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -104,6 +104,7 @@ import org.apache.beam.sdk.io.gcp.bigquery.PassThroughThenCleanup.ContextContainer; import org.apache.beam.sdk.io.gcp.bigquery.RowWriterFactory.OutputType; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; @@ -3815,12 +3816,41 @@ private WriteResult continueExpandTyped( if (numShards == 0) { enableAutoSharding = true; } + if (StreamingOptions.updateCompatibilityVersionLessThan( + input.getPipeline().getOptions(), "2.57.0")) { + if (getFormatRecordOnFailureFunction() != null) { + throw new IllegalArgumentException( + "Formatting records on Failure is not supported on Beam Versions Less than 2.57"); + } + StorageApiLoads256 legacy = + new StorageApiLoads256( + destinationCoder, + storageApiDynamicDestinations, + getRowMutationInformationFn(), + getCreateDisposition(), + getKmsKey(), + getStorageApiTriggeringFrequency(bqOptions), + getBigQueryServices(), + getStorageApiNumStreams(bqOptions), + method == Method.STORAGE_API_AT_LEAST_ONCE, + enableAutoSharding, + getAutoSchemaUpdate(), + getIgnoreUnknownValues(), + getPropagateSuccessfulStorageApiWrites(), + getRowMutationInformationFn() != null, + getDefaultMissingValueInterpretation(), + getBadRecordRouter(), + getBadRecordErrorHandler()); + return input.apply("StorageApiLoads", legacy); + } StorageApiLoads storageApiLoads = new StorageApiLoads<>( destinationCoder, + elementCoder, storageApiDynamicDestinations, getRowMutationInformationFn(), + getFormatRecordOnFailureFunction(), getCreateDisposition(), getKmsKey(), getStorageApiTriggeringFrequency(bqOptions), diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SplittingIterable.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SplittingIterable.java index b8eeb2522cf2..9a17d5903a0b 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SplittingIterable.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SplittingIterable.java @@ -21,11 +21,15 @@ import com.google.auto.value.AutoValue; import com.google.cloud.bigquery.storage.v1.ProtoRows; import com.google.protobuf.ByteString; +import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.NoSuchElementException; import java.util.function.BiConsumer; import java.util.function.Function; +import org.apache.beam.sdk.io.gcp.bigquery.SplittingIterable.Value; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.checkerframework.checker.nullness.qual.Nullable; @@ -36,11 +40,13 @@ * parameter controls how many rows are batched into a single ProtoRows object before we move on to * the next one. */ -class SplittingIterable implements Iterable { +class SplittingIterable implements Iterable> { @AutoValue - abstract static class Value { + abstract static class Value { abstract ProtoRows getProtoRows(); + abstract List getOriginalElements(); + abstract List getTimestamps(); } @@ -49,7 +55,7 @@ ByteString convert(TableRow tableRow, boolean ignoreUnknownValues) throws TableRowToStorageApiProto.SchemaConversionException; } - private final Iterable underlying; + private final Iterable> underlying; private final long splitSize; private final ConvertUnknownFields unknownFieldsToMessage; @@ -60,15 +66,18 @@ ByteString convert(TableRow tableRow, boolean ignoreUnknownValues) private final Instant elementsTimestamp; + @Nullable SerializableFunction formatRecordOnFailureFunction; + public SplittingIterable( - Iterable underlying, + Iterable> underlying, long splitSize, ConvertUnknownFields unknownFieldsToMessage, Function protoToTableRow, BiConsumer, String> failedRowsConsumer, boolean autoUpdateSchema, boolean ignoreUnknownValues, - Instant elementsTimestamp) { + Instant elementsTimestamp, + @Nullable SerializableFunction formatRecordOnFailureFunction) { this.underlying = underlying; this.splitSize = splitSize; this.unknownFieldsToMessage = unknownFieldsToMessage; @@ -77,12 +86,14 @@ public SplittingIterable( this.autoUpdateSchema = autoUpdateSchema; this.ignoreUnknownValues = ignoreUnknownValues; this.elementsTimestamp = elementsTimestamp; + this.formatRecordOnFailureFunction = formatRecordOnFailureFunction; } @Override - public Iterator iterator() { - return new Iterator() { - final Iterator underlyingIterator = underlying.iterator(); + public Iterator> iterator() { + return new Iterator>() { + final Iterator> underlyingIterator = + underlying.iterator(); @Override public boolean hasNext() { @@ -90,16 +101,17 @@ public boolean hasNext() { } @Override - public Value next() { + public Value next() { if (!hasNext()) { throw new NoSuchElementException(); } List timestamps = Lists.newArrayList(); ProtoRows.Builder inserts = ProtoRows.newBuilder(); + List originalElements = new ArrayList<>(); long bytesSize = 0; while (underlyingIterator.hasNext()) { - StorageApiWritePayload payload = underlyingIterator.next(); + StorageApiWritePayload payload = underlyingIterator.next(); ByteString byteString = ByteString.copyFrom(payload.getPayload()); if (autoUpdateSchema) { try { @@ -116,7 +128,12 @@ public Value next() { // This generally implies that ignoreUnknownValues=false and there were still // unknown values here. // Reconstitute the TableRow and send it to the failed-rows consumer. - TableRow tableRow = protoToTableRow.apply(byteString); + TableRow tableRow; + if (formatRecordOnFailureFunction != null) { + tableRow = formatRecordOnFailureFunction.apply(payload.originalElement()); + } else { + tableRow = protoToTableRow.apply(byteString); + } // TODO(24926, reuvenlax): We need to merge the unknown fields in! Currently we // only execute this // codepath when ignoreUnknownFields==true, so we should never hit this codepath. @@ -137,6 +154,7 @@ public Value next() { } } inserts.addSerializedRows(byteString); + originalElements.add(payload.originalElement()); Instant timestamp = payload.getTimestamp(); if (timestamp == null) { timestamp = elementsTimestamp; @@ -147,7 +165,8 @@ public Value next() { break; } } - return new AutoValue_SplittingIterable_Value(inserts.build(), timestamps); + return new AutoValue_SplittingIterable_Value( + inserts.build(), originalElements, timestamps); } }; } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SplittingIterable256.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SplittingIterable256.java new file mode 100644 index 000000000000..5caaf7db2e74 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SplittingIterable256.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.auto.value.AutoValue; +import com.google.cloud.bigquery.storage.v1.ProtoRows; +import com.google.protobuf.ByteString; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.function.BiConsumer; +import java.util.function.Function; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; + +/** + * Takes in an iterable and batches the results into multiple ProtoRows objects. The splitSize + * parameter controls how many rows are batched into a single ProtoRows object before we move on to + * the next one. + */ +class SplittingIterable256 implements Iterable { + @AutoValue + abstract static class Value { + abstract ProtoRows getProtoRows(); + + abstract List getTimestamps(); + } + + interface ConvertUnknownFields { + ByteString convert(TableRow tableRow, boolean ignoreUnknownValues) + throws TableRowToStorageApiProto.SchemaConversionException; + } + + private final Iterable underlying; + private final long splitSize; + + private final ConvertUnknownFields unknownFieldsToMessage; + private final Function protoToTableRow; + private final BiConsumer, String> failedRowsConsumer; + private final boolean autoUpdateSchema; + private final boolean ignoreUnknownValues; + + private final Instant elementsTimestamp; + + public SplittingIterable256( + Iterable underlying, + long splitSize, + ConvertUnknownFields unknownFieldsToMessage, + Function protoToTableRow, + BiConsumer, String> failedRowsConsumer, + boolean autoUpdateSchema, + boolean ignoreUnknownValues, + Instant elementsTimestamp) { + this.underlying = underlying; + this.splitSize = splitSize; + this.unknownFieldsToMessage = unknownFieldsToMessage; + this.protoToTableRow = protoToTableRow; + this.failedRowsConsumer = failedRowsConsumer; + this.autoUpdateSchema = autoUpdateSchema; + this.ignoreUnknownValues = ignoreUnknownValues; + this.elementsTimestamp = elementsTimestamp; + } + + @Override + public Iterator iterator() { + return new Iterator() { + final Iterator underlyingIterator = underlying.iterator(); + + @Override + public boolean hasNext() { + return underlyingIterator.hasNext(); + } + + @Override + public Value next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + List timestamps = Lists.newArrayList(); + ProtoRows.Builder inserts = ProtoRows.newBuilder(); + long bytesSize = 0; + while (underlyingIterator.hasNext()) { + StorageApiWritePayload payload = underlyingIterator.next(); + ByteString byteString = ByteString.copyFrom(payload.getPayload()); + if (autoUpdateSchema) { + try { + @Nullable TableRow unknownFields = payload.getUnknownFields(); + if (unknownFields != null && !unknownFields.isEmpty()) { + // Protocol buffer serialization format supports concatenation. We serialize any new + // "known" fields + // into a proto and concatenate to the existing proto. + try { + byteString = + byteString.concat( + unknownFieldsToMessage.convert(unknownFields, ignoreUnknownValues)); + } catch (TableRowToStorageApiProto.SchemaConversionException e) { + // This generally implies that ignoreUnknownValues=false and there were still + // unknown values here. + // Reconstitute the TableRow and send it to the failed-rows consumer. + TableRow tableRow = protoToTableRow.apply(byteString); + // TODO(24926, reuvenlax): We need to merge the unknown fields in! Currently we + // only execute this + // codepath when ignoreUnknownFields==true, so we should never hit this codepath. + // However once + // 24926 is fixed, we need to merge the unknownFields back into the main row + // before outputting to the + // failed-rows consumer. + Instant timestamp = payload.getTimestamp(); + if (timestamp == null) { + timestamp = elementsTimestamp; + } + failedRowsConsumer.accept(TimestampedValue.of(tableRow, timestamp), e.toString()); + continue; + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + inserts.addSerializedRows(byteString); + Instant timestamp = payload.getTimestamp(); + if (timestamp == null) { + timestamp = elementsTimestamp; + } + timestamps.add(timestamp); + bytesSize += byteString.size(); + if (bytesSize > splitSize) { + break; + } + } + return new AutoValue_SplittingIterable256_Value(inserts.build(), timestamps); + } + }; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages.java index aefdb79c535c..8c0a9262b1ff 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages.java @@ -50,22 +50,26 @@ public class StorageApiConvertMessages private final StorageApiDynamicDestinations dynamicDestinations; private final BigQueryServices bqServices; private final TupleTag failedWritesTag; - private final TupleTag> successfulWritesTag; + private final TupleTag>> + successfulWritesTag; private final Coder errorCoder; - private final Coder> successCoder; + private final Coder>> successCoder; private final @Nullable SerializableFunction rowMutationFn; private final BadRecordRouter badRecordRouter; + private final @Nullable SerializableFunction formatRecordOnFailureFunction; + public StorageApiConvertMessages( StorageApiDynamicDestinations dynamicDestinations, BigQueryServices bqServices, TupleTag failedWritesTag, - TupleTag> successfulWritesTag, + TupleTag>> successfulWritesTag, Coder errorCoder, - Coder> successCoder, + Coder>> successCoder, @Nullable SerializableFunction rowMutationFn, - BadRecordRouter badRecordRouter) { + BadRecordRouter badRecordRouter, + @Nullable SerializableFunction formatRecordOnFailureFunction) { this.dynamicDestinations = dynamicDestinations; this.bqServices = bqServices; this.failedWritesTag = failedWritesTag; @@ -74,6 +78,7 @@ public StorageApiConvertMessages( this.successCoder = successCoder; this.rowMutationFn = rowMutationFn; this.badRecordRouter = badRecordRouter; + this.formatRecordOnFailureFunction = formatRecordOnFailureFunction; } @Override @@ -92,7 +97,8 @@ public PCollectionTuple expand(PCollection> input) { successfulWritesTag, rowMutationFn, badRecordRouter, - input.getCoder())) + input.getCoder(), + formatRecordOnFailureFunction)) .withOutputTags( successfulWritesTag, TupleTagList.of(ImmutableList.of(failedWritesTag, BAD_RECORD_TAG))) @@ -104,26 +110,31 @@ public PCollectionTuple expand(PCollection> input) { } public static class ConvertMessagesDoFn - extends DoFn, KV> { + extends DoFn< + KV, KV>> { private final StorageApiDynamicDestinations dynamicDestinations; private TwoLevelMessageConverterCache messageConverters; private final BigQueryServices bqServices; private final TupleTag failedWritesTag; - private final TupleTag> successfulWritesTag; + private final TupleTag>> + successfulWritesTag; private final @Nullable SerializableFunction rowMutationFn; private final BadRecordRouter badRecordRouter; Coder> elementCoder; private transient @Nullable DatasetService datasetServiceInternal = null; + private final @Nullable SerializableFunction formatRecordOnFailureFunction; + ConvertMessagesDoFn( StorageApiDynamicDestinations dynamicDestinations, BigQueryServices bqServices, String operationName, TupleTag failedWritesTag, - TupleTag> successfulWritesTag, + TupleTag>> successfulWritesTag, @Nullable SerializableFunction rowMutationFn, BadRecordRouter badRecordRouter, - Coder> elementCoder) { + Coder> elementCoder, + @Nullable SerializableFunction formatRecordOnFailureFunction) { this.dynamicDestinations = dynamicDestinations; this.messageConverters = new TwoLevelMessageConverterCache<>(operationName); this.bqServices = bqServices; @@ -132,6 +143,7 @@ public static class ConvertMessagesDoFn payload = messageConverter .toMessage(element.getValue(), rowMutationInformation) .withTimestamp(timestamp); - o.get(successfulWritesTag).output(KV.of(element.getKey(), payload)); + if (formatRecordOnFailureFunction != null) { + payload = payload.toBuilder().setOriginalElement(element.getValue()).build(); + } + o.get(successfulWritesTag) + .output(KV.of(element.getKey(), payload)); } catch (TableRowToStorageApiProto.SchemaConversionException conversionException) { TableRow tableRow; try { - tableRow = messageConverter.toTableRow(element.getValue()); + if (formatRecordOnFailureFunction != null) { + tableRow = formatRecordOnFailureFunction.apply(element.getValue()); + } else { + tableRow = messageConverter.toTableRow(element.getValue()); + } } catch (Exception e) { badRecordRouter.route(o, element, elementCoder, e, "Unable to convert value to TableRow"); return; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages256.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages256.java new file mode 100644 index 000000000000..f3878494ab64 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages256.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.BAD_RECORD_TAG; + +import com.google.api.services.bigquery.model.TableRow; +import java.io.IOException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; +import org.apache.beam.sdk.io.gcp.bigquery.StorageApiDynamicDestinations.MessageConverter; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; +import org.apache.beam.sdk.util.Preconditions; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; + +/** + * A transform that converts messages to protocol buffers in preparation for writing to BigQuery. + */ +public class StorageApiConvertMessages256 + extends PTransform>, PCollectionTuple> { + private final StorageApiDynamicDestinations dynamicDestinations; + private final BigQueryServices bqServices; + private final TupleTag failedWritesTag; + private final TupleTag> successfulWritesTag; + private final Coder errorCoder; + private final Coder> successCoder; + + private final @Nullable SerializableFunction rowMutationFn; + private final BadRecordRouter badRecordRouter; + + public StorageApiConvertMessages256( + StorageApiDynamicDestinations dynamicDestinations, + BigQueryServices bqServices, + TupleTag failedWritesTag, + TupleTag> successfulWritesTag, + Coder errorCoder, + Coder> successCoder, + @Nullable SerializableFunction rowMutationFn, + BadRecordRouter badRecordRouter) { + this.dynamicDestinations = dynamicDestinations; + this.bqServices = bqServices; + this.failedWritesTag = failedWritesTag; + this.successfulWritesTag = successfulWritesTag; + this.errorCoder = errorCoder; + this.successCoder = successCoder; + this.rowMutationFn = rowMutationFn; + this.badRecordRouter = badRecordRouter; + } + + @Override + public PCollectionTuple expand(PCollection> input) { + String operationName = input.getName() + "/" + getName(); + + PCollectionTuple result = + input.apply( + "Convert to message", + ParDo.of( + new ConvertMessagesDoFn<>( + dynamicDestinations, + bqServices, + operationName, + failedWritesTag, + successfulWritesTag, + rowMutationFn, + badRecordRouter, + input.getCoder())) + .withOutputTags( + successfulWritesTag, + TupleTagList.of(ImmutableList.of(failedWritesTag, BAD_RECORD_TAG))) + .withSideInputs(dynamicDestinations.getSideInputs())); + result.get(successfulWritesTag).setCoder(successCoder); + result.get(failedWritesTag).setCoder(errorCoder); + result.get(BAD_RECORD_TAG).setCoder(BadRecord.getCoder(input.getPipeline())); + return result; + } + + public static class ConvertMessagesDoFn + extends DoFn, KV> { + private final StorageApiDynamicDestinations dynamicDestinations; + private TwoLevelMessageConverterCache messageConverters; + private final BigQueryServices bqServices; + private final TupleTag failedWritesTag; + private final TupleTag> successfulWritesTag; + private final @Nullable SerializableFunction rowMutationFn; + private final BadRecordRouter badRecordRouter; + Coder> elementCoder; + private transient @Nullable DatasetService datasetServiceInternal = null; + + ConvertMessagesDoFn( + StorageApiDynamicDestinations dynamicDestinations, + BigQueryServices bqServices, + String operationName, + TupleTag failedWritesTag, + TupleTag> successfulWritesTag, + @Nullable SerializableFunction rowMutationFn, + BadRecordRouter badRecordRouter, + Coder> elementCoder) { + this.dynamicDestinations = dynamicDestinations; + this.messageConverters = new TwoLevelMessageConverterCache<>(operationName); + this.bqServices = bqServices; + this.failedWritesTag = failedWritesTag; + this.successfulWritesTag = successfulWritesTag; + this.rowMutationFn = rowMutationFn; + this.badRecordRouter = badRecordRouter; + this.elementCoder = elementCoder; + } + + private DatasetService getDatasetService(PipelineOptions pipelineOptions) throws IOException { + if (datasetServiceInternal == null) { + datasetServiceInternal = + bqServices.getDatasetService(pipelineOptions.as(BigQueryOptions.class)); + } + return datasetServiceInternal; + } + + @Teardown + public void onTeardown() { + try { + if (datasetServiceInternal != null) { + datasetServiceInternal.close(); + datasetServiceInternal = null; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @ProcessElement + public void processElement( + ProcessContext c, + PipelineOptions pipelineOptions, + @Element KV element, + @Timestamp Instant timestamp, + MultiOutputReceiver o) + throws Exception { + dynamicDestinations.setSideInputAccessorFromProcessContext(c); + MessageConverter messageConverter = + messageConverters.get( + element.getKey(), dynamicDestinations, getDatasetService(pipelineOptions)); + + RowMutationInformation rowMutationInformation = null; + if (rowMutationFn != null) { + rowMutationInformation = + Preconditions.checkStateNotNull(rowMutationFn).apply(element.getValue()); + } + try { + StorageApiWritePayload payload = + messageConverter + .toMessage(element.getValue(), rowMutationInformation) + .withTimestamp(timestamp); + o.get(successfulWritesTag).output(KV.of(element.getKey(), payload)); + } catch (TableRowToStorageApiProto.SchemaConversionException conversionException) { + TableRow tableRow; + try { + tableRow = messageConverter.toTableRow(element.getValue()); + } catch (Exception e) { + badRecordRouter.route(o, element, elementCoder, e, "Unable to convert value to TableRow"); + return; + } + o.get(failedWritesTag) + .output(new BigQueryStorageApiInsertError(tableRow, conversionException.toString())); + } catch (Exception e) { + badRecordRouter.route( + o, element, elementCoder, e, "Unable to convert value to StorageWriteApiPayload"); + } + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java index 62174b5c917a..815f6e29943c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java @@ -47,21 +47,24 @@ import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptor; import org.joda.time.Duration; /** This {@link PTransform} manages loads into BigQuery using the Storage API. */ public class StorageApiLoads extends PTransform>, WriteResult> { - final TupleTag> successfulConvertedRowsTag = - new TupleTag<>("successfulRows"); + final TupleTag>> + successfulConvertedRowsTag = new TupleTag<>("successfulRows"); final TupleTag failedRowsTag = new TupleTag<>("failedRows"); @Nullable TupleTag successfulWrittenRowsTag; private final Coder destinationCoder; + private final Coder elementCoder; private final StorageApiDynamicDestinations dynamicDestinations; private final @Nullable SerializableFunction rowUpdateFn; + private final @Nullable SerializableFunction formatRecordOnFailureFunction; private final CreateDisposition createDisposition; private final String kmsKey; private final Duration triggeringFrequency; @@ -81,8 +84,10 @@ public class StorageApiLoads public StorageApiLoads( Coder destinationCoder, + Coder elementCoder, StorageApiDynamicDestinations dynamicDestinations, @Nullable SerializableFunction rowUpdateFn, + @Nullable SerializableFunction formatRecordOnFailureFunction, CreateDisposition createDisposition, String kmsKey, Duration triggeringFrequency, @@ -98,8 +103,10 @@ public StorageApiLoads( BadRecordRouter badRecordRouter, ErrorHandler badRecordErrorHandler) { this.destinationCoder = destinationCoder; + this.elementCoder = elementCoder; this.dynamicDestinations = dynamicDestinations; this.rowUpdateFn = rowUpdateFn; + this.formatRecordOnFailureFunction = formatRecordOnFailureFunction; this.createDisposition = createDisposition; this.kmsKey = kmsKey; this.triggeringFrequency = triggeringFrequency; @@ -128,14 +135,16 @@ public boolean usesErrorHandler() { @Override public WriteResult expand(PCollection> input) { - Coder payloadCoder; + Coder> payloadCoder; try { payloadCoder = - input.getPipeline().getSchemaRegistry().getSchemaCoder(StorageApiWritePayload.class); + input.getPipeline().getSchemaRegistry().getSchemaCoder( + new TypeDescriptor>() { + }); } catch (NoSuchSchemaException e) { throw new RuntimeException(e); } - Coder> successCoder = + Coder>> successCoder = KvCoder.of(destinationCoder, payloadCoder); if (allowInconsistentWrites) { return expandInconsistent(input, successCoder); @@ -148,7 +157,7 @@ public WriteResult expand(PCollection> input) { public WriteResult expandInconsistent( PCollection> input, - Coder> successCoder) { + Coder>> successCoder) { PCollection> inputInGlobalWindow = input.apply("rewindowIntoGlobal", Window.into(new GlobalWindows())); @@ -163,7 +172,8 @@ public WriteResult expandInconsistent( BigQueryStorageApiInsertErrorCoder.of(), successCoder, rowUpdateFn, - badRecordRouter)); + badRecordRouter, + formatRecordOnFailureFunction)); PCollectionTuple writeRecordsResult = convertMessagesResult .get(successfulConvertedRowsTag) @@ -181,7 +191,8 @@ public WriteResult expandInconsistent( createDisposition, kmsKey, usesCdc, - defaultMissingValueInterpretation)); + defaultMissingValueInterpretation, + formatRecordOnFailureFunction)); PCollection insertErrors = PCollectionList.of(convertMessagesResult.get(failedRowsTag)) @@ -209,8 +220,8 @@ public WriteResult expandInconsistent( public WriteResult expandTriggered( PCollection> input, - Coder> successCoder, - Coder payloadCoder) { + Coder>> successCoder, + Coder> payloadCoder) { // Handle triggered, low-latency loads into BigQuery. PCollection> inputInGlobalWindow = input.apply("rewindowIntoGlobal", Window.into(new GlobalWindows())); @@ -225,9 +236,11 @@ public WriteResult expandTriggered( BigQueryStorageApiInsertErrorCoder.of(), successCoder, rowUpdateFn, - badRecordRouter)); + badRecordRouter, + formatRecordOnFailureFunction)); - PCollection, Iterable>> groupedRecords; + PCollection, Iterable>>> + groupedRecords; int maxAppendBytes = input @@ -241,21 +254,29 @@ public WriteResult expandTriggered( .get(successfulConvertedRowsTag) .apply( "GroupIntoBatches", - GroupIntoBatches.ofByteSize( + GroupIntoBatches.>ofByteSize( maxAppendBytes, - (StorageApiWritePayload e) -> (long) e.getPayload().length) + (StorageApiWritePayload e) -> + (long) e.getPayload().length) .withMaxBufferingDuration(triggeringFrequency) .withShardedKey()); } else { - PCollection, StorageApiWritePayload>> shardedRecords = - createShardedKeyValuePairs(convertMessagesResult) - .setCoder(KvCoder.of(ShardedKey.Coder.of(destinationCoder), payloadCoder)); + PCollection, StorageApiWritePayload>> + shardedRecords = + createShardedKeyValuePairs(convertMessagesResult) + .setCoder( + KvCoder.of( + ShardedKey.Coder.of(destinationCoder), + payloadCoder)); groupedRecords = shardedRecords.apply( "GroupIntoBatches", - GroupIntoBatches., StorageApiWritePayload>ofByteSize( - maxAppendBytes, (StorageApiWritePayload e) -> (long) e.getPayload().length) + GroupIntoBatches + ., StorageApiWritePayload>ofByteSize( + maxAppendBytes, + (StorageApiWritePayload e) -> + (long) e.getPayload().length) .withMaxBufferingDuration(triggeringFrequency)); } PCollectionTuple writeRecordsResult = @@ -273,7 +294,8 @@ public WriteResult expandTriggered( successfulWrittenRowsTag, autoUpdateSchema, ignoreUnknownValues, - defaultMissingValueInterpretation)); + defaultMissingValueInterpretation, + formatRecordOnFailureFunction)); PCollection insertErrors = PCollectionList.of(convertMessagesResult.get(failedRowsTag)) @@ -300,7 +322,7 @@ public WriteResult expandTriggered( successfulWrittenRows); } - private PCollection, StorageApiWritePayload>> + private PCollection,StorageApiWritePayload>> createShardedKeyValuePairs(PCollectionTuple pCollection) { return pCollection .get(successfulConvertedRowsTag) @@ -308,8 +330,8 @@ public WriteResult expandTriggered( "AddShard", ParDo.of( new DoFn< - KV, - KV, StorageApiWritePayload>>() { + KV>, + KV, StorageApiWritePayload>>() { int shardNumber; @Setup @@ -319,8 +341,10 @@ public void setup() { @ProcessElement public void processElement( - @Element KV element, - OutputReceiver, StorageApiWritePayload>> o) { + @Element KV> element, + OutputReceiver< + KV, KV>> + o) { DestinationT destination = element.getKey(); ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES); buffer.putInt(++shardNumber % numShards); @@ -331,7 +355,7 @@ public void processElement( public WriteResult expandUntriggered( PCollection> input, - Coder> successCoder) { + Coder>> successCoder) { PCollection> inputInGlobalWindow = input.apply( "rewindowIntoGlobal", Window.>into(new GlobalWindows())); @@ -346,7 +370,8 @@ public WriteResult expandUntriggered( BigQueryStorageApiInsertErrorCoder.of(), successCoder, rowUpdateFn, - badRecordRouter)); + badRecordRouter, + formatRecordOnFailureFunction)); PCollectionTuple writeRecordsResult = convertMessagesResult @@ -365,7 +390,8 @@ public WriteResult expandUntriggered( createDisposition, kmsKey, usesCdc, - defaultMissingValueInterpretation)); + defaultMissingValueInterpretation, + formatRecordOnFailureFunction)); PCollection insertErrors = PCollectionList.of(convertMessagesResult.get(failedRowsTag)) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads256.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads256.java new file mode 100644 index 000000000000..8c4efd64e00d --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads256.java @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.BAD_RECORD_TAG; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.bigquery.storage.v1.AppendRowsRequest; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.ThreadLocalRandom; +import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.GroupIntoBatches; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.ThrowingBadRecordRouter; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.util.ShardedKey; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTag; +import org.joda.time.Duration; + +/** This {@link PTransform} manages loads into BigQuery using the Storage API. */ +public class StorageApiLoads256 + extends PTransform>, WriteResult> { + final TupleTag> successfulConvertedRowsTag = + new TupleTag<>("successfulRows"); + + final TupleTag failedRowsTag = new TupleTag<>("failedRows"); + + @Nullable TupleTag successfulWrittenRowsTag; + private final Coder destinationCoder; + private final StorageApiDynamicDestinations dynamicDestinations; + + private final @Nullable SerializableFunction rowUpdateFn; + private final CreateDisposition createDisposition; + private final String kmsKey; + private final Duration triggeringFrequency; + private final BigQueryServices bqServices; + private final int numShards; + private final boolean allowInconsistentWrites; + private final boolean allowAutosharding; + private final boolean autoUpdateSchema; + private final boolean ignoreUnknownValues; + private final boolean usesCdc; + + private final AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation; + + private final BadRecordRouter badRecordRouter; + + private final ErrorHandler badRecordErrorHandler; + + public StorageApiLoads256( + Coder destinationCoder, + StorageApiDynamicDestinations dynamicDestinations, + @Nullable SerializableFunction rowUpdateFn, + CreateDisposition createDisposition, + String kmsKey, + Duration triggeringFrequency, + BigQueryServices bqServices, + int numShards, + boolean allowInconsistentWrites, + boolean allowAutosharding, + boolean autoUpdateSchema, + boolean ignoreUnknownValues, + boolean propagateSuccessfulStorageApiWrites, + boolean usesCdc, + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation, + BadRecordRouter badRecordRouter, + ErrorHandler badRecordErrorHandler) { + this.destinationCoder = destinationCoder; + this.dynamicDestinations = dynamicDestinations; + this.rowUpdateFn = rowUpdateFn; + this.createDisposition = createDisposition; + this.kmsKey = kmsKey; + this.triggeringFrequency = triggeringFrequency; + this.bqServices = bqServices; + this.numShards = numShards; + this.allowInconsistentWrites = allowInconsistentWrites; + this.allowAutosharding = allowAutosharding; + this.autoUpdateSchema = autoUpdateSchema; + this.ignoreUnknownValues = ignoreUnknownValues; + if (propagateSuccessfulStorageApiWrites) { + this.successfulWrittenRowsTag = new TupleTag<>("successfulPublishedRowsTag"); + } + this.usesCdc = usesCdc; + this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + this.badRecordRouter = badRecordRouter; + this.badRecordErrorHandler = badRecordErrorHandler; + } + + public TupleTag getFailedRowsTag() { + return failedRowsTag; + } + + public boolean usesErrorHandler() { + return !(badRecordRouter instanceof ThrowingBadRecordRouter); + } + + @Override + public WriteResult expand(PCollection> input) { + Coder payloadCoder; + try { + payloadCoder = + input.getPipeline().getSchemaRegistry().getSchemaCoder(StorageApiWritePayload.class); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e); + } + Coder> successCoder = + KvCoder.of(destinationCoder, payloadCoder); + if (allowInconsistentWrites) { + return expandInconsistent(input, successCoder); + } else { + return triggeringFrequency != null + ? expandTriggered(input, successCoder, payloadCoder) + : expandUntriggered(input, successCoder); + } + } + + public WriteResult expandInconsistent( + PCollection> input, + Coder> successCoder) { + PCollection> inputInGlobalWindow = + input.apply("rewindowIntoGlobal", Window.into(new GlobalWindows())); + + PCollectionTuple convertMessagesResult = + inputInGlobalWindow.apply( + "Convert", + new StorageApiConvertMessages256<>( + dynamicDestinations, + bqServices, + failedRowsTag, + successfulConvertedRowsTag, + BigQueryStorageApiInsertErrorCoder.of(), + successCoder, + rowUpdateFn, + badRecordRouter)); + PCollectionTuple writeRecordsResult = + convertMessagesResult + .get(successfulConvertedRowsTag) + .apply( + "StorageApiWriteInconsistent", + new StorageApiWriteRecordsInconsistent256<>( + dynamicDestinations, + bqServices, + failedRowsTag, + successfulWrittenRowsTag, + BigQueryStorageApiInsertErrorCoder.of(), + TableRowJsonCoder.of(), + autoUpdateSchema, + ignoreUnknownValues, + createDisposition, + kmsKey, + usesCdc, + defaultMissingValueInterpretation)); + + PCollection insertErrors = + PCollectionList.of(convertMessagesResult.get(failedRowsTag)) + .and(writeRecordsResult.get(failedRowsTag)) + .apply("flattenErrors", Flatten.pCollections()); + @Nullable PCollection successfulWrittenRows = null; + if (successfulWrittenRowsTag != null) { + successfulWrittenRows = writeRecordsResult.get(successfulWrittenRowsTag); + } + + addErrorCollections(convertMessagesResult, writeRecordsResult); + + return WriteResult.in( + input.getPipeline(), + null, + null, + null, + null, + null, + failedRowsTag, + insertErrors, + successfulWrittenRowsTag, + successfulWrittenRows); + } + + public WriteResult expandTriggered( + PCollection> input, + Coder> successCoder, + Coder payloadCoder) { + // Handle triggered, low-latency loads into BigQuery. + PCollection> inputInGlobalWindow = + input.apply("rewindowIntoGlobal", Window.into(new GlobalWindows())); + PCollectionTuple convertMessagesResult = + inputInGlobalWindow.apply( + "Convert", + new StorageApiConvertMessages256<>( + dynamicDestinations, + bqServices, + failedRowsTag, + successfulConvertedRowsTag, + BigQueryStorageApiInsertErrorCoder.of(), + successCoder, + rowUpdateFn, + badRecordRouter)); + + PCollection, Iterable>> groupedRecords; + + int maxAppendBytes = + input + .getPipeline() + .getOptions() + .as(BigQueryOptions.class) + .getStorageApiAppendThresholdBytes(); + if (this.allowAutosharding) { + groupedRecords = + convertMessagesResult + .get(successfulConvertedRowsTag) + .apply( + "GroupIntoBatches", + GroupIntoBatches.ofByteSize( + maxAppendBytes, + (StorageApiWritePayload e) -> (long) e.getPayload().length) + .withMaxBufferingDuration(triggeringFrequency) + .withShardedKey()); + + } else { + PCollection, StorageApiWritePayload>> shardedRecords = + createShardedKeyValuePairs(convertMessagesResult) + .setCoder(KvCoder.of(ShardedKey.Coder.of(destinationCoder), payloadCoder)); + groupedRecords = + shardedRecords.apply( + "GroupIntoBatches", + GroupIntoBatches., StorageApiWritePayload>ofByteSize( + maxAppendBytes, (StorageApiWritePayload e) -> (long) e.getPayload().length) + .withMaxBufferingDuration(triggeringFrequency)); + } + PCollectionTuple writeRecordsResult = + groupedRecords.apply( + "StorageApiWriteSharded", + new StorageApiWritesShardedRecords256<>( + dynamicDestinations, + createDisposition, + kmsKey, + bqServices, + destinationCoder, + BigQueryStorageApiInsertErrorCoder.of(), + TableRowJsonCoder.of(), + failedRowsTag, + successfulWrittenRowsTag, + autoUpdateSchema, + ignoreUnknownValues, + defaultMissingValueInterpretation)); + + PCollection insertErrors = + PCollectionList.of(convertMessagesResult.get(failedRowsTag)) + .and(writeRecordsResult.get(failedRowsTag)) + .apply("flattenErrors", Flatten.pCollections()); + + @Nullable PCollection successfulWrittenRows = null; + if (successfulWrittenRowsTag != null) { + successfulWrittenRows = writeRecordsResult.get(successfulWrittenRowsTag); + } + + addErrorCollections(convertMessagesResult, writeRecordsResult); + + return WriteResult.in( + input.getPipeline(), + null, + null, + null, + null, + null, + failedRowsTag, + insertErrors, + successfulWrittenRowsTag, + successfulWrittenRows); + } + + private PCollection, StorageApiWritePayload>> + createShardedKeyValuePairs(PCollectionTuple pCollection) { + return pCollection + .get(successfulConvertedRowsTag) + .apply( + "AddShard", + ParDo.of( + new DoFn< + KV, + KV, StorageApiWritePayload>>() { + int shardNumber; + + @Setup + public void setup() { + shardNumber = ThreadLocalRandom.current().nextInt(numShards); + } + + @ProcessElement + public void processElement( + @Element KV element, + OutputReceiver, StorageApiWritePayload>> o) { + DestinationT destination = element.getKey(); + ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES); + buffer.putInt(++shardNumber % numShards); + o.output(KV.of(ShardedKey.of(destination, buffer.array()), element.getValue())); + } + })); + } + + public WriteResult expandUntriggered( + PCollection> input, + Coder> successCoder) { + PCollection> inputInGlobalWindow = + input.apply( + "rewindowIntoGlobal", Window.>into(new GlobalWindows())); + PCollectionTuple convertMessagesResult = + inputInGlobalWindow.apply( + "Convert", + new StorageApiConvertMessages256<>( + dynamicDestinations, + bqServices, + failedRowsTag, + successfulConvertedRowsTag, + BigQueryStorageApiInsertErrorCoder.of(), + successCoder, + rowUpdateFn, + badRecordRouter)); + + PCollectionTuple writeRecordsResult = + convertMessagesResult + .get(successfulConvertedRowsTag) + .apply( + "StorageApiWriteUnsharded", + new StorageApiWriteUnshardedRecords256<>( + dynamicDestinations, + bqServices, + failedRowsTag, + successfulWrittenRowsTag, + BigQueryStorageApiInsertErrorCoder.of(), + TableRowJsonCoder.of(), + autoUpdateSchema, + ignoreUnknownValues, + createDisposition, + kmsKey, + usesCdc, + defaultMissingValueInterpretation)); + + PCollection insertErrors = + PCollectionList.of(convertMessagesResult.get(failedRowsTag)) + .and(writeRecordsResult.get(failedRowsTag)) + .apply("flattenErrors", Flatten.pCollections()); + + @Nullable PCollection successfulWrittenRows = null; + if (successfulWrittenRowsTag != null) { + successfulWrittenRows = writeRecordsResult.get(successfulWrittenRowsTag); + } + + addErrorCollections(convertMessagesResult, writeRecordsResult); + + return WriteResult.in( + input.getPipeline(), + null, + null, + null, + null, + null, + failedRowsTag, + insertErrors, + successfulWrittenRowsTag, + successfulWrittenRows); + } + + private void addErrorCollections( + PCollectionTuple convertMessagesResult, PCollectionTuple writeRecordsResult) { + if (usesErrorHandler()) { + PCollection badRecords = + PCollectionList.of( + convertMessagesResult + .get(failedRowsTag) + .apply( + "ConvertMessageFailuresToBadRecord", + ParDo.of( + new ConvertInsertErrorToBadRecord( + "Failed to Convert to Storage API Message")))) + .and(convertMessagesResult.get(BAD_RECORD_TAG)) + .and( + writeRecordsResult + .get(failedRowsTag) + .apply( + "WriteRecordFailuresToBadRecord", + ParDo.of( + new ConvertInsertErrorToBadRecord( + "Failed to Write Message to Storage API")))) + .apply("flattenBadRecords", Flatten.pCollections()); + badRecordErrorHandler.addErrorCollection(badRecords); + } + } + + private static class ConvertInsertErrorToBadRecord + extends DoFn { + + private final String errorMessage; + + public ConvertInsertErrorToBadRecord(String errorMessage) { + this.errorMessage = errorMessage; + } + + @ProcessElement + public void processElement( + @Element BigQueryStorageApiInsertError bigQueryStorageApiInsertError, + OutputReceiver outputReceiver) + throws IOException { + outputReceiver.output( + BadRecord.fromExceptionInformation( + bigQueryStorageApiInsertError, + BigQueryStorageApiInsertErrorCoder.of(), + null, + errorMessage)); + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritePayload.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritePayload.java index 5b6f27949870..797f58b8435e 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritePayload.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritePayload.java @@ -30,7 +30,7 @@ /** Class used to wrap elements being sent to the Storage API sinks. */ @AutoValue @DefaultSchema(AutoValueSchema.class) -public abstract class StorageApiWritePayload { +public abstract class StorageApiWritePayload { @SuppressWarnings("mutable") public abstract byte[] getPayload(); @@ -39,34 +39,38 @@ public abstract class StorageApiWritePayload { public abstract @Nullable Instant getTimestamp(); + public abstract @Nullable ElementT originalElement(); + @AutoValue.Builder - public abstract static class Builder { - public abstract Builder setPayload(byte[] value); + public abstract static class Builder { + public abstract Builder setPayload(byte[] value); + + public abstract Builder setUnknownFieldsPayload(@Nullable byte[] value); - public abstract Builder setUnknownFieldsPayload(@Nullable byte[] value); + public abstract Builder setTimestamp(@Nullable Instant value); - public abstract Builder setTimestamp(@Nullable Instant value); + public abstract Builder setOriginalElement(@Nullable ElementT originalElement); - public abstract StorageApiWritePayload build(); + public abstract StorageApiWritePayload build(); } - public abstract Builder toBuilder(); + public abstract Builder toBuilder(); @SuppressWarnings("nullness") - static StorageApiWritePayload of(byte[] payload, @Nullable TableRow unknownFields) + static StorageApiWritePayload of(byte[] payload, @Nullable TableRow unknownFields) throws IOException { @Nullable byte[] unknownFieldsPayload = null; if (unknownFields != null) { unknownFieldsPayload = CoderUtils.encodeToByteArray(TableRowJsonCoder.of(), unknownFields); } - return new AutoValue_StorageApiWritePayload.Builder() + return new AutoValue_StorageApiWritePayload.Builder() .setPayload(payload) .setUnknownFieldsPayload(unknownFieldsPayload) .setTimestamp(null) .build(); } - public StorageApiWritePayload withTimestamp(Instant instant) { + public StorageApiWritePayload withTimestamp(Instant instant) { return toBuilder().setTimestamp(instant).build(); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java index 022ee1fbed08..2a54c886d980 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java @@ -23,6 +23,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; @@ -37,7 +38,8 @@ */ @SuppressWarnings("FutureReturnValueIgnored") public class StorageApiWriteRecordsInconsistent - extends PTransform>, PCollectionTuple> { + extends PTransform< + PCollection>>, PCollectionTuple> { private final StorageApiDynamicDestinations dynamicDestinations; private final BigQueryServices bqServices; private final TupleTag failedRowsTag; @@ -51,6 +53,7 @@ public class StorageApiWriteRecordsInconsistent private final @Nullable String kmsKey; private final boolean usesCdc; private final AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation; + private final @Nullable SerializableFunction formatRecordOnFailureFunction; public StorageApiWriteRecordsInconsistent( StorageApiDynamicDestinations dynamicDestinations, @@ -64,7 +67,8 @@ public StorageApiWriteRecordsInconsistent( BigQueryIO.Write.CreateDisposition createDisposition, @Nullable String kmsKey, boolean usesCdc, - AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation) { + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation, + @Nullable SerializableFunction formatRecordOnFailureFunction) { this.dynamicDestinations = dynamicDestinations; this.bqServices = bqServices; this.failedRowsTag = failedRowsTag; @@ -77,10 +81,12 @@ public StorageApiWriteRecordsInconsistent( this.kmsKey = kmsKey; this.usesCdc = usesCdc; this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + this.formatRecordOnFailureFunction = formatRecordOnFailureFunction; } @Override - public PCollectionTuple expand(PCollection> input) { + public PCollectionTuple expand( + PCollection>> input) { String operationName = input.getName() + "/" + getName(); BigQueryOptions bigQueryOptions = input.getPipeline().getOptions().as(BigQueryOptions.class); // Append records to the Storage API streams. @@ -108,7 +114,8 @@ public PCollectionTuple expand(PCollection + extends PTransform>, PCollectionTuple> { + private final StorageApiDynamicDestinations dynamicDestinations; + private final BigQueryServices bqServices; + private final TupleTag failedRowsTag; + private final @Nullable TupleTag successfulRowsTag; + private final TupleTag> finalizeTag = new TupleTag<>("finalizeTag"); + private final Coder failedRowsCoder; + private final Coder successfulRowsCoder; + private final boolean autoUpdateSchema; + private final boolean ignoreUnknownValues; + private final BigQueryIO.Write.CreateDisposition createDisposition; + private final @Nullable String kmsKey; + private final boolean usesCdc; + private final AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation; + + public StorageApiWriteRecordsInconsistent256( + StorageApiDynamicDestinations dynamicDestinations, + BigQueryServices bqServices, + TupleTag failedRowsTag, + @Nullable TupleTag successfulRowsTag, + Coder failedRowsCoder, + Coder successfulRowsCoder, + boolean autoUpdateSchema, + boolean ignoreUnknownValues, + BigQueryIO.Write.CreateDisposition createDisposition, + @Nullable String kmsKey, + boolean usesCdc, + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation) { + this.dynamicDestinations = dynamicDestinations; + this.bqServices = bqServices; + this.failedRowsTag = failedRowsTag; + this.failedRowsCoder = failedRowsCoder; + this.successfulRowsCoder = successfulRowsCoder; + this.successfulRowsTag = successfulRowsTag; + this.autoUpdateSchema = autoUpdateSchema; + this.ignoreUnknownValues = ignoreUnknownValues; + this.createDisposition = createDisposition; + this.kmsKey = kmsKey; + this.usesCdc = usesCdc; + this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + } + + @Override + public PCollectionTuple expand(PCollection> input) { + String operationName = input.getName() + "/" + getName(); + BigQueryOptions bigQueryOptions = input.getPipeline().getOptions().as(BigQueryOptions.class); + // Append records to the Storage API streams. + TupleTagList tupleTagList = TupleTagList.of(failedRowsTag); + if (successfulRowsTag != null) { + tupleTagList = tupleTagList.and(successfulRowsTag); + } + PCollectionTuple result = + input.apply( + "Write Records", + ParDo.of( + new StorageApiWriteUnshardedRecords256.WriteRecordsDoFn<>( + operationName, + dynamicDestinations, + bqServices, + true, + bigQueryOptions.getStorageApiAppendThresholdBytes(), + bigQueryOptions.getStorageApiAppendThresholdRecordCount(), + bigQueryOptions.getNumStorageWriteApiStreamAppendClients(), + finalizeTag, + failedRowsTag, + successfulRowsTag, + autoUpdateSchema, + ignoreUnknownValues, + createDisposition, + kmsKey, + usesCdc, + defaultMissingValueInterpretation)) + .withOutputTags(finalizeTag, tupleTagList) + .withSideInputs(dynamicDestinations.getSideInputs())); + result.get(failedRowsTag).setCoder(failedRowsCoder); + if (successfulRowsTag != null) { + result.get(successfulRowsTag).setCoder(successfulRowsCoder); + } + return result; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java index 846e7e3bddcb..6ea818c0cd93 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords.java @@ -36,6 +36,7 @@ import io.grpc.Status; import java.io.IOException; import java.time.Instant; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -67,6 +68,7 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Reshuffle; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.Preconditions; @@ -96,7 +98,8 @@ */ @SuppressWarnings({"FutureReturnValueIgnored"}) public class StorageApiWriteUnshardedRecords - extends PTransform>, PCollectionTuple> { + extends PTransform< + PCollection>>, PCollectionTuple> { private static final Logger LOG = LoggerFactory.getLogger(StorageApiWriteUnshardedRecords.class); private final StorageApiDynamicDestinations dynamicDestinations; @@ -114,6 +117,8 @@ public class StorageApiWriteUnshardedRecords private final boolean usesCdc; private final AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation; + private final @Nullable SerializableFunction formatRecordOnFailureFunction; + /** * The Guava cache object is thread-safe. However our protocol requires that client pin the * StreamAppendClient after looking up the cache, and we must ensure that the cache is not @@ -171,7 +176,8 @@ public StorageApiWriteUnshardedRecords( BigQueryIO.Write.CreateDisposition createDisposition, @Nullable String kmsKey, boolean usesCdc, - AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation) { + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation, + @Nullable SerializableFunction formatRecordOnFailureFunction) { this.dynamicDestinations = dynamicDestinations; this.bqServices = bqServices; this.failedRowsTag = failedRowsTag; @@ -184,10 +190,12 @@ public StorageApiWriteUnshardedRecords( this.kmsKey = kmsKey; this.usesCdc = usesCdc; this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + this.formatRecordOnFailureFunction = formatRecordOnFailureFunction; } @Override - public PCollectionTuple expand(PCollection> input) { + public PCollectionTuple expand( + PCollection>> input) { String operationName = input.getName() + "/" + getName(); BigQueryOptions options = input.getPipeline().getOptions().as(BigQueryOptions.class); org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument( @@ -217,7 +225,8 @@ public PCollectionTuple expand(PCollection - extends DoFn, KV> { + extends DoFn>, KV> { private final Counter forcedFlushes = Metrics.counter(WriteRecordsDoFn.class, "forcedFlushes"); private final TupleTag> finalizeTag; private final TupleTag failedRowsTag; @@ -248,19 +257,27 @@ static class WriteRecordsDoFn private final @Nullable String kmsKey; private final boolean usesCdc; private final AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation; + private final @Nullable SerializableFunction formatRecordOnFailureFunction; - static class AppendRowsContext extends RetryManager.Operation.Context { + static class AppendRowsContext + extends RetryManager.Operation.Context { long offset; ProtoRows protoRows; List timestamps; + List originalMessages; + int failureCount; public AppendRowsContext( - long offset, ProtoRows protoRows, List timestamps) { + long offset, + ProtoRows protoRows, + List timestamps, + List originalMessages) { this.offset = offset; this.protoRows = protoRows; this.timestamps = timestamps; + this.originalMessages = originalMessages; this.failureCount = 0; } } @@ -272,6 +289,8 @@ class DestinationState { private @Nullable AppendClientInfo appendClientInfo = null; private long currentOffset = 0; private List pendingMessages; + + private List originalMessages; private List pendingTimestamps; private transient @Nullable WriteStreamService maybeWriteStreamService; private final Counter recordsAppended = @@ -311,6 +330,7 @@ public DestinationState( this.tableUrn = tableUrn; this.shortTableUrn = shortTableUrn; this.pendingMessages = Lists.newArrayList(); + this.originalMessages = Lists.newArrayList(); this.pendingTimestamps = Lists.newArrayList(); this.maybeWriteStreamService = writeStreamService; this.useDefaultStream = useDefaultStream; @@ -531,7 +551,7 @@ void invalidateWriteStream() { } void addMessage( - StorageApiWritePayload payload, + StorageApiWritePayload payload, org.joda.time.Instant elementTs, OutputReceiver failedRowsReceiver) throws Exception { @@ -567,12 +587,13 @@ void addMessage( } } pendingMessages.add(payloadBytes); + originalMessages.add(payload.originalElement()); org.joda.time.Instant timestamp = payload.getTimestamp(); pendingTimestamps.add(timestamp != null ? timestamp : elementTs); } long flush( - RetryManager retryManager, + RetryManager> retryManager, OutputReceiver failedRowsReceiver, @Nullable OutputReceiver successfulRowsReceiver) throws Exception { @@ -602,13 +623,18 @@ long flush( for (int i = 0; i < inserts.getSerializedRowsCount(); ++i) { ByteString rowBytes = inserts.getSerializedRows(i); org.joda.time.Instant timestamp = insertTimestamps.get(i); - TableRow failedRow = - TableRowToStorageApiProto.tableRowFromMessage( - DynamicMessage.parseFrom( - TableRowToStorageApiProto.wrapDescriptorProto( - getAppendClientInfo(true, null).getDescriptor()), - rowBytes), - true); + TableRow failedRow; + if (formatRecordOnFailureFunction != null) { + failedRow = formatRecordOnFailureFunction.apply(originalMessages.get(i)); + } else { + failedRow = + TableRowToStorageApiProto.tableRowFromMessage( + DynamicMessage.parseFrom( + TableRowToStorageApiProto.wrapDescriptorProto( + getAppendClientInfo(true, null).getDescriptor()), + rowBytes), + true); + } failedRowsReceiver.outputWithTimestamp( new BigQueryStorageApiInsertError( failedRow, "Row payload too large. Maximum size " + maxRequestSize), @@ -621,6 +647,7 @@ long flush( shortTableUrn) .inc(numRowsFailed); rowsSentToFailedRowsCollection.inc(numRowsFailed); + originalMessages.clear(); return 0; } @@ -630,8 +657,10 @@ long flush( offset = this.currentOffset; this.currentOffset += inserts.getSerializedRowsCount(); } - AppendRowsContext appendRowsContext = - new AppendRowsContext(offset, inserts, insertTimestamps); + AppendRowsContext appendRowsContext = + new AppendRowsContext<>( + offset, inserts, insertTimestamps, new ArrayList<>(originalMessages)); + originalMessages.clear(); retryManager.addOperation( c -> { @@ -660,7 +689,7 @@ long flush( } }, contexts -> { - AppendRowsContext failedContext = + AppendRowsContext failedContext = Preconditions.checkStateNotNull(Iterables.getFirst(contexts, null)); BigQuerySinkMetrics.reportFailedRPCMetrics( failedContext, BigQuerySinkMetrics.RpcMethod.APPEND_ROWS, shortTableUrn); @@ -679,14 +708,21 @@ long flush( ByteString protoBytes = failedContext.protoRows.getSerializedRows(failedIndex); org.joda.time.Instant timestamp = failedContext.timestamps.get(failedIndex); try { - TableRow failedRow = - TableRowToStorageApiProto.tableRowFromMessage( - DynamicMessage.parseFrom( - TableRowToStorageApiProto.wrapDescriptorProto( - Preconditions.checkStateNotNull(appendClientInfo) - .getDescriptor()), - protoBytes), - true); + TableRow failedRow; + if (formatRecordOnFailureFunction != null) { + failedRow = + formatRecordOnFailureFunction.apply( + failedContext.originalMessages.get(failedIndex)); + } else { + failedRow = + TableRowToStorageApiProto.tableRowFromMessage( + DynamicMessage.parseFrom( + TableRowToStorageApiProto.wrapDescriptorProto( + Preconditions.checkStateNotNull(appendClientInfo) + .getDescriptor()), + protoBytes), + true); + } failedRowsReceiver.outputWithTimestamp( new BigQueryStorageApiInsertError( failedRow, error.getRowIndexToErrorMessage().get(failedIndex)), @@ -722,7 +758,7 @@ long flush( // Since we removed rows, we need to update the insert offsets for all remaining // rows. long newOffset = failedContext.offset; - for (AppendRowsContext context : contexts) { + for (AppendRowsContext context : contexts) { context.offset = newOffset; newOffset += context.protoRows.getSerializedRowsCount(); } @@ -833,7 +869,7 @@ long flush( return inserts.getSerializedRowsCount(); } - String retrieveErrorDetails(Iterable failedContext) { + String retrieveErrorDetails(Iterable> failedContext) { return StreamSupport.stream(failedContext.spliterator(), false) .<@Nullable Throwable>map(AppendRowsContext::getError) .filter(Objects::nonNull) @@ -899,7 +935,8 @@ void postFlush() { BigQueryIO.Write.CreateDisposition createDisposition, @Nullable String kmsKey, boolean usesCdc, - AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation) { + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation, + @Nullable SerializableFunction formatRecordOnFailureFunction) { this.messageConverters = new TwoLevelMessageConverterCache<>(operationName); this.dynamicDestinations = dynamicDestinations; this.bqServices = bqServices; @@ -916,6 +953,7 @@ void postFlush() { this.kmsKey = kmsKey; this.usesCdc = usesCdc; this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + this.formatRecordOnFailureFunction = formatRecordOnFailureFunction; } boolean shouldFlush() { @@ -939,12 +977,12 @@ void flushAll( OutputReceiver failedRowsReceiver, @Nullable OutputReceiver successfulRowsReceiver) throws Exception { - List> retryManagers = + List>> retryManagers = Lists.newArrayListWithCapacity(Preconditions.checkStateNotNull(destinations).size()); long numRowsWritten = 0; for (DestinationState destinationState : Preconditions.checkStateNotNull(destinations).values()) { - RetryManager retryManager = + RetryManager> retryManager = new RetryManager<>( Duration.standardSeconds(1), Duration.standardSeconds(10), @@ -961,7 +999,8 @@ void flushAll( // await is called, so // this approach means that if one call fais, it has to wait for all prior calls to complete // before a retry happens. - for (RetryManager retryManager : retryManagers) { + for (RetryManager> retryManager : + retryManagers) { retryManager.await(); } } @@ -1048,7 +1087,7 @@ DestinationState createDestinationState( public void process( ProcessContext c, PipelineOptions pipelineOptions, - @Element KV element, + @Element KV> element, @Timestamp org.joda.time.Instant elementTs, MultiOutputReceiver o) throws Exception { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords256.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords256.java new file mode 100644 index 000000000000..f1ab023e30f5 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteUnshardedRecords256.java @@ -0,0 +1,1154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutures; +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.bigquery.storage.v1.AppendRowsRequest; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.Exceptions; +import com.google.cloud.bigquery.storage.v1.ProtoRows; +import com.google.cloud.bigquery.storage.v1.TableSchema; +import com.google.cloud.bigquery.storage.v1.WriteStream; +import com.google.cloud.bigquery.storage.v1.WriteStream.Type; +import com.google.protobuf.ByteString; +import com.google.protobuf.DescriptorProtos; +import com.google.protobuf.DynamicMessage; +import io.grpc.Status; +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StreamAppendClient; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.WriteStreamService; +import org.apache.beam.sdk.io.gcp.bigquery.RetryManager.RetryType; +import org.apache.beam.sdk.io.gcp.bigquery.StorageApiDynamicDestinations.MessageConverter; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Reshuffle; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.Preconditions; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalNotification; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Write records to the Storage API using a standard batch approach. PENDING streams are used, which + * do not become visible until they are finalized and committed. Each input bundle to the DoFn + * creates a stream per output table, appends all records in the bundle to the stream, and schedules + * a finalize/commit operation at the end. + */ +@SuppressWarnings({"FutureReturnValueIgnored"}) +public class StorageApiWriteUnshardedRecords256 + extends PTransform>, PCollectionTuple> { + private static final Logger LOG = + LoggerFactory.getLogger(StorageApiWriteUnshardedRecords256.class); + + private final StorageApiDynamicDestinations dynamicDestinations; + private final BigQueryServices bqServices; + private final TupleTag failedRowsTag; + private final @Nullable TupleTag successfulRowsTag; + private final TupleTag> finalizeTag = new TupleTag<>("finalizeTag"); + private final Coder failedRowsCoder; + private final Coder successfulRowsCoder; + private final boolean autoUpdateSchema; + private final boolean ignoreUnknownValues; + private static final ExecutorService closeWriterExecutor = Executors.newCachedThreadPool(); + private final BigQueryIO.Write.CreateDisposition createDisposition; + private final @Nullable String kmsKey; + private final boolean usesCdc; + private final AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation; + + /** + * The Guava cache object is thread-safe. However our protocol requires that client pin the + * StreamAppendClient after looking up the cache, and we must ensure that the cache is not + * accessed in between the lookup and the pin (any access of the cache could trigger element + * expiration). Therefore most used of APPEND_CLIENTS should synchronize. + */ + private static final Cache APPEND_CLIENTS = + CacheBuilder.newBuilder() + .expireAfterAccess(15, TimeUnit.MINUTES) + .removalListener( + (RemovalNotification removal) -> { + LOG.info("Expiring append client for " + removal.getKey()); + final @Nullable AppendClientInfo appendClientInfo = removal.getValue(); + if (appendClientInfo != null) { + appendClientInfo.close(); + } + }) + .build(); + + static void clearCache() { + APPEND_CLIENTS.invalidateAll(); + } + + // Run a closure asynchronously, ignoring failures. + private interface ThrowingRunnable { + void run() throws Exception; + } + + private static void runAsyncIgnoreFailure(ExecutorService executor, ThrowingRunnable task) { + executor.submit( + () -> { + try { + task.run(); + } catch (Exception e) { + String msg = + e.toString() + + "\n" + + Arrays.stream(e.getStackTrace()) + .map(StackTraceElement::toString) + .collect(Collectors.joining("\n")); + System.err.println("Exception happened while executing async task. Ignoring: " + msg); + } + }); + } + + public StorageApiWriteUnshardedRecords256( + StorageApiDynamicDestinations dynamicDestinations, + BigQueryServices bqServices, + TupleTag failedRowsTag, + @Nullable TupleTag successfulRowsTag, + Coder failedRowsCoder, + Coder successfulRowsCoder, + boolean autoUpdateSchema, + boolean ignoreUnknownValues, + BigQueryIO.Write.CreateDisposition createDisposition, + @Nullable String kmsKey, + boolean usesCdc, + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation) { + this.dynamicDestinations = dynamicDestinations; + this.bqServices = bqServices; + this.failedRowsTag = failedRowsTag; + this.successfulRowsTag = successfulRowsTag; + this.failedRowsCoder = failedRowsCoder; + this.successfulRowsCoder = successfulRowsCoder; + this.autoUpdateSchema = autoUpdateSchema; + this.ignoreUnknownValues = ignoreUnknownValues; + this.createDisposition = createDisposition; + this.kmsKey = kmsKey; + this.usesCdc = usesCdc; + this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + } + + @Override + public PCollectionTuple expand(PCollection> input) { + String operationName = input.getName() + "/" + getName(); + BigQueryOptions options = input.getPipeline().getOptions().as(BigQueryOptions.class); + org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument( + !options.getUseStorageApiConnectionPool(), + "useStorageApiConnectionPool only supported " + "when using STORAGE_API_AT_LEAST_ONCE"); + TupleTagList tupleTagList = TupleTagList.of(failedRowsTag); + if (successfulRowsTag != null) { + tupleTagList = tupleTagList.and(successfulRowsTag); + } + PCollectionTuple writeResults = + input.apply( + "Write Records", + ParDo.of( + new WriteRecordsDoFn<>( + operationName, + dynamicDestinations, + bqServices, + false, + options.getStorageApiAppendThresholdBytes(), + options.getStorageApiAppendThresholdRecordCount(), + options.getNumStorageWriteApiStreamAppendClients(), + finalizeTag, + failedRowsTag, + successfulRowsTag, + autoUpdateSchema, + ignoreUnknownValues, + createDisposition, + kmsKey, + usesCdc, + defaultMissingValueInterpretation)) + .withOutputTags(finalizeTag, tupleTagList) + .withSideInputs(dynamicDestinations.getSideInputs())); + + writeResults + .get(finalizeTag) + .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())) + // Calling Reshuffle makes the output stable - once this completes, the append operations + // will not retry. + // TODO(reuvenlax): This should use RequiresStableInput instead. + .apply("Reshuffle", Reshuffle.of()) + .apply("Finalize writes", ParDo.of(new StorageApiFinalizeWritesDoFn(bqServices))); + writeResults.get(failedRowsTag).setCoder(failedRowsCoder); + if (successfulRowsTag != null) { + writeResults.get(successfulRowsTag).setCoder(successfulRowsCoder); + } + return writeResults; + } + + static class WriteRecordsDoFn + extends DoFn, KV> { + private final Counter forcedFlushes = Metrics.counter(WriteRecordsDoFn.class, "forcedFlushes"); + private final TupleTag> finalizeTag; + private final TupleTag failedRowsTag; + private final @Nullable TupleTag successfulRowsTag; + private final boolean autoUpdateSchema; + private final boolean ignoreUnknownValues; + private final BigQueryIO.Write.CreateDisposition createDisposition; + private final @Nullable String kmsKey; + private final boolean usesCdc; + private final AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation; + + static class AppendRowsContext extends RetryManager.Operation.Context { + long offset; + ProtoRows protoRows; + List timestamps; + + int failureCount; + + public AppendRowsContext( + long offset, ProtoRows protoRows, List timestamps) { + this.offset = offset; + this.protoRows = protoRows; + this.timestamps = timestamps; + this.failureCount = 0; + } + } + + class DestinationState { + private final String tableUrn; + private final String shortTableUrn; + private String streamName = ""; + private @Nullable AppendClientInfo appendClientInfo = null; + private long currentOffset = 0; + private List pendingMessages; + private List pendingTimestamps; + private transient @Nullable WriteStreamService maybeWriteStreamService; + private final Counter recordsAppended = + Metrics.counter(WriteRecordsDoFn.class, "recordsAppended"); + private final Counter appendFailures = + Metrics.counter(WriteRecordsDoFn.class, "appendFailures"); + private final Distribution inflightWaitSecondsDistribution = + Metrics.distribution(WriteRecordsDoFn.class, "streamWriterWaitSeconds"); + private final Counter rowsSentToFailedRowsCollection = + Metrics.counter( + StorageApiWritesShardedRecords.WriteRecordsDoFn.class, + "rowsSentToFailedRowsCollection"); + private final Callable tryCreateTable; + + private final boolean useDefaultStream; + private TableSchema initialTableSchema; + private DescriptorProtos.DescriptorProto initialDescriptor; + private Instant nextCacheTickle = Instant.MAX; + private final int clientNumber; + private final boolean usingMultiplexing; + private final long maxRequestSize; + + private final boolean includeCdcColumns; + + public DestinationState( + String tableUrn, + String shortTableUrn, + MessageConverter messageConverter, + WriteStreamService writeStreamService, + boolean useDefaultStream, + int streamAppendClientCount, + boolean usingMultiplexing, + long maxRequestSize, + Callable tryCreateTable, + boolean includeCdcColumns) + throws Exception { + this.tableUrn = tableUrn; + this.shortTableUrn = shortTableUrn; + this.pendingMessages = Lists.newArrayList(); + this.pendingTimestamps = Lists.newArrayList(); + this.maybeWriteStreamService = writeStreamService; + this.useDefaultStream = useDefaultStream; + this.initialTableSchema = messageConverter.getTableSchema(); + this.initialDescriptor = messageConverter.getDescriptor(includeCdcColumns); + this.clientNumber = new Random().nextInt(streamAppendClientCount); + this.usingMultiplexing = usingMultiplexing; + this.maxRequestSize = maxRequestSize; + this.tryCreateTable = tryCreateTable; + this.includeCdcColumns = includeCdcColumns; + if (includeCdcColumns) { + checkState(useDefaultStream); + } + } + + void teardown() { + maybeTickleCache(); + if (appendClientInfo != null) { + StreamAppendClient client = appendClientInfo.getStreamAppendClient(); + if (client != null) { + runAsyncIgnoreFailure(closeWriterExecutor, client::unpin); + } + appendClientInfo = null; + } + } + + String getDefaultStreamName() { + return BigQueryHelpers.stripPartitionDecorator(tableUrn) + "/streams/_default"; + } + + String getStreamAppendClientCacheEntryKey() { + if (useDefaultStream) { + String defaultStreamKey = getDefaultStreamName() + "-client" + clientNumber; + // The storage write API doesn't currently allow both inserts and updates/deletes on the + // same connection. + // Once this limitation is removed, we can remove this. + return includeCdcColumns ? defaultStreamKey + "-cdc" : defaultStreamKey; + } + return this.streamName; + } + + String getOrCreateStreamName() throws Exception { + if (Strings.isNullOrEmpty(this.streamName)) { + CreateTableHelpers.createTableWrapper( + () -> { + if (!useDefaultStream) { + this.streamName = + Preconditions.checkStateNotNull(maybeWriteStreamService) + .createWriteStream(tableUrn, Type.PENDING) + .getName(); + this.currentOffset = 0; + } else { + this.streamName = getDefaultStreamName(); + } + return null; + }, + tryCreateTable); + } + return this.streamName; + } + + AppendClientInfo generateClient(@Nullable TableSchema updatedSchema) throws Exception { + SchemaAndDescriptor schemaAndDescriptor = getCurrentTableSchema(streamName, updatedSchema); + + AtomicReference appendClientInfo = + new AtomicReference<>( + AppendClientInfo.of( + schemaAndDescriptor.tableSchema, + schemaAndDescriptor.descriptor, + // Make sure that the client is always closed in a different thread to avoid + // blocking. + client -> + runAsyncIgnoreFailure( + closeWriterExecutor, + () -> { + synchronized (APPEND_CLIENTS) { + // Remove the pin owned by the cache. + client.unpin(); + client.close(); + } + }))); + + CreateTableHelpers.createTableWrapper( + () -> { + appendClientInfo.set( + appendClientInfo + .get() + .withAppendClient( + Preconditions.checkStateNotNull(maybeWriteStreamService), + () -> streamName, + usingMultiplexing, + defaultMissingValueInterpretation)); + Preconditions.checkStateNotNull(appendClientInfo.get().getStreamAppendClient()); + return null; + }, + tryCreateTable); + + // This pin is "owned" by the cache. + Preconditions.checkStateNotNull(appendClientInfo.get().getStreamAppendClient()).pin(); + return appendClientInfo.get(); + } + + private class SchemaAndDescriptor { + private final TableSchema tableSchema; + private final DescriptorProtos.DescriptorProto descriptor; + + private SchemaAndDescriptor( + TableSchema tableSchema, DescriptorProtos.DescriptorProto descriptor) { + this.tableSchema = tableSchema; + this.descriptor = descriptor; + } + } + + SchemaAndDescriptor getCurrentTableSchema(String stream, @Nullable TableSchema updatedSchema) + throws Exception { + if (updatedSchema != null) { + return new SchemaAndDescriptor( + updatedSchema, + TableRowToStorageApiProto.descriptorSchemaFromTableSchema( + updatedSchema, true, includeCdcColumns)); + } + + AtomicReference currentSchema = new AtomicReference<>(initialTableSchema); + AtomicBoolean updated = new AtomicBoolean(); + CreateTableHelpers.createTableWrapper( + () -> { + if (autoUpdateSchema) { + @Nullable + WriteStream writeStream = + Preconditions.checkStateNotNull(maybeWriteStreamService) + .getWriteStream(streamName); + if (writeStream != null && writeStream.hasTableSchema()) { + TableSchema updatedFromStream = writeStream.getTableSchema(); + currentSchema.set(updatedFromStream); + updated.set(true); + LOG.debug( + "Fetched updated schema for table {}:\n\t{}", tableUrn, updatedFromStream); + } + } + return null; + }, + tryCreateTable); + // Note: While it may appear that these two branches are the same, it's important to return + // the actual + // initial descriptor if the schema has not changed. Simply converting the schema back into + // a descriptor isn't + // the same, and would break the direct-from-proto ingestion path. + DescriptorProtos.DescriptorProto descriptor = + updated.get() + ? TableRowToStorageApiProto.descriptorSchemaFromTableSchema( + currentSchema.get(), true, includeCdcColumns) + : initialDescriptor; + return new SchemaAndDescriptor(currentSchema.get(), descriptor); + } + + AppendClientInfo getAppendClientInfo( + boolean lookupCache, final @Nullable TableSchema updatedSchema) { + try { + if (this.appendClientInfo == null) { + getOrCreateStreamName(); + final AppendClientInfo newAppendClientInfo; + synchronized (APPEND_CLIENTS) { + if (lookupCache) { + newAppendClientInfo = + APPEND_CLIENTS.get( + getStreamAppendClientCacheEntryKey(), () -> generateClient(updatedSchema)); + } else { + newAppendClientInfo = generateClient(updatedSchema); + // override the clients in the cache. + APPEND_CLIENTS.put(getStreamAppendClientCacheEntryKey(), newAppendClientInfo); + } + // This pin is "owned" by the current DoFn. + Preconditions.checkStateNotNull(newAppendClientInfo.getStreamAppendClient()).pin(); + } + nextCacheTickle = Instant.now().plus(java.time.Duration.ofMinutes(1)); + this.appendClientInfo = newAppendClientInfo; + } + return Preconditions.checkStateNotNull(appendClientInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + void maybeTickleCache() { + if (appendClientInfo != null && Instant.now().isAfter(nextCacheTickle)) { + synchronized (APPEND_CLIENTS) { + APPEND_CLIENTS.getIfPresent(getStreamAppendClientCacheEntryKey()); + } + nextCacheTickle = Instant.now().plus(java.time.Duration.ofMinutes(1)); + } + } + + void invalidateWriteStream() { + if (appendClientInfo != null) { + synchronized (APPEND_CLIENTS) { + // Unpin in a different thread, as it may execute a blocking close. + StreamAppendClient client = appendClientInfo.getStreamAppendClient(); + if (client != null) { + runAsyncIgnoreFailure(closeWriterExecutor, client::unpin); + } + // The default stream is cached across multiple different DoFns. If they all try and + // invalidate, then we can get races between threads invalidating and recreating + // streams. For this reason, + // we check to see that the cache still contains the object we created before + // invalidating (in case another + // thread has already invalidated and recreated the stream). + String cacheEntryKey = getStreamAppendClientCacheEntryKey(); + @Nullable + AppendClientInfo cachedAppendClient = APPEND_CLIENTS.getIfPresent(cacheEntryKey); + if (cachedAppendClient != null + && System.identityHashCode(cachedAppendClient) + == System.identityHashCode(appendClientInfo)) { + APPEND_CLIENTS.invalidate(cacheEntryKey); + } + } + appendClientInfo = null; + } + } + + void addMessage( + StorageApiWritePayload payload, + org.joda.time.Instant elementTs, + OutputReceiver failedRowsReceiver) + throws Exception { + maybeTickleCache(); + ByteString payloadBytes = ByteString.copyFrom(payload.getPayload()); + if (autoUpdateSchema) { + if (appendClientInfo == null) { + appendClientInfo = getAppendClientInfo(true, null); + } + @Nullable TableRow unknownFields = payload.getUnknownFields(); + if (unknownFields != null) { + try { + payloadBytes = + payloadBytes.concat( + Preconditions.checkStateNotNull(appendClientInfo) + .encodeUnknownFields(unknownFields, ignoreUnknownValues)); + } catch (TableRowToStorageApiProto.SchemaConversionException e) { + TableRow tableRow = appendClientInfo.toTableRow(payloadBytes); + // TODO(24926, reuvenlax): We need to merge the unknown fields in! Currently we only + // execute this + // codepath when ignoreUnknownFields==true, so we should never hit this codepath. + // However once + // 24926 is fixed, we need to merge the unknownFields back into the main row before + // outputting to the + // failed-rows consumer. + org.joda.time.Instant timestamp = payload.getTimestamp(); + rowsSentToFailedRowsCollection.inc(); + failedRowsReceiver.outputWithTimestamp( + new BigQueryStorageApiInsertError(tableRow, e.toString()), + timestamp != null ? timestamp : elementTs); + return; + } + } + } + pendingMessages.add(payloadBytes); + org.joda.time.Instant timestamp = payload.getTimestamp(); + pendingTimestamps.add(timestamp != null ? timestamp : elementTs); + } + + long flush( + RetryManager retryManager, + OutputReceiver failedRowsReceiver, + @Nullable OutputReceiver successfulRowsReceiver) + throws Exception { + if (pendingMessages.isEmpty()) { + return 0; + } + + final ProtoRows.Builder insertsBuilder = ProtoRows.newBuilder(); + insertsBuilder.addAllSerializedRows(pendingMessages); + pendingMessages.clear(); + final ProtoRows inserts = insertsBuilder.build(); + List insertTimestamps = pendingTimestamps; + pendingTimestamps = Lists.newArrayList(); + + // Handle the case where the request is too large. + if (inserts.getSerializedSize() >= maxRequestSize) { + if (inserts.getSerializedRowsCount() > 1) { + // TODO(reuvenlax): Is it worth trying to handle this case by splitting the protoRows? + // Given that we split + // the ProtoRows iterable at 2MB and the max request size is 10MB, this scenario seems + // nearly impossible. + LOG.error( + "A request containing more than one row is over the request size limit of {}. " + + "This is unexpected. All rows in the request will be sent to the failed-rows PCollection.", + maxRequestSize); + } + for (int i = 0; i < inserts.getSerializedRowsCount(); ++i) { + ByteString rowBytes = inserts.getSerializedRows(i); + org.joda.time.Instant timestamp = insertTimestamps.get(i); + TableRow failedRow = + TableRowToStorageApiProto.tableRowFromMessage( + DynamicMessage.parseFrom( + TableRowToStorageApiProto.wrapDescriptorProto( + getAppendClientInfo(true, null).getDescriptor()), + rowBytes), + true); + failedRowsReceiver.outputWithTimestamp( + new BigQueryStorageApiInsertError( + failedRow, "Row payload too large. Maximum size " + maxRequestSize), + timestamp); + } + int numRowsFailed = inserts.getSerializedRowsCount(); + BigQuerySinkMetrics.appendRowsRowStatusCounter( + BigQuerySinkMetrics.RowStatus.FAILED, + BigQuerySinkMetrics.PAYLOAD_TOO_LARGE, + shortTableUrn) + .inc(numRowsFailed); + rowsSentToFailedRowsCollection.inc(numRowsFailed); + return 0; + } + + long offset = -1; + if (!this.useDefaultStream) { + getOrCreateStreamName(); // Force creation of the stream before we get offsets. + offset = this.currentOffset; + this.currentOffset += inserts.getSerializedRowsCount(); + } + AppendRowsContext appendRowsContext = + new AppendRowsContext(offset, inserts, insertTimestamps); + + retryManager.addOperation( + c -> { + if (c.protoRows.getSerializedRowsCount() == 0) { + // This might happen if all rows in a batch failed and were sent to the failed-rows + // PCollection. + return ApiFutures.immediateFuture(AppendRowsResponse.newBuilder().build()); + } + try { + StreamAppendClient writeStream = + Preconditions.checkStateNotNull( + getAppendClientInfo(true, null).getStreamAppendClient()); + ApiFuture response = + writeStream.appendRows(c.offset, c.protoRows); + inflightWaitSecondsDistribution.update(writeStream.getInflightWaitSeconds()); + if (!usingMultiplexing) { + if (writeStream.getInflightWaitSeconds() > 5) { + LOG.warn( + "Storage Api write delay more than {} seconds.", + writeStream.getInflightWaitSeconds()); + } + } + return response; + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + contexts -> { + AppendRowsContext failedContext = + Preconditions.checkStateNotNull(Iterables.getFirst(contexts, null)); + BigQuerySinkMetrics.reportFailedRPCMetrics( + failedContext, BigQuerySinkMetrics.RpcMethod.APPEND_ROWS, shortTableUrn); + String errorCode = + BigQuerySinkMetrics.throwableToGRPCCodeString(failedContext.getError()); + + if (failedContext.getError() != null + && failedContext.getError() instanceof Exceptions.AppendSerializationError) { + Exceptions.AppendSerializationError error = + Preconditions.checkStateNotNull( + (Exceptions.AppendSerializationError) failedContext.getError()); + + Set failedRowIndices = error.getRowIndexToErrorMessage().keySet(); + for (int failedIndex : failedRowIndices) { + // Convert the message to a TableRow and send it to the failedRows collection. + ByteString protoBytes = failedContext.protoRows.getSerializedRows(failedIndex); + org.joda.time.Instant timestamp = failedContext.timestamps.get(failedIndex); + try { + TableRow failedRow = + TableRowToStorageApiProto.tableRowFromMessage( + DynamicMessage.parseFrom( + TableRowToStorageApiProto.wrapDescriptorProto( + Preconditions.checkStateNotNull(appendClientInfo) + .getDescriptor()), + protoBytes), + true); + failedRowsReceiver.outputWithTimestamp( + new BigQueryStorageApiInsertError( + failedRow, error.getRowIndexToErrorMessage().get(failedIndex)), + timestamp); + } catch (Exception e) { + LOG.error("Failed to insert row and could not parse the result!", e); + } + } + int numRowsFailed = failedRowIndices.size(); + rowsSentToFailedRowsCollection.inc(numRowsFailed); + BigQuerySinkMetrics.appendRowsRowStatusCounter( + BigQuerySinkMetrics.RowStatus.FAILED, errorCode, shortTableUrn) + .inc(numRowsFailed); + + // Remove the failed row from the payload, so we retry the batch without the failed + // rows. + ProtoRows.Builder retryRows = ProtoRows.newBuilder(); + List retryTimestamps = Lists.newArrayList(); + for (int i = 0; i < failedContext.protoRows.getSerializedRowsCount(); ++i) { + if (!failedRowIndices.contains(i)) { + ByteString rowBytes = failedContext.protoRows.getSerializedRows(i); + retryRows.addSerializedRows(rowBytes); + retryTimestamps.add(failedContext.timestamps.get(i)); + } + } + failedContext.protoRows = retryRows.build(); + failedContext.timestamps = retryTimestamps; + int numRowsRetried = failedContext.protoRows.getSerializedRowsCount(); + BigQuerySinkMetrics.appendRowsRowStatusCounter( + BigQuerySinkMetrics.RowStatus.RETRIED, errorCode, shortTableUrn) + .inc(numRowsRetried); + + // Since we removed rows, we need to update the insert offsets for all remaining + // rows. + long newOffset = failedContext.offset; + for (AppendRowsContext context : contexts) { + context.offset = newOffset; + newOffset += context.protoRows.getSerializedRowsCount(); + } + this.currentOffset = newOffset; + return RetryType.RETRY_ALL_OPERATIONS; + } + + LOG.warn( + "Append to stream {} by client #{} failed with error, operations will be retried.\n{}", + streamName, + clientNumber, + retrieveErrorDetails(contexts)); + failedContext.failureCount += 1; + + boolean quotaError = false; + Throwable error = failedContext.getError(); + Status.Code statusCode = Status.Code.OK; + if (error != null) { + statusCode = Status.fromThrowable(error).getCode(); + quotaError = statusCode.equals(Status.Code.RESOURCE_EXHAUSTED); + } + + if (!quotaError) { + // This forces us to close and reopen all gRPC connections to Storage API on error, + // which empirically fixes random stuckness issues. + invalidateWriteStream(); + } + + // Maximum number of times we retry before we fail the work item. + if (failedContext.failureCount > 5) { + throw new RuntimeException("More than 5 attempts to call AppendRows failed."); + } + + // The following errors are known to be persistent, so always fail the work item in + // this case. + if (statusCode.equals(Status.Code.OUT_OF_RANGE) + || statusCode.equals(Status.Code.ALREADY_EXISTS)) { + throw new RuntimeException( + "Append to stream " + + this.streamName + + " failed with invalid " + + "offset of " + + failedContext.offset); + } + + boolean hasPersistentErrors = + failedContext.getError() instanceof Exceptions.StreamFinalizedException + || statusCode.equals(Status.Code.INVALID_ARGUMENT) + || statusCode.equals(Status.Code.NOT_FOUND) + || statusCode.equals(Status.Code.FAILED_PRECONDITION); + if (hasPersistentErrors) { + throw new RuntimeException( + String.format( + "Append to stream %s failed with Status Code %s. The stream may not exist.", + this.streamName, statusCode), + error); + } + // TODO: Only do this on explicit NOT_FOUND errors once BigQuery reliably produces + // them. + try { + tryCreateTable.call(); + } catch (Exception e) { + throw new RuntimeException(e); + } + + int numRowsRetried = failedContext.protoRows.getSerializedRowsCount(); + BigQuerySinkMetrics.appendRowsRowStatusCounter( + BigQuerySinkMetrics.RowStatus.RETRIED, errorCode, shortTableUrn) + .inc(numRowsRetried); + + appendFailures.inc(); + return RetryType.RETRY_ALL_OPERATIONS; + }, + c -> { + int numRecordsAppended = c.protoRows.getSerializedRowsCount(); + recordsAppended.inc(numRecordsAppended); + BigQuerySinkMetrics.appendRowsRowStatusCounter( + BigQuerySinkMetrics.RowStatus.SUCCESSFUL, + BigQuerySinkMetrics.OK, + shortTableUrn) + .inc(numRecordsAppended); + + BigQuerySinkMetrics.reportSuccessfulRpcMetrics( + c, BigQuerySinkMetrics.RpcMethod.APPEND_ROWS, shortTableUrn); + + if (successfulRowsReceiver != null) { + for (int i = 0; i < c.protoRows.getSerializedRowsCount(); ++i) { + ByteString rowBytes = c.protoRows.getSerializedRowsList().get(i); + try { + TableRow row = + TableRowToStorageApiProto.tableRowFromMessage( + DynamicMessage.parseFrom( + TableRowToStorageApiProto.wrapDescriptorProto( + Preconditions.checkStateNotNull(appendClientInfo) + .getDescriptor()), + rowBytes), + true); + org.joda.time.Instant timestamp = c.timestamps.get(i); + successfulRowsReceiver.outputWithTimestamp(row, timestamp); + } catch (Exception e) { + LOG.warn("Failure parsing TableRow", e); + } + } + } + }, + appendRowsContext); + maybeTickleCache(); + return inserts.getSerializedRowsCount(); + } + + String retrieveErrorDetails(Iterable failedContext) { + return StreamSupport.stream(failedContext.spliterator(), false) + .<@Nullable Throwable>map(AppendRowsContext::getError) + .filter(Objects::nonNull) + .map( + thrw -> + Preconditions.checkStateNotNull(thrw).toString() + + "\n" + + Arrays.stream(Preconditions.checkStateNotNull(thrw).getStackTrace()) + .map(StackTraceElement::toString) + .collect(Collectors.joining("\n"))) + .collect(Collectors.joining("\n")); + } + + void postFlush() { + // If we got a response indicating an updated schema, recreate the client. + if (this.appendClientInfo != null && autoUpdateSchema) { + @Nullable + StreamAppendClient streamAppendClient = appendClientInfo.getStreamAppendClient(); + @Nullable + TableSchema updatedTableSchemaReturned = + (streamAppendClient != null) ? streamAppendClient.getUpdatedSchema() : null; + if (updatedTableSchemaReturned != null) { + Optional updatedTableSchema = + TableSchemaUpdateUtils.getUpdatedSchema( + this.initialTableSchema, updatedTableSchemaReturned); + if (updatedTableSchema.isPresent()) { + invalidateWriteStream(); + appendClientInfo = + Preconditions.checkStateNotNull( + getAppendClientInfo(false, updatedTableSchema.get())); + } + } + } + } + } + + private @Nullable Map destinations = Maps.newHashMap(); + private final TwoLevelMessageConverterCache messageConverters; + private transient @Nullable DatasetService maybeDatasetService; + private transient @Nullable WriteStreamService maybeWriteStreamService; + private int numPendingRecords = 0; + private int numPendingRecordBytes = 0; + private final int flushThresholdBytes; + private final int flushThresholdCount; + private final StorageApiDynamicDestinations dynamicDestinations; + private final BigQueryServices bqServices; + private final boolean useDefaultStream; + private int streamAppendClientCount; + + WriteRecordsDoFn( + String operationName, + StorageApiDynamicDestinations dynamicDestinations, + BigQueryServices bqServices, + boolean useDefaultStream, + int flushThresholdBytes, + int flushThresholdCount, + int streamAppendClientCount, + TupleTag> finalizeTag, + TupleTag failedRowsTag, + @Nullable TupleTag successfulRowsTag, + boolean autoUpdateSchema, + boolean ignoreUnknownValues, + BigQueryIO.Write.CreateDisposition createDisposition, + @Nullable String kmsKey, + boolean usesCdc, + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation) { + this.messageConverters = new TwoLevelMessageConverterCache<>(operationName); + this.dynamicDestinations = dynamicDestinations; + this.bqServices = bqServices; + this.useDefaultStream = useDefaultStream; + this.flushThresholdBytes = flushThresholdBytes; + this.flushThresholdCount = flushThresholdCount; + this.streamAppendClientCount = streamAppendClientCount; + this.finalizeTag = finalizeTag; + this.failedRowsTag = failedRowsTag; + this.successfulRowsTag = successfulRowsTag; + this.autoUpdateSchema = autoUpdateSchema; + this.ignoreUnknownValues = ignoreUnknownValues; + this.createDisposition = createDisposition; + this.kmsKey = kmsKey; + this.usesCdc = usesCdc; + this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + } + + boolean shouldFlush() { + return numPendingRecords > flushThresholdCount || numPendingRecordBytes > flushThresholdBytes; + } + + void flushIfNecessary( + OutputReceiver failedRowsReceiver, + @Nullable OutputReceiver successfulRowsReceiver) + throws Exception { + if (shouldFlush()) { + forcedFlushes.inc(); + // Too much memory being used. Flush the state and wait for it to drain out. + // TODO(reuvenlax): Consider waiting for memory usage to drop instead of waiting for all the + // appends to finish. + flushAll(failedRowsReceiver, successfulRowsReceiver); + } + } + + void flushAll( + OutputReceiver failedRowsReceiver, + @Nullable OutputReceiver successfulRowsReceiver) + throws Exception { + List> retryManagers = + Lists.newArrayListWithCapacity(Preconditions.checkStateNotNull(destinations).size()); + long numRowsWritten = 0; + for (DestinationState destinationState : + Preconditions.checkStateNotNull(destinations).values()) { + RetryManager retryManager = + new RetryManager<>( + Duration.standardSeconds(1), + Duration.standardSeconds(10), + 1000, + BigQuerySinkMetrics.throttledTimeCounter( + BigQuerySinkMetrics.RpcMethod.APPEND_ROWS)); + retryManagers.add(retryManager); + numRowsWritten += + destinationState.flush(retryManager, failedRowsReceiver, successfulRowsReceiver); + retryManager.run(false); + } + if (numRowsWritten > 0) { + // TODO(reuvenlax): Can we await in parallel instead? Failure retries aren't triggered until + // await is called, so + // this approach means that if one call fais, it has to wait for all prior calls to complete + // before a retry happens. + for (RetryManager retryManager : retryManagers) { + retryManager.await(); + } + } + for (DestinationState destinationState : + Preconditions.checkStateNotNull(destinations).values()) { + destinationState.postFlush(); + } + numPendingRecords = 0; + numPendingRecordBytes = 0; + } + + private DatasetService initializeDatasetService(PipelineOptions pipelineOptions) { + if (maybeDatasetService == null) { + maybeDatasetService = + bqServices.getDatasetService(pipelineOptions.as(BigQueryOptions.class)); + } + return maybeDatasetService; + } + + private WriteStreamService initializeWriteStreamService(PipelineOptions pipelineOptions) { + if (maybeWriteStreamService == null) { + maybeWriteStreamService = + bqServices.getWriteStreamService(pipelineOptions.as(BigQueryOptions.class)); + } + return maybeWriteStreamService; + } + + @StartBundle + public void startBundle() throws IOException { + destinations = Maps.newHashMap(); + numPendingRecords = 0; + numPendingRecordBytes = 0; + } + + DestinationState createDestinationState( + ProcessContext c, + DestinationT destination, + boolean useCdc, + DatasetService datasetService, + WriteStreamService writeStreamService, + BigQueryOptions bigQueryOptions) { + TableDestination tableDestination1 = dynamicDestinations.getTable(destination); + checkArgument( + tableDestination1 != null, + "DynamicDestinations.getTable() may not return null, " + + "but %s returned null for destination %s", + dynamicDestinations, + destination); + @Nullable Coder destinationCoder = dynamicDestinations.getDestinationCoder(); + Callable tryCreateTable = + () -> { + CreateTableHelpers.possiblyCreateTable( + c.getPipelineOptions().as(BigQueryOptions.class), + tableDestination1, + () -> dynamicDestinations.getSchema(destination), + () -> dynamicDestinations.getTableConstraints(destination), + createDisposition, + destinationCoder, + kmsKey, + bqServices); + return true; + }; + + MessageConverter messageConverter; + try { + messageConverter = messageConverters.get(destination, dynamicDestinations, datasetService); + return new DestinationState( + tableDestination1.getTableUrn(bigQueryOptions), + tableDestination1.getShortTableUrn(), + messageConverter, + writeStreamService, + useDefaultStream, + streamAppendClientCount, + bigQueryOptions.getUseStorageApiConnectionPool(), + bigQueryOptions.getStorageWriteApiMaxRequestSize(), + tryCreateTable, + useCdc); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @ProcessElement + public void process( + ProcessContext c, + PipelineOptions pipelineOptions, + @Element KV element, + @Timestamp org.joda.time.Instant elementTs, + MultiOutputReceiver o) + throws Exception { + DatasetService initializedDatasetService = initializeDatasetService(pipelineOptions); + WriteStreamService initializedWriteStreamService = + initializeWriteStreamService(pipelineOptions); + dynamicDestinations.setSideInputAccessorFromProcessContext(c); + DestinationState state = + Preconditions.checkStateNotNull(destinations) + .computeIfAbsent( + element.getKey(), + destination -> + createDestinationState( + c, + destination, + usesCdc, + initializedDatasetService, + initializedWriteStreamService, + pipelineOptions.as(BigQueryOptions.class))); + + OutputReceiver failedRowsReceiver = o.get(failedRowsTag); + @Nullable + OutputReceiver successfulRowsReceiver = + (successfulRowsTag != null) ? o.get(successfulRowsTag) : null; + flushIfNecessary(failedRowsReceiver, successfulRowsReceiver); + state.addMessage(element.getValue(), elementTs, failedRowsReceiver); + ++numPendingRecords; + numPendingRecordBytes += element.getValue().getPayload().length; + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) throws Exception { + OutputReceiver failedRowsReceiver = + new OutputReceiver() { + @Override + public void output(BigQueryStorageApiInsertError output) { + outputWithTimestamp(output, GlobalWindow.INSTANCE.maxTimestamp()); + } + + @Override + public void outputWithTimestamp( + BigQueryStorageApiInsertError output, org.joda.time.Instant timestamp) { + context.output(failedRowsTag, output, timestamp, GlobalWindow.INSTANCE); + } + }; + @Nullable OutputReceiver successfulRowsReceiver = null; + if (successfulRowsTag != null) { + successfulRowsReceiver = + new OutputReceiver() { + @Override + public void output(TableRow output) { + outputWithTimestamp(output, GlobalWindow.INSTANCE.maxTimestamp()); + } + + @Override + public void outputWithTimestamp(TableRow output, org.joda.time.Instant timestamp) { + context.output(successfulRowsTag, output, timestamp, GlobalWindow.INSTANCE); + } + }; + } + + flushAll(failedRowsReceiver, successfulRowsReceiver); + + final Map destinations = + Preconditions.checkStateNotNull(this.destinations); + for (DestinationState state : destinations.values()) { + if (!useDefaultStream && !Strings.isNullOrEmpty(state.streamName)) { + context.output( + finalizeTag, + KV.of(state.tableUrn, state.streamName), + GlobalWindow.INSTANCE.maxTimestamp(), + GlobalWindow.INSTANCE); + } + state.teardown(); + } + destinations.clear(); + this.destinations = null; + } + + @Teardown + public void teardown() { + destinations = null; + try { + if (maybeWriteStreamService != null) { + maybeWriteStreamService.close(); + maybeWriteStreamService = null; + } + if (maybeDatasetService != null) { + maybeDatasetService.close(); + maybeDatasetService = null; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public Duration getAllowedTimestampSkew() { + return Duration.millis(BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()); + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java index b612f199a29f..d95ea703a47f 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java @@ -79,6 +79,7 @@ import org.apache.beam.sdk.transforms.Max; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.Repeatedly; @@ -114,7 +115,7 @@ }) public class StorageApiWritesShardedRecords extends PTransform< - PCollection, Iterable>>, + PCollection, Iterable>>>, PCollectionTuple> { private static final Logger LOG = LoggerFactory.getLogger(StorageApiWritesShardedRecords.class); private static final Duration DEFAULT_STREAM_IDLE_TIME = Duration.standardHours(1); @@ -129,6 +130,8 @@ public class StorageApiWritesShardedRecords formatRecordOnFailureFunction; + private final Duration streamIdleTime = DEFAULT_STREAM_IDLE_TIME; private final TupleTag failedRowsTag; private final @Nullable TupleTag successfulRowsTag; @@ -147,12 +150,18 @@ class AppendRowsContext extends RetryManager.Operation.Context originalElements; + List timestamps; AppendRowsContext( - ShardedKey key, ProtoRows protoRows, List timestamps) { + ShardedKey key, + ProtoRows protoRows, + List originalElements, + List timestamps) { this.key = key; this.protoRows = protoRows; + this.originalElements = originalElements; this.timestamps = timestamps; } @@ -221,7 +230,8 @@ public StorageApiWritesShardedRecords( @Nullable TupleTag successfulRowsTag, boolean autoUpdateSchema, boolean ignoreUnknownValues, - AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation) { + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation, + @Nullable SerializableFunction formatRecordOnFailureFunction) { this.dynamicDestinations = dynamicDestinations; this.createDisposition = createDisposition; this.kmsKey = kmsKey; @@ -234,11 +244,13 @@ public StorageApiWritesShardedRecords( this.autoUpdateSchema = autoUpdateSchema; this.ignoreUnknownValues = ignoreUnknownValues; this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + this.formatRecordOnFailureFunction = formatRecordOnFailureFunction; } @Override public PCollectionTuple expand( - PCollection, Iterable>> input) { + PCollection, Iterable>>> + input) { BigQueryOptions bigQueryOptions = input.getPipeline().getOptions().as(BigQueryOptions.class); final long splitSize = bigQueryOptions.getStorageApiAppendThresholdBytes(); final long maxRequestSize = bigQueryOptions.getStorageWriteApiMaxRequestSize(); @@ -292,7 +304,8 @@ public PCollectionTuple expand( class WriteRecordsDoFn extends DoFn< - KV, Iterable>, KV> { + KV, Iterable>>, + KV> { private final Counter recordsAppended = Metrics.counter(WriteRecordsDoFn.class, "recordsAppended"); private final Counter streamsCreated = @@ -427,7 +440,8 @@ public void onTeardown() { public void process( ProcessContext c, final PipelineOptions pipelineOptions, - @Element KV, Iterable> element, + @Element + KV, Iterable>> element, @Timestamp org.joda.time.Instant elementTs, final @AlwaysFetched @StateId("streamName") ValueState streamName, final @AlwaysFetched @StateId("streamOffset") ValueState streamOffset, @@ -551,8 +565,8 @@ public void process( // Each ProtoRows object contains at most 1MB of rows. // TODO: Push messageFromTableRow up to top level. That we we cans skip TableRow entirely if // already proto or already schema. - Iterable messages = - new SplittingIterable( + Iterable> messages = + new SplittingIterable( element.getValue(), splitSize, (fields, ignore) -> appendClientInfo.get().encodeUnknownFields(fields, ignore), @@ -571,7 +585,8 @@ public void process( }, autoUpdateSchema, ignoreUnknownValues, - elementTs); + elementTs, + formatRecordOnFailureFunction); // Initialize stream names and offsets for all contexts. This will be called initially, but // will also be called if we roll over to a new stream on a retry. @@ -671,7 +686,14 @@ public void process( for (int failedIndex : failedRowIndices) { // Convert the message to a TableRow and send it to the failedRows collection. ByteString protoBytes = failedContext.protoRows.getSerializedRows(failedIndex); - TableRow failedRow = appendClientInfo.get().toTableRow(protoBytes); + TableRow failedRow; + if (formatRecordOnFailureFunction != null) { + failedRow = + formatRecordOnFailureFunction.apply( + failedContext.originalElements.get(failedIndex)); + } else { + failedRow = appendClientInfo.get().toTableRow(protoBytes); + } org.joda.time.Instant timestamp = failedContext.timestamps.get(failedIndex); o.get(failedRowsTag) .outputWithTimestamp( @@ -822,7 +844,7 @@ public void process( 1000, BigQuerySinkMetrics.throttledTimeCounter(BigQuerySinkMetrics.RpcMethod.APPEND_ROWS)); int numAppends = 0; - for (SplittingIterable.Value splitValue : messages) { + for (SplittingIterable.Value splitValue : messages) { // Handle the case of a row that is too large. if (splitValue.getProtoRows().getSerializedSize() >= maxRequestSize) { if (splitValue.getProtoRows().getSerializedRowsCount() > 1) { @@ -838,7 +860,13 @@ public void process( for (int i = 0; i < splitValue.getProtoRows().getSerializedRowsCount(); ++i) { ByteString rowBytes = splitValue.getProtoRows().getSerializedRows(i); org.joda.time.Instant timestamp = splitValue.getTimestamps().get(i); - TableRow failedRow = appendClientInfo.get().toTableRow(rowBytes); + TableRow failedRow; + if (formatRecordOnFailureFunction != null) { + failedRow = + formatRecordOnFailureFunction.apply(splitValue.getOriginalElements().get(i)); + } else { + failedRow = appendClientInfo.get().toTableRow(rowBytes); + } o.get(failedRowsTag) .outputWithTimestamp( new BigQueryStorageApiInsertError( @@ -857,7 +885,10 @@ public void process( // RetryManager AppendRowsContext context = new AppendRowsContext( - element.getKey(), splitValue.getProtoRows(), splitValue.getTimestamps()); + element.getKey(), + splitValue.getProtoRows(), + splitValue.getOriginalElements(), + splitValue.getTimestamps()); contexts.add(context); retryManager.addOperation(runOperation, onError, onSuccess, context); recordsAppended.inc(splitValue.getProtoRows().getSerializedRowsCount()); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords256.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords256.java new file mode 100644 index 000000000000..2b173f326e70 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords256.java @@ -0,0 +1,970 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; + +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutures; +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.bigquery.storage.v1.AppendRowsRequest; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.Exceptions; +import com.google.cloud.bigquery.storage.v1.Exceptions.StreamFinalizedException; +import com.google.cloud.bigquery.storage.v1.ProtoRows; +import com.google.cloud.bigquery.storage.v1.TableSchema; +import com.google.cloud.bigquery.storage.v1.WriteStream.Type; +import com.google.protobuf.ByteString; +import com.google.protobuf.DescriptorProtos; +import io.grpc.Status; +import io.grpc.Status.Code; +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.extensions.protobuf.ProtoCoder; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StreamAppendClient; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.WriteStreamService; +import org.apache.beam.sdk.io.gcp.bigquery.RetryManager.RetryType; +import org.apache.beam.sdk.io.gcp.bigquery.StorageApiFlushAndFinalizeDoFn.Operation; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Max; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.Repeatedly; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.util.Preconditions; +import org.apache.beam.sdk.util.ShardedKey; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalNotification; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A transform to write sharded records to BigQuery using the Storage API (Streaming). */ +@SuppressWarnings({ + "FutureReturnValueIgnored", + // TODO(https://github.com/apache/beam/issues/21230): Remove when new version of + // errorprone is released (2.11.0) + "unused" +}) +public class StorageApiWritesShardedRecords256 + extends PTransform< + PCollection, Iterable>>, + PCollectionTuple> { + private static final Logger LOG = + LoggerFactory.getLogger(StorageApiWritesShardedRecords256.class); + private static final Duration DEFAULT_STREAM_IDLE_TIME = Duration.standardHours(1); + + private final StorageApiDynamicDestinations dynamicDestinations; + private final CreateDisposition createDisposition; + private final String kmsKey; + private final BigQueryServices bqServices; + private final Coder destinationCoder; + private final Coder failedRowsCoder; + private final boolean autoUpdateSchema; + private final boolean ignoreUnknownValues; + private final AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation; + + private final Duration streamIdleTime = DEFAULT_STREAM_IDLE_TIME; + private final TupleTag failedRowsTag; + private final @Nullable TupleTag successfulRowsTag; + private final Coder succussfulRowsCoder; + + private final TupleTag> flushTag = new TupleTag<>("flushTag"); + private static final ExecutorService closeWriterExecutor = Executors.newCachedThreadPool(); + + // Context passed into RetryManager for each call. + class AppendRowsContext extends RetryManager.Operation.Context { + final ShardedKey key; + String streamName = ""; + @Nullable StreamAppendClient client = null; + long offset = -1; + long numRows = 0; + long tryIteration = 0; + ProtoRows protoRows; + + List timestamps; + + AppendRowsContext( + ShardedKey key, ProtoRows protoRows, List timestamps) { + this.key = key; + this.protoRows = protoRows; + this.timestamps = timestamps; + } + + @Override + public String toString() { + return "Context: key=" + + key + + " streamName=" + + streamName + + " offset=" + + offset + + " numRows=" + + numRows + + " tryIteration: " + + tryIteration; + } + }; + + private static final Cache, AppendClientInfo> APPEND_CLIENTS = + CacheBuilder.newBuilder() + .expireAfterAccess(5, TimeUnit.MINUTES) + .removalListener( + (RemovalNotification, AppendClientInfo> removal) -> { + final @Nullable AppendClientInfo appendClientInfo = removal.getValue(); + if (appendClientInfo != null) { + appendClientInfo.close(); + } + }) + .build(); + + static void clearCache() { + APPEND_CLIENTS.invalidateAll(); + } + + // Run a closure asynchronously, ignoring failures. + private interface ThrowingRunnable { + void run() throws Exception; + } + + private static void runAsyncIgnoreFailure(ExecutorService executor, ThrowingRunnable task) { + executor.submit( + () -> { + try { + task.run(); + } catch (Exception e) { + String msg = + e.toString() + + "\n" + + Arrays.stream(e.getStackTrace()) + .map(StackTraceElement::toString) + .collect(Collectors.joining("\n")); + System.err.println("Exception happened while executing async task. Ignoring: " + msg); + } + }); + } + + public StorageApiWritesShardedRecords256( + StorageApiDynamicDestinations dynamicDestinations, + CreateDisposition createDisposition, + String kmsKey, + BigQueryServices bqServices, + Coder destinationCoder, + Coder failedRowsCoder, + Coder successfulRowsCoder, + TupleTag failedRowsTag, + @Nullable TupleTag successfulRowsTag, + boolean autoUpdateSchema, + boolean ignoreUnknownValues, + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation) { + this.dynamicDestinations = dynamicDestinations; + this.createDisposition = createDisposition; + this.kmsKey = kmsKey; + this.bqServices = bqServices; + this.destinationCoder = destinationCoder; + this.failedRowsCoder = failedRowsCoder; + this.failedRowsTag = failedRowsTag; + this.successfulRowsTag = successfulRowsTag; + this.succussfulRowsCoder = successfulRowsCoder; + this.autoUpdateSchema = autoUpdateSchema; + this.ignoreUnknownValues = ignoreUnknownValues; + this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + } + + @Override + public PCollectionTuple expand( + PCollection, Iterable>> input) { + BigQueryOptions bigQueryOptions = input.getPipeline().getOptions().as(BigQueryOptions.class); + final long splitSize = bigQueryOptions.getStorageApiAppendThresholdBytes(); + final long maxRequestSize = bigQueryOptions.getStorageWriteApiMaxRequestSize(); + + String operationName = input.getName() + "/" + getName(); + TupleTagList tupleTagList = TupleTagList.of(failedRowsTag); + if (successfulRowsTag != null) { + tupleTagList = tupleTagList.and(successfulRowsTag); + } + // Append records to the Storage API streams. + PCollectionTuple writeRecordsResult = + input.apply( + "Write Records", + ParDo.of(new WriteRecordsDoFn(operationName, streamIdleTime, splitSize, maxRequestSize)) + .withSideInputs(dynamicDestinations.getSideInputs()) + .withOutputTags(flushTag, tupleTagList)); + + SchemaCoder operationCoder; + try { + SchemaRegistry schemaRegistry = input.getPipeline().getSchemaRegistry(); + operationCoder = + SchemaCoder.of( + schemaRegistry.getSchema(Operation.class), + TypeDescriptor.of(Operation.class), + schemaRegistry.getToRowFunction(Operation.class), + schemaRegistry.getFromRowFunction(Operation.class)); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e); + } + + // Send all successful writes to be flushed. + writeRecordsResult + .get(flushTag) + .setCoder(KvCoder.of(StringUtf8Coder.of(), operationCoder)) + .apply( + Window.>configure() + .triggering( + Repeatedly.forever( + AfterProcessingTime.pastFirstElementInPane() + .plusDelayOf(Duration.standardSeconds(1)))) + .discardingFiredPanes()) + .apply("maxFlushPosition", Combine.perKey(Max.naturalOrder(new Operation(-1, false)))) + .apply( + "Flush and finalize writes", ParDo.of(new StorageApiFlushAndFinalizeDoFn(bqServices))); + writeRecordsResult.get(failedRowsTag).setCoder(failedRowsCoder); + if (successfulRowsTag != null) { + writeRecordsResult.get(successfulRowsTag).setCoder(succussfulRowsCoder); + } + return writeRecordsResult; + } + + class WriteRecordsDoFn + extends DoFn< + KV, Iterable>, KV> { + private final Counter recordsAppended = + Metrics.counter(WriteRecordsDoFn.class, "recordsAppended"); + private final Counter streamsCreated = + Metrics.counter(WriteRecordsDoFn.class, "streamsCreated"); + private final Counter streamsIdle = + Metrics.counter(WriteRecordsDoFn.class, "idleStreamsFinalized"); + private final Counter appendFailures = + Metrics.counter(WriteRecordsDoFn.class, "appendFailures"); + private final Counter appendOffsetFailures = + Metrics.counter(WriteRecordsDoFn.class, "appendOffsetFailures"); + private final Counter flushesScheduled = + Metrics.counter(WriteRecordsDoFn.class, "flushesScheduled"); + private final Distribution appendLatencyDistribution = + Metrics.distribution(WriteRecordsDoFn.class, "appendLatencyDistributionMs"); + private final Distribution appendSizeDistribution = + Metrics.distribution(WriteRecordsDoFn.class, "appendSizeDistribution"); + private final Distribution appendSplitDistribution = + Metrics.distribution(WriteRecordsDoFn.class, "appendSplitDistribution"); + private final Counter rowsSentToFailedRowsCollection = + Metrics.counter(WriteRecordsDoFn.class, "rowsSentToFailedRowsCollection"); + + private TwoLevelMessageConverterCache messageConverters; + + private Map destinations = Maps.newHashMap(); + + private transient @Nullable DatasetService datasetServiceInternal = null; + + private transient @Nullable WriteStreamService writeStreamServiceInternal = null; + + // Stores the current stream for this key. + @StateId("streamName") + private final StateSpec> streamNameSpec = StateSpecs.value(); + + // Stores the current stream offset. + @StateId("streamOffset") + private final StateSpec> streamOffsetSpec = StateSpecs.value(); + + @StateId("updatedSchema") + private final StateSpec> updatedSchema = + StateSpecs.value(ProtoCoder.of(TableSchema.class)); + + @TimerId("idleTimer") + private final TimerSpec idleTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + private final Duration streamIdleTime; + private final long splitSize; + private final long maxRequestSize; + + public WriteRecordsDoFn( + String operationName, Duration streamIdleTime, long splitSize, long maxRequestSize) { + this.messageConverters = new TwoLevelMessageConverterCache<>(operationName); + this.streamIdleTime = streamIdleTime; + this.splitSize = splitSize; + this.maxRequestSize = maxRequestSize; + } + + @StartBundle + public void startBundle() throws IOException { + destinations = Maps.newHashMap(); + } + + // Get the current stream for this key. If there is no current stream, create one and store the + // stream name in persistent state. + String getOrCreateStream( + String tableId, + ValueState streamName, + ValueState streamOffset, + Timer streamIdleTimer, + WriteStreamService writeStreamService, + Callable tryCreateTable) { + try { + final @Nullable String streamValue = streamName.read(); + AtomicReference stream = new AtomicReference<>(); + if (streamValue == null || "".equals(streamValue)) { + // In a buffered stream, data is only visible up to the offset to which it was flushed. + CreateTableHelpers.createTableWrapper( + () -> { + stream.set(writeStreamService.createWriteStream(tableId, Type.BUFFERED).getName()); + return null; + }, + tryCreateTable); + + streamName.write(stream.get()); + streamOffset.write(0L); + streamsCreated.inc(); + } else { + stream.set(streamValue); + } + // Reset the idle timer. + streamIdleTimer.offset(streamIdleTime).withNoOutputTimestamp().setRelative(); + + return stream.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private DatasetService getDatasetService(PipelineOptions pipelineOptions) throws IOException { + if (datasetServiceInternal == null) { + datasetServiceInternal = + bqServices.getDatasetService(pipelineOptions.as(BigQueryOptions.class)); + } + return datasetServiceInternal; + } + + private WriteStreamService getWriteStreamService(PipelineOptions pipelineOptions) + throws IOException { + if (writeStreamServiceInternal == null) { + writeStreamServiceInternal = + bqServices.getWriteStreamService(pipelineOptions.as(BigQueryOptions.class)); + } + return writeStreamServiceInternal; + } + + @Teardown + public void onTeardown() { + try { + if (writeStreamServiceInternal != null) { + writeStreamServiceInternal.close(); + writeStreamServiceInternal = null; + } + if (datasetServiceInternal != null) { + datasetServiceInternal.close(); + datasetServiceInternal = null; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @ProcessElement + public void process( + ProcessContext c, + final PipelineOptions pipelineOptions, + @Element KV, Iterable> element, + @Timestamp org.joda.time.Instant elementTs, + final @AlwaysFetched @StateId("streamName") ValueState streamName, + final @AlwaysFetched @StateId("streamOffset") ValueState streamOffset, + final @StateId("updatedSchema") ValueState updatedSchema, + @TimerId("idleTimer") Timer idleTimer, + final MultiOutputReceiver o) + throws Exception { + BigQueryOptions bigQueryOptions = pipelineOptions.as(BigQueryOptions.class); + + if (autoUpdateSchema) { + updatedSchema.readLater(); + } + + dynamicDestinations.setSideInputAccessorFromProcessContext(c); + TableDestination tableDestination = + destinations.computeIfAbsent( + element.getKey().getKey(), + dest -> { + TableDestination tableDestination1 = dynamicDestinations.getTable(dest); + checkArgument( + tableDestination1 != null, + "DynamicDestinations.getTable() may not return null, " + + "but %s returned null for destination %s", + dynamicDestinations, + dest); + return tableDestination1; + }); + final String tableId = tableDestination.getTableUrn(bigQueryOptions); + final String shortTableId = tableDestination.getShortTableUrn(); + final DatasetService datasetService = getDatasetService(pipelineOptions); + final WriteStreamService writeStreamService = getWriteStreamService(pipelineOptions); + + Coder destinationCoder = dynamicDestinations.getDestinationCoder(); + Callable tryCreateTable = + () -> { + DestinationT dest = element.getKey().getKey(); + CreateTableHelpers.possiblyCreateTable( + c.getPipelineOptions().as(BigQueryOptions.class), + tableDestination, + () -> dynamicDestinations.getSchema(dest), + () -> dynamicDestinations.getTableConstraints(dest), + createDisposition, + destinationCoder, + kmsKey, + bqServices); + return true; + }; + + Supplier getOrCreateStream = + () -> + getOrCreateStream( + tableId, streamName, streamOffset, idleTimer, writeStreamService, tryCreateTable); + Callable getAppendClientInfo = + () -> { + @Nullable TableSchema tableSchema; + DescriptorProtos.DescriptorProto descriptor; + TableSchema updatedSchemaValue = updatedSchema.read(); + if (autoUpdateSchema && updatedSchemaValue != null) { + // We've seen an updated schema, so we use that instead of querying the + // MessageConverter. + tableSchema = updatedSchemaValue; + descriptor = + TableRowToStorageApiProto.descriptorSchemaFromTableSchema( + tableSchema, true, false); + } else { + // Start off with the base schema. As we get notified of schema updates, we + // will update the descriptor. + StorageApiDynamicDestinations.MessageConverter converter = + messageConverters.get( + element.getKey().getKey(), dynamicDestinations, datasetService); + tableSchema = converter.getTableSchema(); + descriptor = converter.getDescriptor(false); + } + AppendClientInfo info = + AppendClientInfo.of( + Preconditions.checkStateNotNull(tableSchema), + descriptor, + // Make sure that the client is always closed in a different thread + // to + // avoid blocking. + client -> + runAsyncIgnoreFailure( + closeWriterExecutor, + () -> { + // Remove the pin that is "owned" by the cache. + client.unpin(); + client.close(); + })) + .withAppendClient( + writeStreamService, + getOrCreateStream, + false, + defaultMissingValueInterpretation); + // This pin is "owned" by the cache. + Preconditions.checkStateNotNull(info.getStreamAppendClient()).pin(); + return info; + }; + + AtomicReference appendClientInfo = + new AtomicReference<>(APPEND_CLIENTS.get(element.getKey(), getAppendClientInfo)); + String currentStream = getOrCreateStream.get(); + if (!currentStream.equals(appendClientInfo.get().getStreamName())) { + // Cached append client is inconsistent with persisted state. Throw away cached item and + // force it to be + // recreated. + APPEND_CLIENTS.invalidate(element.getKey()); + appendClientInfo.set(APPEND_CLIENTS.get(element.getKey(), getAppendClientInfo)); + } + + TableSchema updatedSchemaValue = updatedSchema.read(); + if (autoUpdateSchema && updatedSchemaValue != null) { + if (appendClientInfo.get().hasSchemaChanged(updatedSchemaValue)) { + appendClientInfo.set( + AppendClientInfo.of( + updatedSchemaValue, appendClientInfo.get().getCloseAppendClient(), false)); + APPEND_CLIENTS.invalidate(element.getKey()); + APPEND_CLIENTS.put(element.getKey(), appendClientInfo.get()); + } + } + + // Each ProtoRows object contains at most 1MB of rows. + // TODO: Push messageFromTableRow up to top level. That we we cans skip TableRow entirely if + // already proto or already schema. + Iterable messages = + new SplittingIterable256( + element.getValue(), + splitSize, + (fields, ignore) -> appendClientInfo.get().encodeUnknownFields(fields, ignore), + bytes -> appendClientInfo.get().toTableRow(bytes), + (failedRow, errorMessage) -> { + o.get(failedRowsTag) + .outputWithTimestamp( + new BigQueryStorageApiInsertError(failedRow.getValue(), errorMessage), + failedRow.getTimestamp()); + rowsSentToFailedRowsCollection.inc(); + BigQuerySinkMetrics.appendRowsRowStatusCounter( + BigQuerySinkMetrics.RowStatus.FAILED, + BigQuerySinkMetrics.PAYLOAD_TOO_LARGE, + shortTableId) + .inc(1); + }, + autoUpdateSchema, + ignoreUnknownValues, + elementTs); + + // Initialize stream names and offsets for all contexts. This will be called initially, but + // will also be called if we roll over to a new stream on a retry. + BiConsumer, Boolean> initializeContexts = + (contexts, isFailure) -> { + try { + if (isFailure) { + // Clear the stream name, forcing a new one to be created. + streamName.write(""); + } + appendClientInfo.set( + appendClientInfo + .get() + .withAppendClient( + writeStreamService, + getOrCreateStream, + false, + defaultMissingValueInterpretation)); + StreamAppendClient streamAppendClient = + Preconditions.checkArgumentNotNull( + appendClientInfo.get().getStreamAppendClient()); + String streamNameRead = Preconditions.checkArgumentNotNull(streamName.read()); + long currentOffset = Preconditions.checkArgumentNotNull(streamOffset.read()); + for (AppendRowsContext context : contexts) { + context.streamName = streamNameRead; + streamAppendClient.pin(); + context.client = appendClientInfo.get().getStreamAppendClient(); + context.offset = currentOffset; + ++context.tryIteration; + currentOffset = context.offset + context.protoRows.getSerializedRowsCount(); + } + streamOffset.write(currentOffset); + } catch (Exception e) { + throw new RuntimeException(e); + } + }; + + Consumer> clearClients = + contexts -> { + APPEND_CLIENTS.invalidate(element.getKey()); + appendClientInfo.set(appendClientInfo.get().withNoAppendClient()); + APPEND_CLIENTS.put(element.getKey(), appendClientInfo.get()); + for (AppendRowsContext context : contexts) { + if (context.client != null) { + // Unpin in a different thread, as it may execute a blocking close. + runAsyncIgnoreFailure(closeWriterExecutor, context.client::unpin); + context.client = null; + } + } + }; + + Function> runOperation = + context -> { + if (context.protoRows.getSerializedRowsCount() == 0) { + // This might happen if all rows in a batch failed and were sent to the failed-rows + // PCollection. + return ApiFutures.immediateFuture(AppendRowsResponse.newBuilder().build()); + } + try { + appendClientInfo.set( + appendClientInfo + .get() + .withAppendClient( + writeStreamService, + getOrCreateStream, + false, + defaultMissingValueInterpretation)); + return Preconditions.checkStateNotNull(appendClientInfo.get().getStreamAppendClient()) + .appendRows(context.offset, context.protoRows); + } catch (Exception e) { + throw new RuntimeException(e); + } + }; + + Function, RetryType> onError = + failedContexts -> { + // The first context is always the one that fails. + AppendRowsContext failedContext = + Preconditions.checkStateNotNull(Iterables.getFirst(failedContexts, null)); + BigQuerySinkMetrics.reportFailedRPCMetrics( + failedContext, BigQuerySinkMetrics.RpcMethod.APPEND_ROWS, shortTableId); + String errorCode = + BigQuerySinkMetrics.throwableToGRPCCodeString(failedContext.getError()); + + // AppendSerializationError means that BigQuery detected errors on individual rows, e.g. + // a row not conforming + // to bigQuery invariants. These errors are persistent, so we redirect those rows to the + // failedInserts + // PCollection, and retry with the remaining rows. + if (failedContext.getError() != null + && failedContext.getError() instanceof Exceptions.AppendSerializationError) { + Exceptions.AppendSerializationError error = + Preconditions.checkArgumentNotNull( + (Exceptions.AppendSerializationError) failedContext.getError()); + + Set failedRowIndices = error.getRowIndexToErrorMessage().keySet(); + for (int failedIndex : failedRowIndices) { + // Convert the message to a TableRow and send it to the failedRows collection. + ByteString protoBytes = failedContext.protoRows.getSerializedRows(failedIndex); + TableRow failedRow = appendClientInfo.get().toTableRow(protoBytes); + org.joda.time.Instant timestamp = failedContext.timestamps.get(failedIndex); + o.get(failedRowsTag) + .outputWithTimestamp( + new BigQueryStorageApiInsertError( + failedRow, error.getRowIndexToErrorMessage().get(failedIndex)), + timestamp); + } + int failedRows = failedRowIndices.size(); + rowsSentToFailedRowsCollection.inc(failedRows); + BigQuerySinkMetrics.appendRowsRowStatusCounter( + BigQuerySinkMetrics.RowStatus.FAILED, errorCode, shortTableId) + .inc(failedRows); + + // Remove the failed row from the payload, so we retry the batch without the failed + // rows. + ProtoRows.Builder retryRows = ProtoRows.newBuilder(); + @Nullable List timestamps = Lists.newArrayList(); + for (int i = 0; i < failedContext.protoRows.getSerializedRowsCount(); ++i) { + if (!failedRowIndices.contains(i)) { + ByteString rowBytes = failedContext.protoRows.getSerializedRows(i); + retryRows.addSerializedRows(rowBytes); + timestamps.add(failedContext.timestamps.get(i)); + } + } + failedContext.protoRows = retryRows.build(); + failedContext.timestamps = timestamps; + int retriedRows = failedContext.protoRows.getSerializedRowsCount(); + BigQuerySinkMetrics.appendRowsRowStatusCounter( + BigQuerySinkMetrics.RowStatus.RETRIED, errorCode, shortTableId) + .inc(retriedRows); + + // Since we removed rows, we need to update the insert offsets for all remaining rows. + long offset = failedContext.offset; + for (AppendRowsContext context : failedContexts) { + context.offset = offset; + offset += context.protoRows.getSerializedRowsCount(); + } + streamOffset.write(offset); + return RetryType.RETRY_ALL_OPERATIONS; + } + + Throwable error = Preconditions.checkStateNotNull(failedContext.getError()); + Status.Code statusCode = Status.fromThrowable(error).getCode(); + + // This means that the offset we have stored does not match the current end of + // the stream in the Storage API. Usually this happens because a crash or a bundle + // failure + // happened after an append but before the worker could checkpoint it's + // state. The records that were appended in a failed bundle will be retried, + // meaning that the unflushed tail of the stream must be discarded to prevent + // duplicates. + boolean offsetMismatch = + statusCode.equals(Code.OUT_OF_RANGE) || statusCode.equals(Code.ALREADY_EXISTS); + + boolean quotaError = statusCode.equals(Code.RESOURCE_EXHAUSTED); + if (!offsetMismatch) { + // Don't log errors for expected offset mismatch. These will be logged as warnings + // below. + LOG.error( + "Got error " + failedContext.getError() + " closing " + failedContext.streamName); + } + + try { + // TODO: Only do this on explicit NOT_FOUND errors once BigQuery reliably produces + // them. + tryCreateTable.call(); + } catch (Exception e) { + throw new RuntimeException(e); + } + + if (!quotaError) { + // This forces us to close and reopen all gRPC connections to Storage API on error, + // which empirically fixes random stuckness issues. + clearClients.accept(failedContexts); + } + appendFailures.inc(); + int retriedRows = failedContext.protoRows.getSerializedRowsCount(); + BigQuerySinkMetrics.appendRowsRowStatusCounter( + BigQuerySinkMetrics.RowStatus.RETRIED, errorCode, shortTableId) + .inc(retriedRows); + + boolean explicitStreamFinalized = + failedContext.getError() instanceof StreamFinalizedException; + // This implies that the stream doesn't exist or has already been finalized. In this + // case we have no choice but to create a new stream. + boolean streamDoesNotExist = + explicitStreamFinalized + || statusCode.equals(Code.INVALID_ARGUMENT) + || statusCode.equals(Code.NOT_FOUND) + || statusCode.equals(Code.FAILED_PRECONDITION); + if (offsetMismatch || streamDoesNotExist) { + appendOffsetFailures.inc(); + LOG.warn( + "Append to " + + failedContext + + " failed with " + + failedContext.getError() + + " Will retry with a new stream"); + // Finalize the stream and clear streamName so a new stream will be created. + o.get(flushTag) + .output( + KV.of( + failedContext.streamName, new Operation(failedContext.offset - 1, true))); + // Reinitialize all contexts with the new stream and new offsets. + initializeContexts.accept(failedContexts, true); + + // Offset failures imply that all subsequent parallel appends will also fail. + // Retry them all. + return RetryType.RETRY_ALL_OPERATIONS; + } + + return RetryType.RETRY_ALL_OPERATIONS; + }; + + Consumer onSuccess = + context -> { + AppendRowsResponse response = Preconditions.checkStateNotNull(context.getResult()); + o.get(flushTag) + .output( + KV.of( + context.streamName, + new Operation( + context.offset + context.protoRows.getSerializedRowsCount() - 1, + false))); + int flushedRows = context.protoRows.getSerializedRowsCount(); + flushesScheduled.inc(flushedRows); + BigQuerySinkMetrics.reportSuccessfulRpcMetrics( + context, BigQuerySinkMetrics.RpcMethod.APPEND_ROWS, shortTableId); + BigQuerySinkMetrics.appendRowsRowStatusCounter( + BigQuerySinkMetrics.RowStatus.SUCCESSFUL, BigQuerySinkMetrics.OK, shortTableId) + .inc(flushedRows); + + if (successfulRowsTag != null) { + for (int i = 0; i < context.protoRows.getSerializedRowsCount(); ++i) { + ByteString protoBytes = context.protoRows.getSerializedRows(i); + org.joda.time.Instant timestamp = context.timestamps.get(i); + o.get(successfulRowsTag) + .outputWithTimestamp(appendClientInfo.get().toTableRow(protoBytes), timestamp); + } + } + }; + Instant now = Instant.now(); + List contexts = Lists.newArrayList(); + RetryManager retryManager = + new RetryManager<>( + Duration.standardSeconds(1), + Duration.standardSeconds(10), + 1000, + BigQuerySinkMetrics.throttledTimeCounter(BigQuerySinkMetrics.RpcMethod.APPEND_ROWS)); + int numAppends = 0; + for (SplittingIterable256.Value splitValue : messages) { + // Handle the case of a row that is too large. + if (splitValue.getProtoRows().getSerializedSize() >= maxRequestSize) { + if (splitValue.getProtoRows().getSerializedRowsCount() > 1) { + // TODO(reuvenlax): Is it worth trying to handle this case by splitting the protoRows? + // Given that we split + // the ProtoRows iterable at 2MB and the max request size is 10MB, this scenario seems + // nearly impossible. + LOG.error( + "A request containing more than one row is over the request size limit of " + + maxRequestSize + + ". This is unexpected. All rows in the request will be sent to the failed-rows PCollection."); + } + for (int i = 0; i < splitValue.getProtoRows().getSerializedRowsCount(); ++i) { + ByteString rowBytes = splitValue.getProtoRows().getSerializedRows(i); + org.joda.time.Instant timestamp = splitValue.getTimestamps().get(i); + TableRow failedRow = appendClientInfo.get().toTableRow(rowBytes); + o.get(failedRowsTag) + .outputWithTimestamp( + new BigQueryStorageApiInsertError( + failedRow, "Row payload too large. Maximum size " + maxRequestSize), + timestamp); + } + int numRowsFailed = splitValue.getProtoRows().getSerializedRowsCount(); + rowsSentToFailedRowsCollection.inc(numRowsFailed); + BigQuerySinkMetrics.appendRowsRowStatusCounter( + BigQuerySinkMetrics.RowStatus.FAILED, + BigQuerySinkMetrics.PAYLOAD_TOO_LARGE, + shortTableId) + .inc(numRowsFailed); + } else { + ++numAppends; + // RetryManager + AppendRowsContext context = + new AppendRowsContext( + element.getKey(), splitValue.getProtoRows(), splitValue.getTimestamps()); + contexts.add(context); + retryManager.addOperation(runOperation, onError, onSuccess, context); + recordsAppended.inc(splitValue.getProtoRows().getSerializedRowsCount()); + appendSizeDistribution.update(context.protoRows.getSerializedRowsCount()); + } + } + + if (numAppends > 0) { + initializeContexts.accept(contexts, false); + try { + retryManager.run(true); + } finally { + // Make sure that all pins are removed. + for (AppendRowsContext context : contexts) { + if (context.client != null) { + runAsyncIgnoreFailure(closeWriterExecutor, context.client::unpin); + } + } + } + appendSplitDistribution.update(numAppends); + + if (autoUpdateSchema) { + @Nullable + StreamAppendClient streamAppendClient = appendClientInfo.get().getStreamAppendClient(); + TableSchema originalSchema = appendClientInfo.get().getTableSchema(); + ; + @Nullable + TableSchema updatedSchemaReturned = + (streamAppendClient != null) ? streamAppendClient.getUpdatedSchema() : null; + // Update the table schema and clear the append client. + if (updatedSchemaReturned != null) { + Optional newSchema = + TableSchemaUpdateUtils.getUpdatedSchema(originalSchema, updatedSchemaReturned); + if (newSchema.isPresent()) { + appendClientInfo.set( + AppendClientInfo.of( + newSchema.get(), appendClientInfo.get().getCloseAppendClient(), false)); + APPEND_CLIENTS.invalidate(element.getKey()); + APPEND_CLIENTS.put(element.getKey(), appendClientInfo.get()); + LOG.debug( + "Fetched updated schema for table {}:\n\t{}", tableId, updatedSchemaReturned); + updatedSchema.write(newSchema.get()); + } + } + } + + java.time.Duration timeElapsed = java.time.Duration.between(now, Instant.now()); + appendLatencyDistribution.update(timeElapsed.toMillis()); + } + idleTimer.offset(streamIdleTime).withNoOutputTimestamp().setRelative(); + } + + // called by the idleTimer and window-expiration handlers. + private void finalizeStream( + @AlwaysFetched @StateId("streamName") ValueState streamName, + @AlwaysFetched @StateId("streamOffset") ValueState streamOffset, + ShardedKey key, + MultiOutputReceiver o, + org.joda.time.Instant finalizeElementTs) { + String stream = MoreObjects.firstNonNull(streamName.read(), ""); + + if (!Strings.isNullOrEmpty(stream)) { + // Finalize the stream + long nextOffset = MoreObjects.firstNonNull(streamOffset.read(), 0L); + o.get(flushTag) + .outputWithTimestamp( + KV.of(stream, new Operation(nextOffset - 1, true)), finalizeElementTs); + streamName.clear(); + streamOffset.clear(); + // Make sure that the stream object is closed. + APPEND_CLIENTS.invalidate(key); + } + } + + @OnTimer("idleTimer") + public void onTimer( + @Key ShardedKey key, + @AlwaysFetched @StateId("streamName") ValueState streamName, + @AlwaysFetched @StateId("streamOffset") ValueState streamOffset, + MultiOutputReceiver o, + BoundedWindow window) { + // Stream is idle - clear it. + // Note: this is best effort. We are explicitly emiting a timestamp that is before + // the default output timestamp, which means that in some cases (usually when draining + // a pipeline) this finalize element will be dropped as late. This is usually ok as + // BigQuery will eventually garbage collect the stream. We attempt to finalize idle streams + // merely to remove the pressure of large numbers of orphaned streams from BigQuery. + finalizeStream(streamName, streamOffset, key, o, window.maxTimestamp()); + streamsIdle.inc(); + } + + @OnWindowExpiration + public void onWindowExpiration( + @Key ShardedKey key, + @AlwaysFetched @StateId("streamName") ValueState streamName, + @AlwaysFetched @StateId("streamOffset") ValueState streamOffset, + MultiOutputReceiver o, + BoundedWindow window) { + // Window is done - usually because the pipeline has been drained. Make sure to clean up + // streams so that they are not leaked. + finalizeStream(streamName, streamOffset, key, o, window.maxTimestamp()); + } + + @Override + public Duration getAllowedTimestampSkew() { + return Duration.millis(BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()); + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkFailedRowsIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkFailedRowsIT.java index f721f57147e3..c6cdaac818b9 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkFailedRowsIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiSinkFailedRowsIT.java @@ -33,11 +33,13 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.Nullable; import org.hamcrest.Matchers; import org.joda.time.Duration; import org.junit.AfterClass; @@ -102,6 +104,17 @@ public static Iterable data() { private static final byte[] BIG_BYTES = new byte[11 * 1024 * 1024]; + private static final SerializableFunction ERROR_FN = + new SerializableFunction() { + @Override + public TableRow apply(TableRow input) { + if (input.containsKey("bytes")) { + return new TableRow().set("bytes", new byte[10]).set("modified", "true"); + } + return input.clone().set("modified", "true"); + } + }; + private BigQueryIO.Write.Method getMethod() { return useAtLeastOnce ? BigQueryIO.Write.Method.STORAGE_API_AT_LEAST_ONCE @@ -128,34 +141,106 @@ public static void cleanup() { @Test public void testSchemaMismatchCaughtByBeam() throws IOException, InterruptedException { String tableSpec = createTable(BASE_TABLE_SCHEMA); + + Iterable goodRows = getSchemaMismatchGoodRows(); + Iterable badRows = getSchemaMismatchBadRows(); + + runPipeline( + getMethod(), + useStreamingExactlyOnce, + tableSpec, + Iterables.concat(goodRows, badRows), + badRows, + null); + assertGoodRowsWritten(tableSpec, goodRows); + } + + @Test + public void testSchemaMismatchCaughtByBeamWithCustomErrorHandling() + throws IOException, InterruptedException { + String tableSpec = createTable(BASE_TABLE_SCHEMA); + + Iterable goodRows = getSchemaMismatchGoodRows(); + Iterable badRows = getSchemaMismatchBadRows(); + + List modifiedBadRows = getSchemaMismatchBadRows(); + + for (TableRow tr : modifiedBadRows) { + tr.set("modified", "true"); + } + + runPipeline( + getMethod(), + useStreamingExactlyOnce, + tableSpec, + Iterables.concat(goodRows, badRows), + modifiedBadRows, + ERROR_FN); + assertGoodRowsWritten(tableSpec, goodRows); + } + + public List getSchemaMismatchGoodRows() { TableRow good1 = new TableRow().set("str", "foo").set("i64", "42"); TableRow good2 = new TableRow().set("str", "foo").set("i64", "43"); - Iterable goodRows = - ImmutableList.of( - good1.clone().set("inner", new TableRow()), - good2.clone().set("inner", new TableRow()), - new TableRow().set("inner", good1), - new TableRow().set("inner", good2)); + return ImmutableList.of( + good1.clone().set("inner", new TableRow()), + good2.clone().set("inner", new TableRow()), + new TableRow().set("inner", good1), + new TableRow().set("inner", good2)); + } + public List getSchemaMismatchBadRows() { TableRow bad1 = new TableRow().set("str", "foo").set("i64", "baad"); TableRow bad2 = new TableRow().set("str", "foo").set("i64", "42").set("unknown", "foobar"); - Iterable badRows = - ImmutableList.of( - bad1, bad2, new TableRow().set("inner", bad1), new TableRow().set("inner", bad2)); + return ImmutableList.of( + bad1, + bad2, + new TableRow().set("inner", bad1.clone()), + new TableRow().set("inner", bad2.clone())); + } + + @Test + public void testInvalidRowCaughtByBigquery() throws IOException, InterruptedException { + String tableSpec = createTable(BASE_TABLE_SCHEMA); + + Iterable goodRows = getInvalidBQGoodRows(); + Iterable badRows = getInvalidBQBadRows(false); runPipeline( getMethod(), useStreamingExactlyOnce, tableSpec, Iterables.concat(goodRows, badRows), - badRows); + badRows, + null); assertGoodRowsWritten(tableSpec, goodRows); } @Test - public void testInvalidRowCaughtByBigquery() throws IOException, InterruptedException { + public void testInvalidRowCaughtByBigqueryWithCustomErrorHandling() + throws IOException, InterruptedException { String tableSpec = createTable(BASE_TABLE_SCHEMA); + Iterable goodRows = getInvalidBQGoodRows(); + Iterable badRows = getInvalidBQBadRows(false); + + List modifiedBadRows = getInvalidBQBadRows(true); + + for (TableRow tr : modifiedBadRows) { + tr.set("modified", "true"); + } + + runPipeline( + getMethod(), + useStreamingExactlyOnce, + tableSpec, + Iterables.concat(goodRows, badRows), + modifiedBadRows, + ERROR_FN); + assertGoodRowsWritten(tableSpec, goodRows); + } + + private List getInvalidBQGoodRows() { TableRow good1 = new TableRow() .set("str", "foo") @@ -164,13 +249,14 @@ public void testInvalidRowCaughtByBigquery() throws IOException, InterruptedExce .set("stronearray", Lists.newArrayList()); TableRow good2 = new TableRow().set("str", "foo").set("i64", "43").set("stronearray", Lists.newArrayList()); - Iterable goodRows = - ImmutableList.of( - good1.clone().set("inner", new TableRow().set("stronearray", Lists.newArrayList())), - good2.clone().set("inner", new TableRow().set("stronearray", Lists.newArrayList())), - new TableRow().set("inner", good1).set("stronearray", Lists.newArrayList()), - new TableRow().set("inner", good2).set("stronearray", Lists.newArrayList())); + return ImmutableList.of( + good1.clone().set("inner", new TableRow().set("stronearray", Lists.newArrayList())), + good2.clone().set("inner", new TableRow().set("stronearray", Lists.newArrayList())), + new TableRow().set("inner", good1).set("stronearray", Lists.newArrayList()), + new TableRow().set("inner", good2).set("stronearray", Lists.newArrayList())); + } + private List getInvalidBQBadRows(boolean isModifiedRows) { TableRow bad1 = new TableRow().set("str", "foo").set("i64", "42").set("date", "10001-08-16"); TableRow bad2 = new TableRow().set("str", "foo").set("i64", "42").set("strone", "ab"); TableRow bad3 = new TableRow().set("str", "foo").set("i64", "42").set("json", "BAADF00D"); @@ -179,25 +265,16 @@ public void testInvalidRowCaughtByBigquery() throws IOException, InterruptedExce .set("str", "foo") .set("i64", "42") .set("stronearray", Lists.newArrayList("toolong")); - TableRow bad5 = new TableRow().set("bytes", BIG_BYTES); - Iterable badRows = - ImmutableList.of( - bad1, - bad2, - bad3, - bad4, - bad5, - new TableRow().set("inner", bad1), - new TableRow().set("inner", bad2), - new TableRow().set("inner", bad3)); - - runPipeline( - getMethod(), - useStreamingExactlyOnce, - tableSpec, - Iterables.concat(goodRows, badRows), - badRows); - assertGoodRowsWritten(tableSpec, goodRows); + TableRow bad5 = new TableRow().set("bytes", isModifiedRows ? new byte[10] : BIG_BYTES); + return ImmutableList.of( + bad1, + bad2, + bad3, + bad4, + bad5, + new TableRow().set("inner", bad1.clone()), + new TableRow().set("inner", bad2.clone()), + new TableRow().set("inner", bad3.clone())); } private static String createTable(TableSchema tableSchema) @@ -240,7 +317,8 @@ private static void runPipeline( boolean triggered, String tableSpec, Iterable tableRows, - Iterable expectedFailedRows) { + Iterable expectedFailedRows, + @Nullable SerializableFunction userProvidedErrorFunction) { Pipeline p = Pipeline.create(); BigQueryIO.Write write = @@ -249,6 +327,9 @@ private static void runPipeline( .withSchema(BASE_TABLE_SCHEMA) .withMethod(method) .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_NEVER); + if (userProvidedErrorFunction != null) { + write = write.withFormatRecordOnFailureFunction(userProvidedErrorFunction); + } if (method == BigQueryIO.Write.Method.STORAGE_WRITE_API) { write = write.withNumStorageWriteApiStreams(1); if (triggered) {