diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index 04a81821d8e9..6745f7aceea0 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -68,6 +68,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.extensions.avro.io.AvroSource; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; @@ -75,7 +76,9 @@ import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath; import org.apache.beam.sdk.extensions.protobuf.ProtoCoder; import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; import org.apache.beam.sdk.io.FileSystems; +import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.fs.MoveOptions; import org.apache.beam.sdk.io.fs.ResolveOptions; import org.apache.beam.sdk.io.fs.ResourceId; @@ -1179,11 +1182,11 @@ public PCollection expand(PBegin input) { // if both toRowFn and fromRowFn values are set, enable Beam schema support Pipeline p = input.getPipeline(); + BigQueryOptions bqOptions = p.getOptions().as(BigQueryOptions.class); final BigQuerySourceDef sourceDef = createSourceDef(); Schema beamSchema = null; if (getTypeDescriptor() != null && getToBeamRowFn() != null && getFromBeamRowFn() != null) { - BigQueryOptions bqOptions = p.getOptions().as(BigQueryOptions.class); beamSchema = sourceDef.getBeamSchema(bqOptions); beamSchema = getFinalSchema(beamSchema, getSelectedFields()); } @@ -1191,7 +1194,7 @@ public PCollection expand(PBegin input) { final Coder coder = inferCoder(p.getCoderRegistry()); if (getMethod() == TypedRead.Method.DIRECT_READ) { - return expandForDirectRead(input, coder, beamSchema); + return expandForDirectRead(input, coder, beamSchema, bqOptions); } checkArgument( @@ -1369,7 +1372,7 @@ private static Schema getFinalSchema( } private PCollection expandForDirectRead( - PBegin input, Coder outputCoder, Schema beamSchema) { + PBegin input, Coder outputCoder, Schema beamSchema, BigQueryOptions bqOptions) { ValueProvider tableProvider = getTableProvider(); Pipeline p = input.getPipeline(); if (tableProvider != null) { @@ -1416,6 +1419,7 @@ private PCollection expandForDirectRead( // PCollectionView jobIdTokenView; + PCollectionTuple tuple; PCollection rows; if (!getWithTemplateCompatibility()) { @@ -1446,108 +1450,46 @@ public String apply(String input) { jobIdTokenView = jobIdTokenCollection.apply("ViewId", View.asSingleton()); TupleTag readStreamsTag = new TupleTag<>(); + TupleTag> listReadStreamsTag = new TupleTag<>(); TupleTag readSessionTag = new TupleTag<>(); TupleTag tableSchemaTag = new TupleTag<>(); - PCollectionTuple tuple = - jobIdTokenCollection.apply( - "RunQueryJob", - ParDo.of( - new DoFn() { - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - BigQueryOptions options = - c.getPipelineOptions().as(BigQueryOptions.class); - String jobUuid = c.element(); - // Execute the query and get the destination table holding the results. - // The getTargetTable call runs a new instance of the query and returns - // the destination table created to hold the results. - BigQueryStorageQuerySource querySource = - createStorageQuerySource(jobUuid, outputCoder); - Table queryResultTable = querySource.getTargetTable(options); - - // Create a read session without specifying a desired stream count and - // let the BigQuery storage server pick the number of streams. - CreateReadSessionRequest request = - CreateReadSessionRequest.newBuilder() - .setParent( - BigQueryHelpers.toProjectResourceName( - options.getBigQueryProject() == null - ? options.getProject() - : options.getBigQueryProject())) - .setReadSession( - ReadSession.newBuilder() - .setTable( - BigQueryHelpers.toTableResourceName( - queryResultTable.getTableReference())) - .setDataFormat(DataFormat.AVRO)) - .setMaxStreamCount(0) - .build(); - - ReadSession readSession; - try (StorageClient storageClient = - getBigQueryServices().getStorageClient(options)) { - readSession = storageClient.createReadSession(request); - } - - for (ReadStream readStream : readSession.getStreamsList()) { - c.output(readStream); - } - - c.output(readSessionTag, readSession); - c.output( - tableSchemaTag, - BigQueryHelpers.toJsonString(queryResultTable.getSchema())); - } - }) - .withOutputTags( - readStreamsTag, TupleTagList.of(readSessionTag).and(tableSchemaTag))); + if (!bqOptions.getEnableBundling()) { + tuple = + createTupleForDirectRead( + jobIdTokenCollection, + outputCoder, + readStreamsTag, + readSessionTag, + tableSchemaTag); + tuple.get(readStreamsTag).setCoder(ProtoCoder.of(ReadStream.class)); + } else { + tuple = + createTupleForDirectReadWithStreamBundle( + jobIdTokenCollection, + outputCoder, + listReadStreamsTag, + readSessionTag, + tableSchemaTag); + tuple.get(listReadStreamsTag).setCoder(ListCoder.of(ProtoCoder.of(ReadStream.class))); + } - tuple.get(readStreamsTag).setCoder(ProtoCoder.of(ReadStream.class)); tuple.get(readSessionTag).setCoder(ProtoCoder.of(ReadSession.class)); tuple.get(tableSchemaTag).setCoder(StringUtf8Coder.of()); - PCollectionView readSessionView = tuple.get(readSessionTag).apply("ReadSessionView", View.asSingleton()); PCollectionView tableSchemaView = tuple.get(tableSchemaTag).apply("TableSchemaView", View.asSingleton()); - rows = - tuple - .get(readStreamsTag) - .apply(Reshuffle.viaRandomKey()) - .apply( - ParDo.of( - new DoFn() { - @ProcessElement - public void processElement(ProcessContext c) throws Exception { - ReadSession readSession = c.sideInput(readSessionView); - TableSchema tableSchema = - BigQueryHelpers.fromJsonString( - c.sideInput(tableSchemaView), TableSchema.class); - ReadStream readStream = c.element(); - - BigQueryStorageStreamSource streamSource = - BigQueryStorageStreamSource.create( - readSession, - readStream, - tableSchema, - getParseFn(), - outputCoder, - getBigQueryServices()); - - // Read all of the data from the stream. In the event that this work - // item fails and is rescheduled, the same rows will be returned in - // the same order. - BoundedSource.BoundedReader reader = - streamSource.createReader(c.getPipelineOptions()); - for (boolean more = reader.start(); more; more = reader.advance()) { - c.output(reader.getCurrent()); - } - } - }) - .withSideInputs(readSessionView, tableSchemaView)) - .setCoder(outputCoder); + if (!bqOptions.getEnableBundling()) { + rows = + createPCollectionForDirectRead( + tuple, outputCoder, readStreamsTag, readSessionView, tableSchemaView); + } else { + rows = + createPCollectionForDirectReadWithStreamBundle( + tuple, outputCoder, listReadStreamsTag, readSessionView, tableSchemaView); + } } PassThroughThenCleanup.CleanupOperation cleanupOperation = @@ -1593,6 +1535,235 @@ void cleanup(ContextContainer c) throws Exception { return rows.apply(new PassThroughThenCleanup<>(cleanupOperation, jobIdTokenView)); } + private PCollectionTuple createTupleForDirectRead( + PCollection jobIdTokenCollection, + Coder outputCoder, + TupleTag readStreamsTag, + TupleTag readSessionTag, + TupleTag tableSchemaTag) { + PCollectionTuple tuple = + jobIdTokenCollection.apply( + "RunQueryJob", + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + BigQueryOptions options = + c.getPipelineOptions().as(BigQueryOptions.class); + String jobUuid = c.element(); + // Execute the query and get the destination table holding the results. + // The getTargetTable call runs a new instance of the query and returns + // the destination table created to hold the results. + BigQueryStorageQuerySource querySource = + createStorageQuerySource(jobUuid, outputCoder); + Table queryResultTable = querySource.getTargetTable(options); + + // Create a read session without specifying a desired stream count and + // let the BigQuery storage server pick the number of streams. + CreateReadSessionRequest request = + CreateReadSessionRequest.newBuilder() + .setParent( + BigQueryHelpers.toProjectResourceName( + options.getBigQueryProject() == null + ? options.getProject() + : options.getBigQueryProject())) + .setReadSession( + ReadSession.newBuilder() + .setTable( + BigQueryHelpers.toTableResourceName( + queryResultTable.getTableReference())) + .setDataFormat(DataFormat.AVRO)) + .setMaxStreamCount(0) + .build(); + + ReadSession readSession; + try (StorageClient storageClient = + getBigQueryServices().getStorageClient(options)) { + readSession = storageClient.createReadSession(request); + } + + for (ReadStream readStream : readSession.getStreamsList()) { + c.output(readStream); + } + + c.output(readSessionTag, readSession); + c.output( + tableSchemaTag, + BigQueryHelpers.toJsonString(queryResultTable.getSchema())); + } + }) + .withOutputTags( + readStreamsTag, TupleTagList.of(readSessionTag).and(tableSchemaTag))); + + return tuple; + } + + private PCollectionTuple createTupleForDirectReadWithStreamBundle( + PCollection jobIdTokenCollection, + Coder outputCoder, + TupleTag> listReadStreamsTag, + TupleTag readSessionTag, + TupleTag tableSchemaTag) { + + PCollectionTuple tuple = + jobIdTokenCollection.apply( + "RunQueryJob", + ParDo.of( + new DoFn>() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + BigQueryOptions options = + c.getPipelineOptions().as(BigQueryOptions.class); + String jobUuid = c.element(); + // Execute the query and get the destination table holding the results. + // The getTargetTable call runs a new instance of the query and returns + // the destination table created to hold the results. + BigQueryStorageQuerySource querySource = + createStorageQuerySource(jobUuid, outputCoder); + Table queryResultTable = querySource.getTargetTable(options); + + // Create a read session without specifying a desired stream count and + // let the BigQuery storage server pick the number of streams. + CreateReadSessionRequest request = + CreateReadSessionRequest.newBuilder() + .setParent( + BigQueryHelpers.toProjectResourceName( + options.getBigQueryProject() == null + ? options.getProject() + : options.getBigQueryProject())) + .setReadSession( + ReadSession.newBuilder() + .setTable( + BigQueryHelpers.toTableResourceName( + queryResultTable.getTableReference())) + .setDataFormat(DataFormat.AVRO)) + .setMaxStreamCount(0) + .build(); + + ReadSession readSession; + try (StorageClient storageClient = + getBigQueryServices().getStorageClient(options)) { + readSession = storageClient.createReadSession(request); + } + int streamIndex = 0; + int streamsPerBundle = 10; + List streamBundle = Lists.newArrayList(); + for (ReadStream readStream : readSession.getStreamsList()) { + streamIndex++; + streamBundle.add(readStream); + if (streamIndex % streamsPerBundle == 0) { + c.output(streamBundle); + streamBundle = Lists.newArrayList(); + } + } + if (streamIndex % streamsPerBundle != 0) { + c.output(streamBundle); + } + c.output(readSessionTag, readSession); + c.output( + tableSchemaTag, + BigQueryHelpers.toJsonString(queryResultTable.getSchema())); + } + }) + .withOutputTags( + listReadStreamsTag, TupleTagList.of(readSessionTag).and(tableSchemaTag))); + + return tuple; + } + + private PCollection createPCollectionForDirectRead( + PCollectionTuple tuple, + Coder outputCoder, + TupleTag readStreamsTag, + PCollectionView readSessionView, + PCollectionView tableSchemaView) { + PCollection rows = + tuple + .get(readStreamsTag) + .apply(Reshuffle.viaRandomKey()) + .apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + ReadSession readSession = c.sideInput(readSessionView); + TableSchema tableSchema = + BigQueryHelpers.fromJsonString( + c.sideInput(tableSchemaView), TableSchema.class); + ReadStream readStream = c.element(); + + BigQueryStorageStreamSource streamSource = + BigQueryStorageStreamSource.create( + readSession, + readStream, + tableSchema, + getParseFn(), + outputCoder, + getBigQueryServices()); + + // Read all of the data from the stream. In the event that this work + // item fails and is rescheduled, the same rows will be returned in + // the same order. + BoundedSource.BoundedReader reader = + streamSource.createReader(c.getPipelineOptions()); + for (boolean more = reader.start(); more; more = reader.advance()) { + c.output(reader.getCurrent()); + } + } + }) + .withSideInputs(readSessionView, tableSchemaView)) + .setCoder(outputCoder); + + return rows; + } + + private PCollection createPCollectionForDirectReadWithStreamBundle( + PCollectionTuple tuple, + Coder outputCoder, + TupleTag> listReadStreamsTag, + PCollectionView readSessionView, + PCollectionView tableSchemaView) { + PCollection rows = + tuple + .get(listReadStreamsTag) + .apply(Reshuffle.viaRandomKey()) + .apply( + ParDo.of( + new DoFn, T>() { + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + ReadSession readSession = c.sideInput(readSessionView); + TableSchema tableSchema = + BigQueryHelpers.fromJsonString( + c.sideInput(tableSchemaView), TableSchema.class); + List streamBundle = c.element(); + + BigQueryStorageStreamBundleSource streamSource = + BigQueryStorageStreamBundleSource.create( + readSession, + streamBundle, + tableSchema, + getParseFn(), + outputCoder, + getBigQueryServices(), + 1L); + + // Read all of the data from the stream. In the event that this work + // item fails and is rescheduled, the same rows will be returned in + // the same order. + BoundedReader reader = + streamSource.createReader(c.getPipelineOptions()); + for (boolean more = reader.start(); more; more = reader.advance()) { + c.output(reader.getCurrent()); + } + } + }) + .withSideInputs(readSessionView, tableSchemaView)) + .setCoder(outputCoder); + + return rows; + } + @Override public void populateDisplayData(DisplayData.Builder builder) { super.populateDisplayData(builder); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryOptions.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryOptions.java index bf09bf4d9e37..938d131a0da5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryOptions.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryOptions.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.options.ApplicationNameOptions; import org.apache.beam.sdk.options.Default; @@ -163,4 +165,13 @@ public interface BigQueryOptions Long getStorageWriteApiMaxRequestSize(); void setStorageWriteApiMaxRequestSize(Long value); + + @Experimental(Kind.UNSPECIFIED) + @Description( + "If set, BigQueryIO.Read will use the StreamBundle based" + + "implementation of the Read API Source") + @Default.Boolean(false) + Boolean getEnableBundling(); + + void setEnableBundling(Boolean value); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java index 27b88dc39600..834409062ccd 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java @@ -100,7 +100,7 @@ public Coder getOutputCoder() { } @Override - public List> split( + public List> split( long desiredBundleSizeBytes, PipelineOptions options) throws Exception { BigQueryOptions bqOptions = options.as(BigQueryOptions.class); @Nullable Table targetTable = getTargetTable(bqOptions); @@ -133,13 +133,18 @@ public List> split( readSessionBuilder.setDataFormat(format); } + // Setting the requested max stream count to 0, implies that the Read API backend will select + // an appropriate number of streams for the Session to produce reasonable throughput. + // This is required when using the Read API Source V2. int streamCount = 0; - if (desiredBundleSizeBytes > 0) { - long tableSizeBytes = (targetTable != null) ? targetTable.getNumBytes() : 0; - streamCount = (int) Math.min(tableSizeBytes / desiredBundleSizeBytes, MAX_SPLIT_COUNT); - } + if (!bqOptions.getEnableBundling()) { + if (desiredBundleSizeBytes > 0) { + long tableSizeBytes = (targetTable != null) ? targetTable.getNumBytes() : 0; + streamCount = (int) Math.min(tableSizeBytes / desiredBundleSizeBytes, MAX_SPLIT_COUNT); + } - streamCount = Math.max(streamCount, MIN_SPLIT_COUNT); + streamCount = Math.max(streamCount, MIN_SPLIT_COUNT); + } CreateReadSessionRequest createReadSessionRequest = CreateReadSessionRequest.newBuilder() @@ -166,6 +171,25 @@ public List> split( return ImmutableList.of(); } + streamCount = readSession.getStreamsList().size(); + int streamsPerBundle = 0; + double bytesPerStream = 0; + LOG.info( + "Estimated bytes this ReadSession will scan when all Streams are consumed: '{}'", + readSession.getEstimatedTotalBytesScanned()); + if (bqOptions.getEnableBundling()) { + if (desiredBundleSizeBytes > 0) { + bytesPerStream = + (double) readSession.getEstimatedTotalBytesScanned() / readSession.getStreamsCount(); + LOG.info("Estimated bytes each Stream will consume: '{}'", bytesPerStream); + streamsPerBundle = (int) Math.ceil(desiredBundleSizeBytes / bytesPerStream); + } else { + streamsPerBundle = (int) Math.ceil((double) streamCount / 10); + } + streamsPerBundle = Math.min(streamCount, streamsPerBundle); + LOG.info("Distributing '{}' Streams per StreamBundle.", streamsPerBundle); + } + Schema sessionSchema; if (readSession.getDataFormat() == DataFormat.ARROW) { org.apache.arrow.vector.types.pojo.Schema schema = @@ -180,18 +204,37 @@ public List> split( throw new IllegalArgumentException( "data is not in a supported dataFormat: " + readSession.getDataFormat()); } - + int streamIndex = 0; Preconditions.checkStateNotNull( targetTable); // TODO: this is inconsistent with method above, where it can be null TableSchema trimmedSchema = BigQueryAvroUtils.trimBigQueryTableSchema(targetTable.getSchema(), sessionSchema); - List> sources = Lists.newArrayList(); + if (!bqOptions.getEnableBundling()) { + List> sources = Lists.newArrayList(); + for (ReadStream readStream : readSession.getStreamsList()) { + sources.add( + BigQueryStorageStreamSource.create( + readSession, readStream, trimmedSchema, parseFn, outputCoder, bqServices)); + } + return ImmutableList.copyOf(sources); + } + List streamBundle = Lists.newArrayList(); + List> sources = Lists.newArrayList(); for (ReadStream readStream : readSession.getStreamsList()) { + streamIndex++; + streamBundle.add(readStream); + if (streamIndex % streamsPerBundle == 0) { + sources.add( + BigQueryStorageStreamBundleSource.create( + readSession, streamBundle, trimmedSchema, parseFn, outputCoder, bqServices, 1L)); + streamBundle = Lists.newArrayList(); + } + } + if (streamIndex % streamsPerBundle != 0) { sources.add( - BigQueryStorageStreamSource.create( - readSession, readStream, trimmedSchema, parseFn, outputCoder, bqServices)); + BigQueryStorageStreamBundleSource.create( + readSession, streamBundle, trimmedSchema, parseFn, outputCoder, bqServices, 1L)); } - return ImmutableList.copyOf(sources); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamBundleSource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamBundleSource.java new file mode 100644 index 000000000000..42e99b6aae38 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamBundleSource.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.fromJsonString; +import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.toJsonString; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; + +import com.google.api.gax.rpc.ApiException; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.bigquery.storage.v1.ReadRowsRequest; +import com.google.cloud.bigquery.storage.v1.ReadRowsResponse; +import com.google.cloud.bigquery.storage.v1.ReadSession; +import com.google.cloud.bigquery.storage.v1.ReadStream; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import org.apache.beam.runners.core.metrics.ServiceCallMetric; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.OffsetBasedSource; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.BigQueryServerStream; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StorageClient; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.util.Preconditions; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.RequiresNonNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link org.apache.beam.sdk.io.Source} representing a bundle of Streams in a BigQuery ReadAPI + * Session. This Source ONLY supports splitting at the StreamBundle level. + * + *

{@link BigQueryStorageStreamBundleSource} defines a split-point as the starting offset of each + * Stream. As a result, the number of valid split points in the Source is equal to the number of + * Streams in the StreamBundle and this Source does NOT support sub-Stream splitting. + * + *

Additionally, the underlying {@link org.apache.beam.sdk.io.range.OffsetRangeTracker} and + * {@link OffsetBasedSource} operate in the split point space and do NOT directly interact with the + * Streams constituting the StreamBundle. Consequently, fractional values used in + * `splitAtFraction()` are translated into StreamBundleIndices and the underlying RangeTracker + * handles the split operation by checking the validity of the split point. This has the following + * implications for the `splitAtFraction()` operation: + * + *

1. Fraction values that point to the "middle" of a Stream will be translated to the + * appropriate Stream boundary by the RangeTracker. + * + *

2. Once a Stream is being read from, the RangeTracker will only accept `splitAtFraction()` + * calls that point to StreamBundleIndices that are greater than the StreamBundleIndex of the + * current Stream + * + * @param Type of records represented by the source. + * @see OffsetBasedSource + * @see org.apache.beam.sdk.io.range.OffsetRangeTracker + * @see org.apache.beam.sdk.io.BlockBasedSource (semantically similar to {@link + * BigQueryStorageStreamBundleSource}) + */ +class BigQueryStorageStreamBundleSource extends OffsetBasedSource { + + public static BigQueryStorageStreamBundleSource create( + ReadSession readSession, + List streamBundle, + TableSchema tableSchema, + SerializableFunction parseFn, + Coder outputCoder, + BigQueryServices bqServices, + long minBundleSize) { + return new BigQueryStorageStreamBundleSource<>( + readSession, + streamBundle, + toJsonString(Preconditions.checkArgumentNotNull(tableSchema, "tableSchema")), + parseFn, + outputCoder, + bqServices, + minBundleSize); + } + + /** + * Creates a new source with the same properties as this one, except with a different {@link + * List}. + */ + public BigQueryStorageStreamBundleSource fromExisting(List newStreamBundle) { + return new BigQueryStorageStreamBundleSource<>( + readSession, + newStreamBundle, + jsonTableSchema, + parseFn, + outputCoder, + bqServices, + getMinBundleSize()); + } + + private final ReadSession readSession; + private final List streamBundle; + private final String jsonTableSchema; + private final SerializableFunction parseFn; + private final Coder outputCoder; + private final BigQueryServices bqServices; + + private BigQueryStorageStreamBundleSource( + ReadSession readSession, + List streamBundle, + String jsonTableSchema, + SerializableFunction parseFn, + Coder outputCoder, + BigQueryServices bqServices, + long minBundleSize) { + super(0, streamBundle.size(), minBundleSize); + this.readSession = Preconditions.checkArgumentNotNull(readSession, "readSession"); + this.streamBundle = Preconditions.checkArgumentNotNull(streamBundle, "streams"); + this.jsonTableSchema = Preconditions.checkArgumentNotNull(jsonTableSchema, "jsonTableSchema"); + this.parseFn = Preconditions.checkArgumentNotNull(parseFn, "parseFn"); + this.outputCoder = Preconditions.checkArgumentNotNull(outputCoder, "outputCoder"); + this.bqServices = Preconditions.checkArgumentNotNull(bqServices, "bqServices"); + } + + @Override + public Coder getOutputCoder() { + return outputCoder; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder + .add(DisplayData.item("table", readSession.getTable()).withLabel("Table")) + .add(DisplayData.item("readSession", readSession.getName()).withLabel("Read session")); + for (ReadStream readStream : streamBundle) { + builder.add(DisplayData.item("stream", readStream.getName()).withLabel("Stream")); + } + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) { + // The size of stream source can't be estimated due to server-side liquid sharding. + // TODO: Implement progress reporting. + return 0L; + } + + @Override + public List> split( + long desiredBundleSizeBytes, PipelineOptions options) { + // This method is only called for initial splits. Since this class will always be a child source + // of BigQueryStorageSourceBase, all splits here will be handled by `splitAtFraction()`. As a + // result, this is a no-op. + return ImmutableList.of(this); + } + + @Override + public long getMaxEndOffset(PipelineOptions options) throws Exception { + return this.streamBundle.size(); + } + + @Override + public OffsetBasedSource createSourceForSubrange(long start, long end) { + List newStreamBundle = streamBundle.subList((int) start, (int) end); + return fromExisting(newStreamBundle); + } + + @Override + public BigQueryStorageStreamBundleReader createReader(PipelineOptions options) + throws IOException { + return new BigQueryStorageStreamBundleReader<>(this, options.as(BigQueryOptions.class)); + } + + public static class BigQueryStorageStreamBundleReader extends OffsetBasedReader { + private static final Logger LOG = + LoggerFactory.getLogger(BigQueryStorageStreamBundleReader.class); + + private final BigQueryStorageReader reader; + private final SerializableFunction parseFn; + private final StorageClient storageClient; + private final TableSchema tableSchema; + + private BigQueryStorageStreamBundleSource source; + private @Nullable BigQueryServerStream responseStream = null; + private @Nullable Iterator responseIterator = null; + private @Nullable T current = null; + private int currentStreamBundleIndex; + private long currentStreamOffset; + + // Values used for progress reporting. + private double fractionOfStreamBundleConsumed; + + private double progressAtResponseStart; + private double progressAtResponseEnd; + private long rowsConsumedFromCurrentResponse; + private long totalRowsInCurrentResponse; + + private @Nullable TableReference tableReference; + private @Nullable ServiceCallMetric serviceCallMetric; + + private BigQueryStorageStreamBundleReader( + BigQueryStorageStreamBundleSource source, BigQueryOptions options) throws IOException { + super(source); + this.source = source; + this.reader = BigQueryStorageReaderFactory.getReader(source.readSession); + this.parseFn = source.parseFn; + this.storageClient = source.bqServices.getStorageClient(options); + this.tableSchema = fromJsonString(source.jsonTableSchema, TableSchema.class); + this.currentStreamBundleIndex = 0; + this.fractionOfStreamBundleConsumed = 0d; + this.progressAtResponseStart = 0d; + this.progressAtResponseEnd = 0d; + this.rowsConsumedFromCurrentResponse = 0L; + this.totalRowsInCurrentResponse = 0L; + } + + @Override + public T getCurrent() throws NoSuchElementException { + if (current == null) { + throw new NoSuchElementException(); + } + return current; + } + + @Override + protected long getCurrentOffset() throws NoSuchElementException { + return currentStreamBundleIndex; + } + + @Override + protected boolean isAtSplitPoint() throws NoSuchElementException { + if (currentStreamOffset == 0) { + return true; + } + return false; + } + + @Override + public boolean startImpl() throws IOException { + return readNextStream(); + } + + @Override + public boolean advanceImpl() throws IOException { + Preconditions.checkStateNotNull(responseIterator); + currentStreamOffset += totalRowsInCurrentResponse; + return readNextRecord(); + } + + private boolean readNextStream() throws IOException { + BigQueryStorageStreamBundleSource source = getCurrentSource(); + if (currentStreamBundleIndex == source.streamBundle.size()) { + fractionOfStreamBundleConsumed = 1d; + return false; + } + ReadRowsRequest request = + ReadRowsRequest.newBuilder() + .setReadStream(source.streamBundle.get(currentStreamBundleIndex).getName()) + .build(); + tableReference = BigQueryUtils.toTableReference(source.readSession.getTable()); + serviceCallMetric = BigQueryUtils.readCallMetric(tableReference); + LOG.info( + "Started BigQuery Storage API read from stream {}.", + source.streamBundle.get(currentStreamBundleIndex).getName()); + responseStream = storageClient.readRows(request, source.readSession.getTable()); + responseIterator = responseStream.iterator(); + return readNextRecord(); + } + + @RequiresNonNull("responseIterator") + private boolean readNextRecord() throws IOException { + Iterator responseIterator = this.responseIterator; + if (responseIterator == null) { + LOG.info("Received null responseIterator for stream {}", currentStreamBundleIndex); + return false; + } + while (reader.readyForNextReadResponse()) { + if (!responseIterator.hasNext()) { + synchronized (this) { + currentStreamOffset = 0; + currentStreamBundleIndex++; + } + return readNextStream(); + } + + ReadRowsResponse response; + try { + response = responseIterator.next(); + // Since we don't have a direct hook to the underlying + // API call, record success every time we read a record successfully. + if (serviceCallMetric != null) { + serviceCallMetric.call("ok"); + } + } catch (ApiException e) { + // Occasionally the iterator will fail and raise an exception. + // Capture it here and record the error in the metric. + if (serviceCallMetric != null) { + serviceCallMetric.call(e.getStatusCode().getCode().name()); + } + throw e; + } + + progressAtResponseStart = response.getStats().getProgress().getAtResponseStart(); + progressAtResponseEnd = response.getStats().getProgress().getAtResponseEnd(); + totalRowsInCurrentResponse = response.getRowCount(); + rowsConsumedFromCurrentResponse = 0L; + + checkArgument( + totalRowsInCurrentResponse >= 0, + "Row count from current response (%s) must be non-negative.", + totalRowsInCurrentResponse); + + checkArgument( + 0f <= progressAtResponseStart && progressAtResponseStart <= 1f, + "Progress at response start (%s) is not in the range [0.0, 1.0].", + progressAtResponseStart); + + checkArgument( + 0f <= progressAtResponseEnd && progressAtResponseEnd <= 1f, + "Progress at response end (%s) is not in the range [0.0, 1.0].", + progressAtResponseEnd); + reader.processReadRowsResponse(response); + } + + SchemaAndRecord schemaAndRecord = new SchemaAndRecord(reader.readSingleRecord(), tableSchema); + + current = parseFn.apply(schemaAndRecord); + + // Calculates the fraction of the current stream that has been consumed. This value is + // calculated by interpolating between the fraction consumed value from the previous server + // response (or zero if we're consuming the first response) and the fractional value in the + // current response based on how many of the rows in the current response have been consumed. + rowsConsumedFromCurrentResponse++; + + double fractionOfCurrentStreamConsumed = + progressAtResponseStart + + ((progressAtResponseEnd - progressAtResponseStart) + * (rowsConsumedFromCurrentResponse * 1.0 / totalRowsInCurrentResponse)); + + // We now calculate the progress made over the entire StreamBundle by assuming that each + // Stream in the StreamBundle has approximately the same amount of data. Given this, merely + // counting the number of Streams that have been read and linearly interpolating with the + // progress made in the current Stream gives us the overall StreamBundle progress. + fractionOfStreamBundleConsumed = + (currentStreamBundleIndex + fractionOfCurrentStreamConsumed) / source.streamBundle.size(); + return true; + } + + @Override + public synchronized void close() { + // Because superclass cannot have preconditions around these variables, cannot use + // @RequiresNonNull + Preconditions.checkStateNotNull(storageClient); + Preconditions.checkStateNotNull(reader); + storageClient.close(); + reader.close(); + } + + @Override + public synchronized BigQueryStorageStreamBundleSource getCurrentSource() { + return source; + } + + @Override + public synchronized Double getFractionConsumed() { + return fractionOfStreamBundleConsumed; + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadWithStreamBundleSourceTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadWithStreamBundleSourceTest.java new file mode 100644 index 000000000000..fc1ccd3c8914 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadWithStreamBundleSourceTest.java @@ -0,0 +1,2156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static java.util.Arrays.asList; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import com.google.api.services.bigquery.model.Streamingbuffer; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.bigquery.storage.v1.ArrowRecordBatch; +import com.google.cloud.bigquery.storage.v1.ArrowSchema; +import com.google.cloud.bigquery.storage.v1.AvroRows; +import com.google.cloud.bigquery.storage.v1.AvroSchema; +import com.google.cloud.bigquery.storage.v1.CreateReadSessionRequest; +import com.google.cloud.bigquery.storage.v1.DataFormat; +import com.google.cloud.bigquery.storage.v1.ReadRowsRequest; +import com.google.cloud.bigquery.storage.v1.ReadRowsResponse; +import com.google.cloud.bigquery.storage.v1.ReadSession; +import com.google.cloud.bigquery.storage.v1.ReadStream; +import com.google.cloud.bigquery.storage.v1.StreamStats; +import com.google.cloud.bigquery.storage.v1.StreamStats.Progress; +import com.google.protobuf.ByteString; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.math.BigInteger; +import java.nio.channels.Channels; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.util.Text; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericData.Record; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.io.Encoder; +import org.apache.avro.io.EncoderFactory; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.extensions.protobuf.ByteStringCoder; +import org.apache.beam.sdk.extensions.protobuf.ProtoCoder; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TableRowParser; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StorageClient; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryStorageStreamBundleSource.BigQueryStorageStreamBundleReader; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils.ConversionOptions; +import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; +import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices.FakeBigQueryServerStream; +import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.schemas.transforms.Convert; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.junit.runners.model.Statement; + +/** + * Tests for {@link BigQueryIO#readTableRows() using {@link Method#DIRECT_READ}} AND {@link + * BigQueryOptions#setEnableBundling(Boolean)} (Boolean)} set to True. + */ +@RunWith(JUnit4.class) +public class BigQueryIOStorageReadWithStreamBundleSourceTest { + + private transient PipelineOptions options; + private final transient TemporaryFolder testFolder = new TemporaryFolder(); + private transient TestPipeline p; + private BufferAllocator allocator; + + @Rule + public final transient TestRule folderThenPipeline = + new TestRule() { + @Override + public Statement apply(Statement base, Description description) { + // We need to set up the temporary folder, and then set up the TestPipeline based on the + // chosen folder. Unfortunately, since rule evaluation order is unspecified and unrelated + // to field order, and is separate from construction, that requires manually creating this + // TestRule. + Statement withPipeline = + new Statement() { + @Override + public void evaluate() throws Throwable { + options = TestPipeline.testingPipelineOptions(); + options.as(BigQueryOptions.class).setProject("project-id"); + if (description.getAnnotations().stream() + .anyMatch(a -> a.annotationType().equals(ProjectOverride.class))) { + options.as(BigQueryOptions.class).setBigQueryProject("bigquery-project-id"); + } + options + .as(BigQueryOptions.class) + .setTempLocation(testFolder.getRoot().getAbsolutePath()); + options.as(BigQueryOptions.class).setEnableBundling(true); + p = TestPipeline.fromOptions(options); + p.apply(base, description).evaluate(); + } + }; + return testFolder.apply(withPipeline, description); + } + }; + + @Rule public transient ExpectedException thrown = ExpectedException.none(); + + private final FakeDatasetService fakeDatasetService = new FakeDatasetService(); + + @Before + public void setUp() throws Exception { + FakeDatasetService.setUp(); + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @After + public void teardown() { + allocator.close(); + } + + @Test + public void testBuildTableBasedSource() { + BigQueryIO.TypedRead typedRead = + BigQueryIO.read(new TableRowParser()) + .withCoder(TableRowJsonCoder.of()) + .withMethod(Method.DIRECT_READ) + .from("foo.com:project:dataset.table"); + checkTypedReadTableObject(typedRead, "foo.com:project", "dataset", "table"); + assertTrue(typedRead.getValidate()); + } + + @Test + public void testBuildTableBasedSourceWithoutValidation() { + BigQueryIO.TypedRead typedRead = + BigQueryIO.read(new TableRowParser()) + .withCoder(TableRowJsonCoder.of()) + .withMethod(Method.DIRECT_READ) + .from("foo.com:project:dataset.table") + .withoutValidation(); + checkTypedReadTableObject(typedRead, "foo.com:project", "dataset", "table"); + assertFalse(typedRead.getValidate()); + } + + @Test + public void testBuildTableBasedSourceWithDefaultProject() { + BigQueryIO.TypedRead typedRead = + BigQueryIO.read(new TableRowParser()) + .withCoder(TableRowJsonCoder.of()) + .withMethod(Method.DIRECT_READ) + .from("myDataset.myTable"); + checkTypedReadTableObject(typedRead, null, "myDataset", "myTable"); + } + + @Test + public void testBuildTableBasedSourceWithTableReference() { + TableReference tableReference = + new TableReference() + .setProjectId("foo.com:project") + .setDatasetId("dataset") + .setTableId("table"); + BigQueryIO.TypedRead typedRead = + BigQueryIO.read(new TableRowParser()) + .withCoder(TableRowJsonCoder.of()) + .withMethod(Method.DIRECT_READ) + .from(tableReference); + checkTypedReadTableObject(typedRead, "foo.com:project", "dataset", "table"); + } + + private void checkTypedReadTableObject( + TypedRead typedRead, String project, String dataset, String table) { + assertEquals(project, typedRead.getTable().getProjectId()); + assertEquals(dataset, typedRead.getTable().getDatasetId()); + assertEquals(table, typedRead.getTable().getTableId()); + assertNull(typedRead.getQuery()); + assertEquals(Method.DIRECT_READ, typedRead.getMethod()); + } + + @Test + public void testBuildSourceWithTableAndFlatten() { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage( + "Invalid BigQueryIO.Read: Specifies a table with a result flattening preference," + + " which only applies to queries"); + p.apply( + "ReadMyTable", + BigQueryIO.read(new TableRowParser()) + .withCoder(TableRowJsonCoder.of()) + .withMethod(Method.DIRECT_READ) + .from("foo.com:project:dataset.table") + .withoutResultFlattening()); + p.run(); + } + + @Test + public void testBuildSourceWithTableAndSqlDialect() { + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage( + "Invalid BigQueryIO.Read: Specifies a table with a SQL dialect preference," + + " which only applies to queries"); + p.apply( + "ReadMyTable", + BigQueryIO.read(new TableRowParser()) + .withCoder(TableRowJsonCoder.of()) + .withMethod(Method.DIRECT_READ) + .from("foo.com:project:dataset.table") + .usingStandardSql()); + p.run(); + } + + @Test + public void testDisplayData() { + String tableSpec = "foo.com:project:dataset.table"; + BigQueryIO.TypedRead typedRead = + BigQueryIO.read(new TableRowParser()) + .withCoder(TableRowJsonCoder.of()) + .withMethod(Method.DIRECT_READ) + .withSelectedFields(ImmutableList.of("foo", "bar")) + .withProjectionPushdownApplied() + .from(tableSpec); + DisplayData displayData = DisplayData.from(typedRead); + assertThat(displayData, hasDisplayItem("table", tableSpec)); + assertThat(displayData, hasDisplayItem("selectedFields", "foo, bar")); + assertThat(displayData, hasDisplayItem("projectionPushdownApplied", true)); + } + + @Test + public void testName() { + assertEquals( + "BigQueryIO.TypedRead", + BigQueryIO.read(new TableRowParser()) + .withCoder(TableRowJsonCoder.of()) + .withMethod(Method.DIRECT_READ) + .from("foo.com:project:dataset.table") + .getName()); + } + + @Test + public void testCoderInference() { + // Lambdas erase too much type information -- use an anonymous class here. + SerializableFunction> parseFn = + new SerializableFunction>() { + @Override + public KV apply(SchemaAndRecord input) { + return null; + } + }; + + assertEquals( + KvCoder.of(ByteStringCoder.of(), ProtoCoder.of(ReadSession.class)), + BigQueryIO.read(parseFn).inferCoder(CoderRegistry.createDefault())); + } + + @Test + public void testTableSourceEstimatedSize() throws Exception { + doTableSourceEstimatedSizeTest(false); + } + + @Test + public void testTableSourceEstimatedSize_IgnoresStreamingBuffer() throws Exception { + doTableSourceEstimatedSizeTest(true); + } + + private void doTableSourceEstimatedSizeTest(boolean useStreamingBuffer) throws Exception { + fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table"); + Table table = new Table().setTableReference(tableRef).setNumBytes(100L); + if (useStreamingBuffer) { + table.setStreamingBuffer(new Streamingbuffer().setEstimatedBytes(BigInteger.TEN)); + } + + fakeDatasetService.createTable(table); + + BigQueryStorageTableSource tableSource = + BigQueryStorageTableSource.create( + ValueProvider.StaticValueProvider.of(tableRef), + null, + null, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withDatasetService(fakeDatasetService)); + + assertEquals(100, tableSource.getEstimatedSizeBytes(options)); + } + + @Test + @ProjectOverride + public void testTableSourceEstimatedSize_WithBigQueryProject() throws Exception { + fakeDatasetService.createDataset("bigquery-project-id", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("bigquery-project-id:dataset.table"); + Table table = new Table().setTableReference(tableRef).setNumBytes(100L); + fakeDatasetService.createTable(table); + + BigQueryStorageTableSource tableSource = + BigQueryStorageTableSource.create( + ValueProvider.StaticValueProvider.of(BigQueryHelpers.parseTableSpec("dataset.table")), + null, + null, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withDatasetService(fakeDatasetService)); + + assertEquals(100, tableSource.getEstimatedSizeBytes(options)); + } + + @Test + public void testTableSourceEstimatedSize_WithDefaultProject() throws Exception { + fakeDatasetService.createDataset("project-id", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("project-id:dataset.table"); + Table table = new Table().setTableReference(tableRef).setNumBytes(100L); + fakeDatasetService.createTable(table); + + BigQueryStorageTableSource tableSource = + BigQueryStorageTableSource.create( + ValueProvider.StaticValueProvider.of(BigQueryHelpers.parseTableSpec("dataset.table")), + null, + null, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withDatasetService(fakeDatasetService)); + + assertEquals(100, tableSource.getEstimatedSizeBytes(options)); + } + + private static final String AVRO_SCHEMA_STRING = + "{\"namespace\": \"example.avro\",\n" + + " \"type\": \"record\",\n" + + " \"name\": \"RowRecord\",\n" + + " \"fields\": [\n" + + " {\"name\": \"name\", \"type\": \"string\"},\n" + + " {\"name\": \"number\", \"type\": \"long\"}\n" + + " ]\n" + + "}"; + + private static final Schema AVRO_SCHEMA = new Schema.Parser().parse(AVRO_SCHEMA_STRING); + + private static final String TRIMMED_AVRO_SCHEMA_STRING = + "{\"namespace\": \"example.avro\",\n" + + "\"type\": \"record\",\n" + + "\"name\": \"RowRecord\",\n" + + "\"fields\": [\n" + + " {\"name\": \"name\", \"type\": \"string\"}\n" + + " ]\n" + + "}"; + + private static final Schema TRIMMED_AVRO_SCHEMA = + new Schema.Parser().parse(TRIMMED_AVRO_SCHEMA_STRING); + + private static final TableSchema TABLE_SCHEMA = + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("name").setType("STRING").setMode("REQUIRED"), + new TableFieldSchema().setName("number").setType("INTEGER").setMode("REQUIRED"))); + + private static final org.apache.arrow.vector.types.pojo.Schema ARROW_SCHEMA = + new org.apache.arrow.vector.types.pojo.Schema( + asList( + field("name", new ArrowType.Utf8()), field("number", new ArrowType.Int(64, true)))); + + private void doTableSourceInitialSplitTest(long bundleSize, long tableSize, int streamCount) + throws Exception { + fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table"); + + Table table = + new Table().setTableReference(tableRef).setNumBytes(tableSize).setSchema(TABLE_SCHEMA); + + fakeDatasetService.createTable(table); + + CreateReadSessionRequest expectedRequest = + CreateReadSessionRequest.newBuilder() + .setParent("projects/project-id") + .setReadSession( + ReadSession.newBuilder() + .setTable("projects/foo.com:project/datasets/dataset/tables/table")) + .setMaxStreamCount(0) + .build(); + + ReadSession.Builder builder = + ReadSession.newBuilder() + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .setDataFormat(DataFormat.AVRO) + .setEstimatedTotalBytesScanned(tableSize); + for (int i = 0; i < streamCount; i++) { + builder.addStreams(ReadStream.newBuilder().setName("stream-" + i)); + } + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.createReadSession(expectedRequest)).thenReturn(builder.build()); + + BigQueryStorageTableSource tableSource = + BigQueryStorageTableSource.create( + ValueProvider.StaticValueProvider.of(tableRef), + null, + null, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withStorageClient(fakeStorageClient)); + + List> sources = tableSource.split(bundleSize, options); + // Each StreamBundle is expected to contain a single stream. + assertEquals(streamCount, sources.size()); + } + + @Test + public void testTableSourceInitialSplit() throws Exception { + doTableSourceInitialSplitTest(1024L, 1024L * 1024L, 1024); + } + + @Test + public void testTableSourceInitialSplit_MinSplitCount() throws Exception { + doTableSourceInitialSplitTest(1024L, 1024L * 1024L, 10); + } + + @Test + public void testTableSourceInitialSplit_MaxSplitCount() throws Exception { + doTableSourceInitialSplitTest(10L, 1024L * 1024L, 10_000); + } + + @Test + public void testTableSourceInitialSplit_WithSelectedFieldsAndRowRestriction() throws Exception { + fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table"); + + Table table = new Table().setTableReference(tableRef).setNumBytes(200L).setSchema(TABLE_SCHEMA); + + fakeDatasetService.createTable(table); + + CreateReadSessionRequest expectedRequest = + CreateReadSessionRequest.newBuilder() + .setParent("projects/project-id") + .setReadSession( + ReadSession.newBuilder() + .setTable("projects/foo.com:project/datasets/dataset/tables/table") + .setReadOptions( + ReadSession.TableReadOptions.newBuilder() + .addSelectedFields("name") + .setRowRestriction("number > 5"))) + .setMaxStreamCount(0) + .build(); + + ReadSession.Builder builder = + ReadSession.newBuilder() + .setAvroSchema(AvroSchema.newBuilder().setSchema(TRIMMED_AVRO_SCHEMA_STRING)) + .setDataFormat(DataFormat.AVRO) + .setEstimatedTotalBytesScanned(100L); + for (int i = 0; i < 10; i++) { + builder.addStreams(ReadStream.newBuilder().setName("stream-" + i)); + } + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.createReadSession(expectedRequest)).thenReturn(builder.build()); + + BigQueryStorageTableSource tableSource = + BigQueryStorageTableSource.create( + ValueProvider.StaticValueProvider.of(tableRef), + StaticValueProvider.of(Lists.newArrayList("name")), + StaticValueProvider.of("number > 5"), + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withStorageClient(fakeStorageClient)); + + List> sources = tableSource.split(20L, options); + assertEquals(5, sources.size()); + } + + @Test + public void testTableSourceInitialSplit_WithDefaultProject() throws Exception { + fakeDatasetService.createDataset("project-id", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("project-id:dataset.table"); + + Table table = + new Table().setTableReference(tableRef).setNumBytes(1024L).setSchema(TABLE_SCHEMA); + + fakeDatasetService.createTable(table); + + CreateReadSessionRequest expectedRequest = + CreateReadSessionRequest.newBuilder() + .setParent("projects/project-id") + .setReadSession( + ReadSession.newBuilder() + .setTable("projects/project-id/datasets/dataset/tables/table")) + .setMaxStreamCount(0) + .build(); + + ReadSession.Builder builder = + ReadSession.newBuilder() + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .setDataFormat(DataFormat.AVRO) + .setEstimatedTotalBytesScanned(1024L); + for (int i = 0; i < 50; i++) { + builder.addStreams(ReadStream.newBuilder().setName("stream-" + i)); + } + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.createReadSession(expectedRequest)).thenReturn(builder.build()); + + BigQueryStorageTableSource tableSource = + BigQueryStorageTableSource.create( + ValueProvider.StaticValueProvider.of(BigQueryHelpers.parseTableSpec("dataset.table")), + null, + null, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withStorageClient(fakeStorageClient)); + + List> sources = tableSource.split(4096L, options); + // A single StreamBundle containing all the Streams. + assertEquals(1, sources.size()); + } + + @Test + public void testTableSourceInitialSplit_EmptyTable() throws Exception { + fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table"); + + Table table = + new Table() + .setTableReference(tableRef) + .setNumBytes(1024L * 1024L) + .setSchema(new TableSchema()); + + fakeDatasetService.createTable(table); + + CreateReadSessionRequest expectedRequest = + CreateReadSessionRequest.newBuilder() + .setParent("projects/project-id") + .setReadSession( + ReadSession.newBuilder() + .setTable("projects/foo.com:project/datasets/dataset/tables/table")) + .setMaxStreamCount(0) + .build(); + + ReadSession emptyReadSession = ReadSession.newBuilder().build(); + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.createReadSession(expectedRequest)).thenReturn(emptyReadSession); + + BigQueryStorageTableSource tableSource = + BigQueryStorageTableSource.create( + ValueProvider.StaticValueProvider.of(tableRef), + null, + null, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withStorageClient(fakeStorageClient)); + + List> sources = tableSource.split(1024L, options); + assertTrue(sources.isEmpty()); + } + + @Test + public void testTableSourceCreateReader() throws Exception { + BigQueryStorageTableSource tableSource = + BigQueryStorageTableSource.create( + ValueProvider.StaticValueProvider.of( + BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table")), + null, + null, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withDatasetService(fakeDatasetService)); + + thrown.expect(UnsupportedOperationException.class); + thrown.expectMessage("BigQuery storage source must be split before reading"); + tableSource.createReader(options); + } + + private static GenericRecord createRecord(String name, Schema schema) { + GenericRecord genericRecord = new Record(schema); + genericRecord.put("name", name); + return genericRecord; + } + + private static GenericRecord createRecord(String name, long number, Schema schema) { + GenericRecord genericRecord = new Record(schema); + genericRecord.put("name", name); + genericRecord.put("number", number); + return genericRecord; + } + + private static ByteString serializeArrowSchema( + org.apache.arrow.vector.types.pojo.Schema arrowSchema) { + ByteArrayOutputStream byteOutputStream = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize( + new WriteChannel(Channels.newChannel(byteOutputStream)), arrowSchema); + } catch (IOException ex) { + throw new RuntimeException("Failed to serialize arrow schema.", ex); + } + return ByteString.copyFrom(byteOutputStream.toByteArray()); + } + + private static final EncoderFactory ENCODER_FACTORY = EncoderFactory.get(); + + private static ReadRowsResponse createResponse( + Schema schema, + Collection genericRecords, + double progressAtResponseStart, + double progressAtResponseEnd) + throws Exception { + GenericDatumWriter writer = new GenericDatumWriter<>(schema); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + Encoder binaryEncoder = ENCODER_FACTORY.binaryEncoder(outputStream, null); + for (GenericRecord genericRecord : genericRecords) { + writer.write(genericRecord, binaryEncoder); + } + + binaryEncoder.flush(); + + return ReadRowsResponse.newBuilder() + .setAvroRows( + AvroRows.newBuilder() + .setSerializedBinaryRows(ByteString.copyFrom(outputStream.toByteArray())) + .setRowCount(genericRecords.size())) + .setRowCount(genericRecords.size()) + .setStats( + StreamStats.newBuilder() + .setProgress( + Progress.newBuilder() + .setAtResponseStart(progressAtResponseStart) + .setAtResponseEnd(progressAtResponseEnd))) + .build(); + } + + private ReadRowsResponse createResponseArrow( + org.apache.arrow.vector.types.pojo.Schema arrowSchema, + List name, + List number, + double progressAtResponseStart, + double progressAtResponseEnd) { + ArrowRecordBatch serializedRecord; + try (VectorSchemaRoot schemaRoot = VectorSchemaRoot.create(arrowSchema, allocator)) { + schemaRoot.allocateNew(); + schemaRoot.setRowCount(name.size()); + VarCharVector strVector = (VarCharVector) schemaRoot.getFieldVectors().get(0); + BigIntVector bigIntVector = (BigIntVector) schemaRoot.getFieldVectors().get(1); + for (int i = 0; i < name.size(); i++) { + bigIntVector.set(i, number.get(i)); + strVector.set(i, new Text(name.get(i))); + } + + VectorUnloader unLoader = new VectorUnloader(schemaRoot); + try (org.apache.arrow.vector.ipc.message.ArrowRecordBatch records = + unLoader.getRecordBatch()) { + try (ByteArrayOutputStream os = new ByteArrayOutputStream()) { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(os)), records); + serializedRecord = + ArrowRecordBatch.newBuilder() + .setRowCount(records.getLength()) + .setSerializedRecordBatch(ByteString.copyFrom(os.toByteArray())) + .build(); + } catch (IOException e) { + throw new RuntimeException("Error writing to byte array output stream", e); + } + } + } + + return ReadRowsResponse.newBuilder() + .setArrowRecordBatch(serializedRecord) + .setRowCount(name.size()) + .setStats( + StreamStats.newBuilder() + .setProgress( + Progress.newBuilder() + .setAtResponseStart(progressAtResponseStart) + .setAtResponseEnd(progressAtResponseEnd))) + .build(); + } + + @Test + public void testStreamSourceEstimatedSizeBytes() throws Exception { + List streamBundle = Lists.newArrayList(ReadStream.getDefaultInstance()); + BigQueryStorageStreamBundleSource streamSource = + BigQueryStorageStreamBundleSource.create( + ReadSession.getDefaultInstance(), + streamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices(), + 1L); + + assertEquals(0, streamSource.getEstimatedSizeBytes(options)); + } + + @Test + public void testStreamSourceSplit() throws Exception { + List streamBundle = Lists.newArrayList(ReadStream.getDefaultInstance()); + BigQueryStorageStreamBundleSource streamSource = + BigQueryStorageStreamBundleSource.create( + ReadSession.getDefaultInstance(), + streamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices(), + 1L); + + assertThat(streamSource.split(0, options), containsInAnyOrder(streamSource)); + } + + @Test + public void testReadFromStreamSource() throws Exception { + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSession") + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .build(); + + ReadRowsRequest expectedRequestOne = + ReadRowsRequest.newBuilder().setReadStream("readStream1").setOffset(0).build(); + ReadRowsRequest expectedRequestTwo = + ReadRowsRequest.newBuilder().setReadStream("readStream2").setOffset(0).build(); + + List records = + Lists.newArrayList( + createRecord("A", 1, AVRO_SCHEMA), + createRecord("B", 2, AVRO_SCHEMA), + createRecord("C", 3, AVRO_SCHEMA), + createRecord("D", 4, AVRO_SCHEMA), + createRecord("E", 5, AVRO_SCHEMA), + createRecord("F", 6, AVRO_SCHEMA)); + + List responsesOne = + Lists.newArrayList( + createResponse(AVRO_SCHEMA, records.subList(0, 2), 0.0, 0.50), + createResponse(AVRO_SCHEMA, records.subList(2, 3), 0.5, 0.75)); + List responsesTwo = + Lists.newArrayList( + createResponse(AVRO_SCHEMA, records.subList(3, 5), 0.0, 0.50), + createResponse(AVRO_SCHEMA, records.subList(5, 6), 0.5, 0.75)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows(expectedRequestOne, "")) + .thenReturn(new FakeBigQueryServerStream<>(responsesOne)); + when(fakeStorageClient.readRows(expectedRequestTwo, "")) + .thenReturn(new FakeBigQueryServerStream<>(responsesTwo)); + + List streamBundle = + Lists.newArrayList( + ReadStream.newBuilder().setName("readStream1").build(), + ReadStream.newBuilder().setName("readStream2").build()); + BigQueryStorageStreamBundleSource streamSource = + BigQueryStorageStreamBundleSource.create( + readSession, + streamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + List rows = new ArrayList<>(); + BigQueryStorageStreamBundleReader reader = streamSource.createReader(options); + for (boolean hasNext = reader.start(); hasNext; hasNext = reader.advance()) { + rows.add(reader.getCurrent()); + } + + System.out.println("Rows: " + rows); + + assertEquals(6, rows.size()); + } + + private static final double DELTA = 1e-6; + + @Test + public void testFractionConsumedWithOneStreamInBundle() throws Exception { + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSession") + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .build(); + + ReadRowsRequest expectedRequest = + ReadRowsRequest.newBuilder().setReadStream("readStream").build(); + + List records = + Lists.newArrayList( + createRecord("A", 1, AVRO_SCHEMA), + createRecord("B", 2, AVRO_SCHEMA), + createRecord("C", 3, AVRO_SCHEMA), + createRecord("D", 4, AVRO_SCHEMA), + createRecord("E", 5, AVRO_SCHEMA), + createRecord("F", 6, AVRO_SCHEMA), + createRecord("G", 7, AVRO_SCHEMA)); + + List responses = + Lists.newArrayList( + createResponse(AVRO_SCHEMA, records.subList(0, 2), 0.0, 0.25), + // Some responses may contain zero results, so we must ensure that we can are resilient + // to such responses. + createResponse(AVRO_SCHEMA, Lists.newArrayList(), 0.25, 0.25), + createResponse(AVRO_SCHEMA, records.subList(2, 4), 0.3, 0.5), + createResponse(AVRO_SCHEMA, records.subList(4, 7), 0.7, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows(expectedRequest, "")) + .thenReturn(new FakeBigQueryServerStream<>(responses)); + + List streamBundle = + Lists.newArrayList(ReadStream.newBuilder().setName("readStream").build()); + BigQueryStorageStreamBundleSource streamSource = + BigQueryStorageStreamBundleSource.create( + readSession, + streamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + BoundedReader reader = streamSource.createReader(options); + + // Before call to BoundedReader#start, fraction consumed must be zero. + assertEquals(0.0, reader.getFractionConsumed(), DELTA); + + assertTrue(reader.start()); // Reads A. + assertEquals(0.125, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads B. + assertEquals(0.25, reader.getFractionConsumed(), DELTA); + + assertTrue(reader.advance()); // Reads C. + assertEquals(0.4, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads D. + assertEquals(0.5, reader.getFractionConsumed(), DELTA); + + assertTrue(reader.advance()); // Reads E. + assertEquals(0.8, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads F. + assertEquals(0.9, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads G. + assertEquals(1.0, reader.getFractionConsumed(), DELTA); + + assertFalse(reader.advance()); // Reaches the end. + + // We are done with the stream, so we should report 100% consumption. + assertEquals(Double.valueOf(1.0), reader.getFractionConsumed()); + } + + @Test + public void testFractionConsumedWithMultipleStreamsInBundle() throws Exception { + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSession") + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .build(); + + ReadRowsRequest expectedRequestOne = + ReadRowsRequest.newBuilder().setReadStream("readStream1").build(); + ReadRowsRequest expectedRequestTwo = + ReadRowsRequest.newBuilder().setReadStream("readStream2").build(); + + List records = + Lists.newArrayList( + createRecord("A", 1, AVRO_SCHEMA), + createRecord("B", 2, AVRO_SCHEMA), + createRecord("C", 3, AVRO_SCHEMA), + createRecord("D", 4, AVRO_SCHEMA), + createRecord("E", 5, AVRO_SCHEMA), + createRecord("F", 6, AVRO_SCHEMA), + createRecord("G", 7, AVRO_SCHEMA)); + + List responsesOne = + Lists.newArrayList( + createResponse(AVRO_SCHEMA, records.subList(0, 2), 0.0, 0.5), + // Some responses may contain zero results, so we must ensure that we are resilient + // to such responses. + createResponse(AVRO_SCHEMA, Lists.newArrayList(), 0.5, 0.5), + createResponse(AVRO_SCHEMA, records.subList(2, 4), 0.5, 1.0)); + + List responsesTwo = + Lists.newArrayList(createResponse(AVRO_SCHEMA, records.subList(4, 7), 0.0, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows(expectedRequestOne, "")) + .thenReturn(new FakeBigQueryServerStream<>(responsesOne)); + when(fakeStorageClient.readRows(expectedRequestTwo, "")) + .thenReturn(new FakeBigQueryServerStream<>(responsesTwo)); + + List streamBundle = + Lists.newArrayList( + ReadStream.newBuilder().setName("readStream1").build(), + ReadStream.newBuilder().setName("readStream2").build()); + + BigQueryStorageStreamBundleSource streamSource = + BigQueryStorageStreamBundleSource.create( + readSession, + streamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + BoundedReader reader = streamSource.createReader(options); + + // Before call to BoundedReader#start, fraction consumed must be zero. + assertEquals(0.0, reader.getFractionConsumed(), DELTA); + + assertTrue(reader.start()); // Reads A. + assertEquals(0.125, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads B. + assertEquals(0.25, reader.getFractionConsumed(), DELTA); + + assertTrue(reader.advance()); // Reads C. + assertEquals(0.375, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads D. + assertEquals(0.5, reader.getFractionConsumed(), DELTA); + + assertTrue(reader.advance()); // Reads E. + assertEquals(0.6666666666666666, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads F. + assertEquals(0.8333333333333333, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads G. + assertEquals(1.0, reader.getFractionConsumed(), DELTA); + + assertFalse(reader.advance()); // Reaches the end. + + // We are done with the streams, so we should report 100% consumption. + assertEquals(Double.valueOf(1.0), reader.getFractionConsumed()); + } + + @Test + public void testStreamSourceSplitAtFractionNoOpWithOneStreamInBundle() throws Exception { + List responses = + Lists.newArrayList( + createResponse( + AVRO_SCHEMA, + Lists.newArrayList( + createRecord("A", 1, AVRO_SCHEMA), createRecord("B", 2, AVRO_SCHEMA)), + 0.0, + 0.25), + createResponse( + AVRO_SCHEMA, Lists.newArrayList(createRecord("C", 3, AVRO_SCHEMA)), 0.25, 0.50), + createResponse( + AVRO_SCHEMA, + Lists.newArrayList( + createRecord("D", 4, AVRO_SCHEMA), createRecord("E", 5, AVRO_SCHEMA)), + 0.50, + 0.75)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("parentStream").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses)); + + List parentStreamBundle = + Lists.newArrayList(ReadStream.newBuilder().setName("parentStream").build()); + BigQueryStorageStreamBundleSource streamBundleSource = + BigQueryStorageStreamBundleSource.create( + ReadSession.newBuilder() + .setName("readSession") + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .build(), + parentStreamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + // Read a few records from the parent stream and ensure that records are returned in the + // prescribed order. + BoundedReader primary = streamBundleSource.createReader(options); + assertTrue(primary.start()); + assertEquals("A", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("B", primary.getCurrent().get("name")); + + // Now split the stream. Since we do NOT split below the granularity of a single stream, + // this will be a No-Op and the primary source should be read to completion. + BoundedSource secondary = primary.splitAtFraction(0.5); + assertNull(secondary); + + assertTrue(primary.advance()); + assertEquals("C", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("D", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("E", primary.getCurrent().get("name")); + assertFalse(primary.advance()); + } + + @Test + public void testStreamSourceSplitAtFractionWithMultipleStreamsInBundle() throws Exception { + List responses = + Lists.newArrayList( + createResponse( + AVRO_SCHEMA, + Lists.newArrayList( + createRecord("A", 1, AVRO_SCHEMA), createRecord("B", 2, AVRO_SCHEMA)), + 0.0, + 0.6), + createResponse( + AVRO_SCHEMA, Lists.newArrayList(createRecord("C", 3, AVRO_SCHEMA)), 0.6, 1.0), + createResponse( + AVRO_SCHEMA, + Lists.newArrayList( + createRecord("D", 4, AVRO_SCHEMA), + createRecord("E", 5, AVRO_SCHEMA), + createRecord("F", 6, AVRO_SCHEMA)), + 0.0, + 1.0), + createResponse( + AVRO_SCHEMA, Lists.newArrayList(createRecord("G", 7, AVRO_SCHEMA)), 0.0, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream1").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(0, 2))); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream2").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(2, 3))); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream3").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(3, 4))); + + List primaryStreamBundle = + Lists.newArrayList( + ReadStream.newBuilder().setName("readStream1").build(), + ReadStream.newBuilder().setName("readStream2").build(), + ReadStream.newBuilder().setName("readStream3").build()); + + BigQueryStorageStreamBundleSource primarySource = + BigQueryStorageStreamBundleSource.create( + ReadSession.newBuilder() + .setName("readSession") + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .build(), + primaryStreamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + // Read a few records from the primary Source and ensure that records are returned in the + // prescribed order. + BoundedReader primary = primarySource.createReader(options); + + assertTrue(primary.start()); + + // Attempting to split at a sub-Stream level which is NOT supported by the + // `BigQueryStorageStreamBundleSource`. IOTW, since there are exactly 3 Streams in the Source, + // a split will only occur for fraction > 0.33. + BoundedSource secondarySource = primary.splitAtFraction(0.05); + assertNull(secondarySource); + + assertEquals("A", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("B", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("C", primary.getCurrent().get("name")); + + // Now split the primary Source, and ensure that the returned source points to a non-null + // StreamBundle containing Streams 2 & 3. + secondarySource = primary.splitAtFraction(0.5); + assertNotNull(secondarySource); + BoundedReader secondary = secondarySource.createReader(options); + + // Since the last two streams were split out the Primary source has been exhausted. + assertFalse(primary.advance()); + + assertTrue(secondary.start()); + assertEquals("D", secondary.getCurrent().get("name")); + assertTrue(secondary.advance()); + assertEquals("E", secondary.getCurrent().get("name")); + assertTrue(secondary.advance()); + assertEquals("F", secondary.getCurrent().get("name")); + assertTrue((secondary.advance())); + + // Since we have already started reading from the last Stream in the StreamBundle, splitting + // is now a no-op. + BoundedSource tertiarySource = secondary.splitAtFraction(0.55); + assertNull(tertiarySource); + + assertEquals("G", secondary.getCurrent().get("name")); + assertFalse((secondary.advance())); + } + + @Test + public void testStreamSourceSplitAtFractionRepeatedWithMultipleStreamInBundle() throws Exception { + List responses = + Lists.newArrayList( + createResponse( + AVRO_SCHEMA, + Lists.newArrayList( + createRecord("A", 1, AVRO_SCHEMA), createRecord("B", 2, AVRO_SCHEMA)), + 0.0, + 0.6), + createResponse( + AVRO_SCHEMA, Lists.newArrayList(createRecord("C", 3, AVRO_SCHEMA)), 0.6, 1.0), + createResponse( + AVRO_SCHEMA, + Lists.newArrayList( + createRecord("D", 4, AVRO_SCHEMA), + createRecord("E", 5, AVRO_SCHEMA), + createRecord("F", 6, AVRO_SCHEMA)), + 0.0, + 1.0), + createResponse( + AVRO_SCHEMA, Lists.newArrayList(createRecord("G", 7, AVRO_SCHEMA)), 0.0, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream1").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(0, 2))); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream2").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(2, 3))); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream3").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(3, 4))); + + List primaryStreamBundle = + Lists.newArrayList( + ReadStream.newBuilder().setName("readStream1").build(), + ReadStream.newBuilder().setName("readStream2").build(), + ReadStream.newBuilder().setName("readStream3").build()); + + BigQueryStorageStreamBundleSource primarySource = + BigQueryStorageStreamBundleSource.create( + ReadSession.newBuilder() + .setName("readSession") + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .build(), + primaryStreamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + // Read a few records from the primary Source and ensure that records are returned in the + // prescribed order. + BoundedReader primary = primarySource.createReader(options); + + assertTrue(primary.start()); + assertEquals("A", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("B", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("C", primary.getCurrent().get("name")); + + // Now split the primary Source, and ensure that the returned source points to a non-null + // StreamBundle containing ONLY Stream 3. Since there are exactly 3 Streams in the Source, + // a split will only occur for fraction > 0.33. + BoundedSource secondarySource = primary.splitAtFraction(0.7); + assertNotNull(secondarySource); + BoundedReader secondary = secondarySource.createReader(options); + assertTrue(secondary.start()); + assertEquals("G", secondary.getCurrent().get("name")); + assertFalse((secondary.advance())); + + // A second splitAtFraction() call on the primary source. The resulting source should + // contain a StreamBundle containing ONLY Stream 2. Since there are 2 Streams in the Source, + // a split will only occur for fraction > 0.50. + BoundedSource tertiarySource = primary.splitAtFraction(0.55); + assertNotNull(tertiarySource); + BoundedReader tertiary = tertiarySource.createReader(options); + assertTrue(tertiary.start()); + assertEquals("D", tertiary.getCurrent().get("name")); + assertTrue(tertiary.advance()); + assertEquals("E", tertiary.getCurrent().get("name")); + assertTrue(tertiary.advance()); + assertEquals("F", tertiary.getCurrent().get("name")); + assertFalse(tertiary.advance()); + + // A third attempt to split the primary source. This will be ignored since the primary source + // since the Source contains only a single stream now and `BigQueryStorageStreamBundleSource` + // does NOT support sub-stream splitting. + tertiarySource = primary.splitAtFraction(0.9); + assertNull(tertiarySource); + + // All the rows in the primary Source have been read. + assertFalse(primary.advance()); + } + + @Test + public void testStreamSourceSplitAtFractionFailsWhenParentIsPastSplitPoint() throws Exception { + List responses = + Lists.newArrayList( + createResponse( + AVRO_SCHEMA, + Lists.newArrayList( + createRecord("A", 1, AVRO_SCHEMA), createRecord("B", 2, AVRO_SCHEMA)), + 0.0, + 0.66), + createResponse( + AVRO_SCHEMA, Lists.newArrayList(createRecord("C", 3, AVRO_SCHEMA)), 0.66, 1.0), + createResponse( + AVRO_SCHEMA, + Lists.newArrayList( + createRecord("D", 4, AVRO_SCHEMA), createRecord("E", 5, AVRO_SCHEMA)), + 0.0, + 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream1").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(0, 2))); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream2").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(2, 3))); + + List parentStreamBundle = + Lists.newArrayList( + ReadStream.newBuilder().setName("readStream1").build(), + ReadStream.newBuilder().setName("readStream2").build()); + + BigQueryStorageStreamBundleSource streamBundleSource = + BigQueryStorageStreamBundleSource.create( + ReadSession.newBuilder() + .setName("readSession") + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .build(), + parentStreamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + // Read a few records from the parent bundle and ensure the records are returned in + // the prescribed order. + BoundedReader primary = streamBundleSource.createReader(options); + assertTrue(primary.start()); + assertEquals("A", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("B", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("C", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("D", primary.getCurrent().get("name")); + + // We attempt to split the StreamBundle after starting to read the contents of the second + // stream. + BoundedSource secondarySource = primary.splitAtFraction(0.5); + assertNull(secondarySource); + + assertTrue(primary.advance()); + assertEquals("E", primary.getCurrent().get("name")); + assertFalse(primary.advance()); + } + + private static final class ParseKeyValue + implements SerializableFunction> { + + @Override + public KV apply(SchemaAndRecord input) { + return KV.of( + input.getRecord().get("name").toString(), (Long) input.getRecord().get("number")); + } + } + + @Test + public void testReadFromBigQueryIO() throws Exception { + fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table"); + Table table = new Table().setTableReference(tableRef).setNumBytes(10L).setSchema(TABLE_SCHEMA); + fakeDatasetService.createTable(table); + + CreateReadSessionRequest expectedCreateReadSessionRequest = + CreateReadSessionRequest.newBuilder() + .setParent("projects/project-id") + .setReadSession( + ReadSession.newBuilder() + .setTable("projects/foo.com:project/datasets/dataset/tables/table") + .setDataFormat(DataFormat.AVRO)) + .setMaxStreamCount(0) + .build(); + + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSessionName") + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .addStreams(ReadStream.newBuilder().setName("streamName1")) + .addStreams(ReadStream.newBuilder().setName("streamName2")) + .setDataFormat(DataFormat.AVRO) + .setEstimatedTotalBytesScanned(10L) + .build(); + + ReadRowsRequest expectedReadRowsRequestOne = + ReadRowsRequest.newBuilder().setReadStream("streamName1").build(); + ReadRowsRequest expectedReadRowsRequestTwo = + ReadRowsRequest.newBuilder().setReadStream("streamName2").build(); + + List records = + Lists.newArrayList( + createRecord("A", 1, AVRO_SCHEMA), + createRecord("B", 2, AVRO_SCHEMA), + createRecord("C", 3, AVRO_SCHEMA), + createRecord("D", 4, AVRO_SCHEMA), + createRecord("E", 5, AVRO_SCHEMA), + createRecord("F", 6, AVRO_SCHEMA), + createRecord("G", 7, AVRO_SCHEMA)); + + List readRowsResponsesOne = + Lists.newArrayList( + createResponse(AVRO_SCHEMA, records.subList(0, 2), 0.0, 0.50), + createResponse(AVRO_SCHEMA, records.subList(2, 4), 0.5, 1.0)); + List readRowsResponsesTwo = + Lists.newArrayList( + createResponse(AVRO_SCHEMA, records.subList(4, 5), 0.0, 0.33), + createResponse(AVRO_SCHEMA, records.subList(5, 7), 0.33, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class, withSettings().serializable()); + when(fakeStorageClient.createReadSession(expectedCreateReadSessionRequest)) + .thenReturn(readSession); + when(fakeStorageClient.readRows(expectedReadRowsRequestOne, "")) + .thenReturn(new FakeBigQueryServerStream<>(readRowsResponsesOne)); + when(fakeStorageClient.readRows(expectedReadRowsRequestTwo, "")) + .thenReturn(new FakeBigQueryServerStream<>(readRowsResponsesTwo)); + + PCollection> output = + p.apply( + BigQueryIO.read(new ParseKeyValue()) + .from("foo.com:project:dataset.table") + .withMethod(Method.DIRECT_READ) + .withFormat(DataFormat.AVRO) + .withTestServices( + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withStorageClient(fakeStorageClient))); + + PAssert.that(output) + .containsInAnyOrder( + ImmutableList.of( + KV.of("A", 1L), + KV.of("B", 2L), + KV.of("C", 3L), + KV.of("D", 4L), + KV.of("E", 5L), + KV.of("F", 6L), + KV.of("G", 7L))); + + p.run(); + } + + @Test + public void testReadFromBigQueryIOWithTrimmedSchema() throws Exception { + fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table"); + Table table = new Table().setTableReference(tableRef).setNumBytes(10L).setSchema(TABLE_SCHEMA); + fakeDatasetService.createTable(table); + + CreateReadSessionRequest expectedCreateReadSessionRequest = + CreateReadSessionRequest.newBuilder() + .setParent("projects/project-id") + .setReadSession( + ReadSession.newBuilder() + .setTable("projects/foo.com:project/datasets/dataset/tables/table") + .setReadOptions( + ReadSession.TableReadOptions.newBuilder().addSelectedFields("name")) + .setDataFormat(DataFormat.AVRO)) + .setMaxStreamCount(0) + .build(); + + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSessionName") + .setAvroSchema(AvroSchema.newBuilder().setSchema(TRIMMED_AVRO_SCHEMA_STRING)) + .addStreams(ReadStream.newBuilder().setName("streamName1")) + .addStreams(ReadStream.newBuilder().setName("streamName2")) + .setDataFormat(DataFormat.AVRO) + .build(); + + ReadRowsRequest expectedReadRowsRequestOne = + ReadRowsRequest.newBuilder().setReadStream("streamName1").build(); + ReadRowsRequest expectedReadRowsRequestTwo = + ReadRowsRequest.newBuilder().setReadStream("streamName2").build(); + + List records = + Lists.newArrayList( + createRecord("A", TRIMMED_AVRO_SCHEMA), + createRecord("B", TRIMMED_AVRO_SCHEMA), + createRecord("C", TRIMMED_AVRO_SCHEMA), + createRecord("D", TRIMMED_AVRO_SCHEMA), + createRecord("E", TRIMMED_AVRO_SCHEMA), + createRecord("F", TRIMMED_AVRO_SCHEMA), + createRecord("G", TRIMMED_AVRO_SCHEMA)); + + List readRowsResponsesOne = + Lists.newArrayList( + createResponse(TRIMMED_AVRO_SCHEMA, records.subList(0, 2), 0.0, 0.50), + createResponse(TRIMMED_AVRO_SCHEMA, records.subList(2, 4), 0.5, 0.75)); + List readRowsResponsesTwo = + Lists.newArrayList( + createResponse(TRIMMED_AVRO_SCHEMA, records.subList(4, 5), 0.0, 0.33), + createResponse(TRIMMED_AVRO_SCHEMA, records.subList(5, 7), 0.33, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class, withSettings().serializable()); + when(fakeStorageClient.createReadSession(expectedCreateReadSessionRequest)) + .thenReturn(readSession); + when(fakeStorageClient.readRows(expectedReadRowsRequestOne, "")) + .thenReturn(new FakeBigQueryServerStream<>(readRowsResponsesOne)); + when(fakeStorageClient.readRows(expectedReadRowsRequestTwo, "")) + .thenReturn(new FakeBigQueryServerStream<>(readRowsResponsesTwo)); + + PCollection output = + p.apply( + BigQueryIO.readTableRows() + .from("foo.com:project:dataset.table") + .withMethod(Method.DIRECT_READ) + .withSelectedFields(Lists.newArrayList("name")) + .withFormat(DataFormat.AVRO) + .withTestServices( + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withStorageClient(fakeStorageClient))); + + PAssert.that(output) + .containsInAnyOrder( + ImmutableList.of( + new TableRow().set("name", "A"), + new TableRow().set("name", "B"), + new TableRow().set("name", "C"), + new TableRow().set("name", "D"), + new TableRow().set("name", "E"), + new TableRow().set("name", "F"), + new TableRow().set("name", "G"))); + + p.run(); + } + + @Test + public void testReadFromBigQueryIOWithBeamSchema() throws Exception { + fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table"); + Table table = new Table().setTableReference(tableRef).setNumBytes(10L).setSchema(TABLE_SCHEMA); + fakeDatasetService.createTable(table); + + CreateReadSessionRequest expectedCreateReadSessionRequest = + CreateReadSessionRequest.newBuilder() + .setParent("projects/project-id") + .setReadSession( + ReadSession.newBuilder() + .setTable("projects/foo.com:project/datasets/dataset/tables/table") + .setReadOptions( + ReadSession.TableReadOptions.newBuilder().addSelectedFields("name")) + .setDataFormat(DataFormat.AVRO)) + .setMaxStreamCount(0) + .build(); + + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSessionName") + .setAvroSchema(AvroSchema.newBuilder().setSchema(TRIMMED_AVRO_SCHEMA_STRING)) + .addStreams(ReadStream.newBuilder().setName("streamName1")) + .addStreams(ReadStream.newBuilder().setName("streamName2")) + .setDataFormat(DataFormat.AVRO) + .build(); + + ReadRowsRequest expectedReadRowsRequestOne = + ReadRowsRequest.newBuilder().setReadStream("streamName1").build(); + ReadRowsRequest expectedReadRowsRequestTwo = + ReadRowsRequest.newBuilder().setReadStream("streamName2").build(); + + List records = + Lists.newArrayList( + createRecord("A", TRIMMED_AVRO_SCHEMA), + createRecord("B", TRIMMED_AVRO_SCHEMA), + createRecord("C", TRIMMED_AVRO_SCHEMA), + createRecord("D", TRIMMED_AVRO_SCHEMA), + createRecord("E", TRIMMED_AVRO_SCHEMA), + createRecord("F", TRIMMED_AVRO_SCHEMA), + createRecord("G", TRIMMED_AVRO_SCHEMA)); + + List readRowsResponsesOne = + Lists.newArrayList( + createResponse(TRIMMED_AVRO_SCHEMA, records.subList(0, 2), 0.0, 0.50), + createResponse(TRIMMED_AVRO_SCHEMA, records.subList(2, 4), 0.5, 0.75)); + List readRowsResponsesTwo = + Lists.newArrayList( + createResponse(TRIMMED_AVRO_SCHEMA, records.subList(4, 5), 0.0, 0.33), + createResponse(TRIMMED_AVRO_SCHEMA, records.subList(5, 7), 0.33, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class, withSettings().serializable()); + when(fakeStorageClient.createReadSession(expectedCreateReadSessionRequest)) + .thenReturn(readSession); + when(fakeStorageClient.readRows(expectedReadRowsRequestOne, "")) + .thenReturn(new FakeBigQueryServerStream<>(readRowsResponsesOne)); + when(fakeStorageClient.readRows(expectedReadRowsRequestTwo, "")) + .thenReturn(new FakeBigQueryServerStream<>(readRowsResponsesTwo)); + + PCollection output = + p.apply( + BigQueryIO.readTableRowsWithSchema() + .from("foo.com:project:dataset.table") + .withMethod(Method.DIRECT_READ) + .withSelectedFields(Lists.newArrayList("name")) + .withFormat(DataFormat.AVRO) + .withTestServices( + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withStorageClient(fakeStorageClient))) + .apply(Convert.toRows()); + + org.apache.beam.sdk.schemas.Schema beamSchema = + org.apache.beam.sdk.schemas.Schema.of( + org.apache.beam.sdk.schemas.Schema.Field.of( + "name", org.apache.beam.sdk.schemas.Schema.FieldType.STRING)); + PAssert.that(output) + .containsInAnyOrder( + ImmutableList.of( + Row.withSchema(beamSchema).addValue("A").build(), + Row.withSchema(beamSchema).addValue("B").build(), + Row.withSchema(beamSchema).addValue("C").build(), + Row.withSchema(beamSchema).addValue("D").build(), + Row.withSchema(beamSchema).addValue("E").build(), + Row.withSchema(beamSchema).addValue("F").build(), + Row.withSchema(beamSchema).addValue("G").build())); + + p.run(); + } + + @Test + public void testReadFromBigQueryIOArrow() throws Exception { + fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null); + TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table"); + Table table = new Table().setTableReference(tableRef).setNumBytes(10L).setSchema(TABLE_SCHEMA); + fakeDatasetService.createTable(table); + + CreateReadSessionRequest expectedCreateReadSessionRequest = + CreateReadSessionRequest.newBuilder() + .setParent("projects/project-id") + .setReadSession( + ReadSession.newBuilder() + .setTable("projects/foo.com:project/datasets/dataset/tables/table") + .setDataFormat(DataFormat.ARROW)) + .setMaxStreamCount(0) + .build(); + + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSessionName") + .setArrowSchema( + ArrowSchema.newBuilder() + .setSerializedSchema(serializeArrowSchema(ARROW_SCHEMA)) + .build()) + .addStreams(ReadStream.newBuilder().setName("streamName1")) + .addStreams(ReadStream.newBuilder().setName("streamName2")) + .setDataFormat(DataFormat.ARROW) + .build(); + + ReadRowsRequest expectedReadRowsRequestOne = + ReadRowsRequest.newBuilder().setReadStream("streamName1").build(); + ReadRowsRequest expectedReadRowsRequestTwo = + ReadRowsRequest.newBuilder().setReadStream("streamName2").build(); + + List names = Arrays.asList("A", "B", "C", "D", "E", "F", "G"); + List values = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L); + List readRowsResponsesOne = + Lists.newArrayList( + createResponseArrow(ARROW_SCHEMA, names.subList(0, 2), values.subList(0, 2), 0.0, 0.50), + createResponseArrow( + ARROW_SCHEMA, names.subList(2, 4), values.subList(2, 4), 0.5, 0.75)); + List readRowsResponsesTwo = + Lists.newArrayList( + createResponseArrow(ARROW_SCHEMA, names.subList(4, 5), values.subList(4, 5), 0.0, 0.33), + createResponseArrow( + ARROW_SCHEMA, names.subList(5, 6), values.subList(5, 6), 0.33, 0.66), + createResponseArrow( + ARROW_SCHEMA, names.subList(6, 7), values.subList(6, 7), 0.66, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class, withSettings().serializable()); + when(fakeStorageClient.createReadSession(expectedCreateReadSessionRequest)) + .thenReturn(readSession); + when(fakeStorageClient.readRows(expectedReadRowsRequestOne, "")) + .thenReturn(new FakeBigQueryServerStream<>(readRowsResponsesOne)); + when(fakeStorageClient.readRows(expectedReadRowsRequestTwo, "")) + .thenReturn(new FakeBigQueryServerStream<>(readRowsResponsesTwo)); + + PCollection> output = + p.apply( + BigQueryIO.read(new ParseKeyValue()) + .from("foo.com:project:dataset.table") + .withMethod(Method.DIRECT_READ) + .withFormat(DataFormat.ARROW) + .withTestServices( + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withStorageClient(fakeStorageClient))); + + PAssert.that(output) + .containsInAnyOrder( + ImmutableList.of( + KV.of("A", 1L), + KV.of("B", 2L), + KV.of("C", 3L), + KV.of("D", 4L), + KV.of("E", 5L), + KV.of("F", 6L), + KV.of("G", 7L))); + + p.run(); + } + + @Test + public void testReadFromStreamSourceArrow() throws Exception { + + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSession") + .setArrowSchema( + ArrowSchema.newBuilder() + .setSerializedSchema(serializeArrowSchema(ARROW_SCHEMA)) + .build()) + .setDataFormat(DataFormat.ARROW) + .build(); + + ReadRowsRequest expectedRequest = + ReadRowsRequest.newBuilder().setReadStream("readStream").build(); + + List names = Arrays.asList("A", "B", "C"); + List values = Arrays.asList(1L, 2L, 3L); + List responses = + Lists.newArrayList( + createResponseArrow(ARROW_SCHEMA, names.subList(0, 2), values.subList(0, 2), 0.0, 0.50), + createResponseArrow( + ARROW_SCHEMA, names.subList(2, 3), values.subList(2, 3), 0.5, 0.75)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows(expectedRequest, "")) + .thenReturn(new FakeBigQueryServerStream<>(responses)); + + List streamBundle = + Lists.newArrayList(ReadStream.newBuilder().setName("readStream").build()); + BigQueryStorageStreamBundleSource streamSource = + BigQueryStorageStreamBundleSource.create( + readSession, + streamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + List rows = new ArrayList<>(); + BoundedReader reader = streamSource.createReader(options); + for (boolean hasNext = reader.start(); hasNext; hasNext = reader.advance()) { + rows.add(reader.getCurrent()); + } + + System.out.println("Rows: " + rows); + + assertEquals(3, rows.size()); + } + + @Test + public void testFractionConsumedWithArrowAndOneStreamInBundle() throws Exception { + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSession") + .setArrowSchema( + ArrowSchema.newBuilder() + .setSerializedSchema(serializeArrowSchema(ARROW_SCHEMA)) + .build()) + .setDataFormat(DataFormat.ARROW) + .build(); + + ReadRowsRequest expectedRequest = + ReadRowsRequest.newBuilder().setReadStream("readStream").build(); + + List names = Arrays.asList("A", "B", "C", "D", "E", "F", "G"); + List values = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L); + List responses = + Lists.newArrayList( + createResponseArrow(ARROW_SCHEMA, names.subList(0, 2), values.subList(0, 2), 0.0, 0.25), + createResponseArrow( + ARROW_SCHEMA, Lists.newArrayList(), Lists.newArrayList(), 0.25, 0.25), + createResponseArrow(ARROW_SCHEMA, names.subList(2, 4), values.subList(2, 4), 0.3, 0.5), + createResponseArrow(ARROW_SCHEMA, names.subList(4, 7), values.subList(4, 7), 0.7, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows(expectedRequest, "")) + .thenReturn(new FakeBigQueryServerStream<>(responses)); + + List streamBundle = + Lists.newArrayList(ReadStream.newBuilder().setName("readStream").build()); + BigQueryStorageStreamBundleSource streamSource = + BigQueryStorageStreamBundleSource.create( + readSession, + streamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + BoundedReader reader = streamSource.createReader(options); + + // Before call to BoundedReader#start, fraction consumed must be zero. + assertEquals(0.0, reader.getFractionConsumed(), DELTA); + + assertTrue(reader.start()); // Reads A. + assertEquals(0.125, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads B. + assertEquals(0.25, reader.getFractionConsumed(), DELTA); + + assertTrue(reader.advance()); // Reads C. + assertEquals(0.4, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads D. + assertEquals(0.5, reader.getFractionConsumed(), DELTA); + + assertTrue(reader.advance()); // Reads E. + assertEquals(0.8, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads F. + assertEquals(0.9, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads G. + assertEquals(1.0, reader.getFractionConsumed(), DELTA); + + assertFalse(reader.advance()); // Reaches the end. + + // We are done with the stream, so we should report 100% consumption. + assertEquals(Double.valueOf(1.0), reader.getFractionConsumed()); + } + + @Test + public void testFractionConsumedWithArrowAndMultipleStreamsInBundle() throws Exception { + ReadSession readSession = + ReadSession.newBuilder() + .setName("readSession") + .setArrowSchema( + ArrowSchema.newBuilder() + .setSerializedSchema(serializeArrowSchema(ARROW_SCHEMA)) + .build()) + .setDataFormat(DataFormat.ARROW) + .build(); + + ReadRowsRequest expectedRequestOne = + ReadRowsRequest.newBuilder().setReadStream("readStream1").build(); + ReadRowsRequest expectedRequestTwo = + ReadRowsRequest.newBuilder().setReadStream("readStream2").build(); + + List names = Arrays.asList("A", "B", "C", "D", "E", "F", "G"); + List values = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L); + List responsesOne = + Lists.newArrayList( + createResponseArrow(ARROW_SCHEMA, names.subList(0, 2), values.subList(0, 2), 0.0, 0.5), + createResponseArrow(ARROW_SCHEMA, Lists.newArrayList(), Lists.newArrayList(), 0.5, 0.5), + createResponseArrow(ARROW_SCHEMA, names.subList(2, 4), values.subList(2, 4), 0.5, 1.0)); + + List responsesTwo = + Lists.newArrayList( + createResponseArrow(ARROW_SCHEMA, names.subList(4, 7), values.subList(4, 7), 0.0, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows(expectedRequestOne, "")) + .thenReturn(new FakeBigQueryServerStream<>(responsesOne)); + when(fakeStorageClient.readRows(expectedRequestTwo, "")) + .thenReturn(new FakeBigQueryServerStream<>(responsesTwo)); + + List streamBundle = + Lists.newArrayList( + ReadStream.newBuilder().setName("readStream1").build(), + ReadStream.newBuilder().setName("readStream2").build()); + + BigQueryStorageStreamBundleSource streamSource = + BigQueryStorageStreamBundleSource.create( + readSession, + streamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + BoundedReader reader = streamSource.createReader(options); + + // Before call to BoundedReader#start, fraction consumed must be zero. + assertEquals(0.0, reader.getFractionConsumed(), DELTA); + + assertTrue(reader.start()); // Reads A. + assertEquals(0.125, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads B. + assertEquals(0.25, reader.getFractionConsumed(), DELTA); + + assertTrue(reader.advance()); // Reads C. + assertEquals(0.375, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads D. + assertEquals(0.5, reader.getFractionConsumed(), DELTA); + + assertTrue(reader.advance()); // Reads E. + assertEquals(0.6666666666666666, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads F. + assertEquals(0.8333333333333333, reader.getFractionConsumed(), DELTA); + assertTrue(reader.advance()); // Reads G. + assertEquals(1.0, reader.getFractionConsumed(), DELTA); + + assertFalse(reader.advance()); // Reaches the end. + + // We are done with the streams, so we should report 100% consumption. + assertEquals(Double.valueOf(1.0), reader.getFractionConsumed()); + } + + @Test + public void testStreamSourceSplitAtFractionWithArrowAndMultipleStreamsInBundle() + throws Exception { + List names = Arrays.asList("A", "B", "C", "D", "E", "F", "G"); + List values = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L); + List responses = + Lists.newArrayList( + createResponseArrow(ARROW_SCHEMA, names.subList(0, 2), values.subList(0, 2), 0.0, 0.6), + createResponseArrow(ARROW_SCHEMA, names.subList(2, 3), values.subList(2, 3), 0.6, 1.0), + createResponseArrow(ARROW_SCHEMA, names.subList(3, 6), values.subList(3, 6), 0.0, 1.0), + createResponseArrow(ARROW_SCHEMA, names.subList(6, 7), values.subList(6, 7), 0.0, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream1").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(0, 2))); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream2").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(2, 3))); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream3").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(3, 4))); + + List primaryStreamBundle = + Lists.newArrayList( + ReadStream.newBuilder().setName("readStream1").build(), + ReadStream.newBuilder().setName("readStream2").build(), + ReadStream.newBuilder().setName("readStream3").build()); + + BigQueryStorageStreamBundleSource primarySource = + BigQueryStorageStreamBundleSource.create( + ReadSession.newBuilder() + .setName("readSession") + .setArrowSchema( + ArrowSchema.newBuilder() + .setSerializedSchema(serializeArrowSchema(ARROW_SCHEMA)) + .build()) + .setDataFormat(DataFormat.ARROW) + .build(), + primaryStreamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + // Read a few records from the primary bundle and ensure that records are returned in the + // prescribed order. + BoundedReader primary = primarySource.createReader(options); + assertTrue(primary.start()); + assertEquals("A", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("B", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + + // Now split the StreamBundle, and ensure that the returned source points to a non-null + // secondary StreamBundle. + BoundedSource secondarySource = primary.splitAtFraction(0.35); + assertNotNull(secondarySource); + BoundedReader secondary = secondarySource.createReader(options); + + assertEquals("C", primary.getCurrent().get("name")); + assertFalse(primary.advance()); + + assertTrue(secondary.start()); + assertEquals("D", secondary.getCurrent().get("name")); + assertTrue(secondary.advance()); + assertEquals("E", secondary.getCurrent().get("name")); + assertTrue(secondary.advance()); + assertEquals("F", secondary.getCurrent().get("name")); + assertTrue((secondary.advance())); + assertEquals("G", secondary.getCurrent().get("name")); + assertFalse((secondary.advance())); + } + + @Test + public void testStreamSourceSplitAtFractionRepeatedWithArrowAndMultipleStreamsInBundle() + throws Exception { + List names = Arrays.asList("A", "B", "C", "D", "E", "F", "G"); + List values = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L, 7L); + List responses = + Lists.newArrayList( + createResponseArrow(ARROW_SCHEMA, names.subList(0, 2), values.subList(0, 2), 0.0, 0.6), + createResponseArrow(ARROW_SCHEMA, names.subList(2, 3), values.subList(2, 3), 0.6, 1.0), + createResponseArrow(ARROW_SCHEMA, names.subList(3, 6), values.subList(3, 6), 0.0, 1.0), + createResponseArrow(ARROW_SCHEMA, names.subList(6, 7), values.subList(6, 7), 0.0, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream1").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(0, 2))); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream2").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(2, 3))); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream3").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(3, 4))); + + List primaryStreamBundle = + Lists.newArrayList( + ReadStream.newBuilder().setName("readStream1").build(), + ReadStream.newBuilder().setName("readStream2").build(), + ReadStream.newBuilder().setName("readStream3").build()); + + BigQueryStorageStreamBundleSource primarySource = + BigQueryStorageStreamBundleSource.create( + ReadSession.newBuilder() + .setName("readSession") + .setArrowSchema( + ArrowSchema.newBuilder() + .setSerializedSchema(serializeArrowSchema(ARROW_SCHEMA)) + .build()) + .setDataFormat(DataFormat.ARROW) + .build(), + primaryStreamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + // Read a few records from the primary bundle and ensure that records are returned in the + // prescribed order. + BoundedReader primary = primarySource.createReader(options); + assertTrue(primary.start()); + assertEquals("A", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("B", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + + // Now split the StreamBundle, and ensure that the returned source points to a non-null + // secondary StreamBundle. Since there are 3 streams in this Bundle, splitting will only + // occur when fraction >= 0.33. + BoundedSource secondarySource = primary.splitAtFraction(0.35); + assertNotNull(secondarySource); + BoundedReader secondary = secondarySource.createReader(options); + + assertEquals("C", primary.getCurrent().get("name")); + assertFalse(primary.advance()); + + assertTrue(secondary.start()); + assertEquals("D", secondary.getCurrent().get("name")); + assertTrue(secondary.advance()); + assertEquals("E", secondary.getCurrent().get("name")); + assertTrue(secondary.advance()); + + // Now split the StreamBundle again, and ensure that the returned source points to a non-null + // tertiary StreamBundle. Since there are 2 streams in this Bundle, splitting will only + // occur when fraction >= 0.5. + BoundedSource tertiarySource = secondary.splitAtFraction(0.5); + assertNotNull(tertiarySource); + BoundedReader tertiary = tertiarySource.createReader(options); + + assertEquals("F", secondary.getCurrent().get("name")); + assertFalse((secondary.advance())); + + assertTrue(tertiary.start()); + assertEquals("G", tertiary.getCurrent().get("name")); + assertFalse((tertiary.advance())); + } + + @Test + public void testStreamSourceSplitAtFractionFailsWhenParentIsPastSplitPointArrow() + throws Exception { + List names = Arrays.asList("A", "B", "C", "D", "E"); + List values = Arrays.asList(1L, 2L, 3L, 4L, 5L); + List responses = + Lists.newArrayList( + createResponseArrow(ARROW_SCHEMA, names.subList(0, 2), values.subList(0, 2), 0.0, 0.66), + createResponseArrow(ARROW_SCHEMA, names.subList(2, 3), values.subList(2, 3), 0.66, 1.0), + createResponseArrow(ARROW_SCHEMA, names.subList(3, 5), values.subList(3, 5), 0.0, 1.0)); + + StorageClient fakeStorageClient = mock(StorageClient.class); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream1").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(0, 2))); + when(fakeStorageClient.readRows( + ReadRowsRequest.newBuilder().setReadStream("readStream2").build(), "")) + .thenReturn(new FakeBigQueryServerStream<>(responses.subList(2, 3))); + + List parentStreamBundle = + Lists.newArrayList( + ReadStream.newBuilder().setName("readStream1").build(), + ReadStream.newBuilder().setName("readStream2").build()); + + BigQueryStorageStreamBundleSource streamBundleSource = + BigQueryStorageStreamBundleSource.create( + ReadSession.newBuilder() + .setName("readSession") + .setArrowSchema( + ArrowSchema.newBuilder() + .setSerializedSchema(serializeArrowSchema(ARROW_SCHEMA)) + .build()) + .setDataFormat(DataFormat.ARROW) + .build(), + parentStreamBundle, + TABLE_SCHEMA, + new TableRowParser(), + TableRowJsonCoder.of(), + new FakeBigQueryServices().withStorageClient(fakeStorageClient), + 1L); + + // Read a few records from the parent bundle and ensure the records are returned in + // the prescribed order. + BoundedReader primary = streamBundleSource.createReader(options); + assertTrue(primary.start()); + assertEquals("A", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("B", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("C", primary.getCurrent().get("name")); + assertTrue(primary.advance()); + assertEquals("D", primary.getCurrent().get("name")); + + // We attempt to split the StreamBundle after starting to read the contents of the second + // stream. + BoundedSource secondarySource = primary.splitAtFraction(0.5); + assertNull(secondarySource); + + assertTrue(primary.advance()); + assertEquals("E", primary.getCurrent().get("name")); + assertFalse(primary.advance()); + } + + @Test + public void testActuateProjectionPushdown() { + org.apache.beam.sdk.schemas.Schema schema = + org.apache.beam.sdk.schemas.Schema.builder() + .addStringField("foo") + .addStringField("bar") + .build(); + TypedRead read = + BigQueryIO.read( + record -> + BigQueryUtils.toBeamRow( + record.getRecord(), schema, ConversionOptions.builder().build())) + .withMethod(Method.DIRECT_READ) + .withCoder(SchemaCoder.of(schema)); + + assertTrue(read.supportsProjectionPushdown()); + PTransform> pushdownT = + read.actuateProjectionPushdown( + ImmutableMap.of(new TupleTag<>("output"), FieldAccessDescriptor.withFieldNames("foo"))); + + TypedRead pushdownRead = (TypedRead) pushdownT; + assertEquals(Method.DIRECT_READ, pushdownRead.getMethod()); + assertThat(pushdownRead.getSelectedFields().get(), Matchers.containsInAnyOrder("foo")); + assertTrue(pushdownRead.getProjectionPushdownApplied()); + } + + @Test + public void testReadFromQueryDoesNotSupportProjectionPushdown() { + org.apache.beam.sdk.schemas.Schema schema = + org.apache.beam.sdk.schemas.Schema.builder() + .addStringField("foo") + .addStringField("bar") + .build(); + TypedRead read = + BigQueryIO.read( + record -> + BigQueryUtils.toBeamRow( + record.getRecord(), schema, ConversionOptions.builder().build())) + .fromQuery("SELECT bar FROM `dataset.table`") + .withMethod(Method.DIRECT_READ) + .withCoder(SchemaCoder.of(schema)); + + assertFalse(read.supportsProjectionPushdown()); + assertThrows( + IllegalArgumentException.class, + () -> + read.actuateProjectionPushdown( + ImmutableMap.of( + new TupleTag<>("output"), FieldAccessDescriptor.withFieldNames("foo")))); + } + + private static org.apache.arrow.vector.types.pojo.Field field( + String name, + boolean nullable, + ArrowType type, + org.apache.arrow.vector.types.pojo.Field... children) { + return new org.apache.arrow.vector.types.pojo.Field( + name, + new org.apache.arrow.vector.types.pojo.FieldType(nullable, type, null, null), + asList(children)); + } + + static org.apache.arrow.vector.types.pojo.Field field( + String name, ArrowType type, org.apache.arrow.vector.types.pojo.Field... children) { + return field(name, false, type, children); + } +}