diff --git a/.github/trigger_files/beam_PostCommit_Java.json b/.github/trigger_files/beam_PostCommit_Java.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.github/trigger_files/beam_PostCommit_Java_DataflowV1.json b/.github/trigger_files/beam_PostCommit_Java_DataflowV1.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_V2_Streaming.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink_Java11.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Twister2.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Twister2.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_ULR.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_ULR.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.github/trigger_files/beam_PostCommit_XVR_Flink.json b/.github/trigger_files/beam_PostCommit_XVR_Flink.json new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml b/.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml index 58ee88e20ec3..e55c9bb55436 100644 --- a/.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml +++ b/.github/workflows/beam_PostCommit_Java_PVR_Spark3_Streaming.yml @@ -54,7 +54,7 @@ jobs: beam_PostCommit_Java_PVR_Spark3_Streaming: name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) runs-on: [self-hosted, ubuntu-20.04, main] - timeout-minutes: 120 + timeout-minutes: 180 strategy: matrix: job_name: [beam_PostCommit_Java_PVR_Spark3_Streaming] diff --git a/.gitignore b/.gitignore index 193d650c32be..9d8d301032b6 100644 --- a/.gitignore +++ b/.gitignore @@ -54,6 +54,7 @@ sdks/python/NOTICE sdks/python/README.md sdks/python/apache_beam/transforms/xlang/* sdks/python/apache_beam/portability/api/* +sdks/python/apache_beam/yaml/docs/* sdks/python/nosetests*.xml sdks/python/pytest*.xml sdks/python/postcommit_requirements.txt diff --git a/.test-infra/tools/stale_dataflow_prebuilt_image_cleaner.sh b/.test-infra/tools/stale_dataflow_prebuilt_image_cleaner.sh index 126249324fed..e34f637dfbe2 100755 --- a/.test-infra/tools/stale_dataflow_prebuilt_image_cleaner.sh +++ b/.test-infra/tools/stale_dataflow_prebuilt_image_cleaner.sh @@ -56,6 +56,7 @@ done HAS_STALE_IMAGES="" FAILED_IMAGES="" +FAILED_COUNT=0 for image_name in ${IMAGE_NAMES[@]}; do echo IMAGES FOR image ${image_name} @@ -99,6 +100,7 @@ for image_name in ${IMAGE_NAMES[@]}; do if [ -z "$MANIFEST" ]; then # Sometimes "no such manifest" seen. Skip current if command hit error FAILED_IMAGES+=" $current" + FAILED_COUNT=$(($FAILED_COUNT + 1)) continue fi SHOULD_DELETE=0 @@ -129,7 +131,7 @@ for image_name in ${IMAGE_NAMES[@]}; do echo "Failed to delete the following images: ${FAILED_TO_DELETE}. Retrying each of them." for current in $RETRY_DELETE; do echo "Trying again to delete image ${image_name}@"${current}". Command: gcloud container images delete ${image_name}@"${current}" --force-delete-tags -q" - gcloud container images delete ${image_name}@"${current}" --force-delete-tags -q || FAILED_IMAGES+=" ${image_name}@${current}" + gcloud container images delete ${image_name}@"${current}" --force-delete-tags -q || (FAILED_IMAGES+=" ${image_name}@${current}" && FAILED_COUNT=$(($FAILED_COUNT + 1))) done fi done @@ -142,5 +144,10 @@ fi if [ -n "$FAILED_IMAGES" ]; then echo "Failed delete images $FAILED_IMAGES" - exit 1 -fi \ No newline at end of file + # Sometimes images may not be deleted on the first pass if they have dependencies on previous images. Only fail if we have a persistent leak + FAILED_THRESHOLD=10 + if [ $FAILED_COUNT -gt $FAILED_THRESHOLD ]; then + echo "Failed delete at least $FAILED_THRESHOLD images, failing job." + exit 1 + fi +fi diff --git a/CHANGES.md b/CHANGES.md index 3b9e0b60140e..2e23d72e664a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -62,12 +62,16 @@ ## I/Os +* Added support for handling bad records to BigQueryIO ([#30081](https://github.com/apache/beam/pull/30081)). + * Full Support for Storage Read and Write APIs + * Partial Support for File Loads (Failures writing to files supported, failures loading files to BQ unsupported) + * No Support for Extract or Streaming Inserts * Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## New Features / Improvements -* [Enrichment Transform](https://s.apache.org/enrichment-transform) along with GCP BigTable handler added to Python SDK ([#30001](https://github.com/apache/beam/pull/30001)). * Allow writing clustered and not time partitioned BigQuery tables (Java) ([#30094](https://github.com/apache/beam/pull/30094)). +* Redis cache support added to RequestResponseIO and Enrichment transform (Python) ([#30307](https://github.com/apache/beam/pull/30307)) * Merged sdks/java/fn-execution and runners/core-construction-java into the main SDK. These artifacts were never meant for users, but noting that they no longer exist. These are steps to bring portability into the core SDK alongside all other core functionality. @@ -78,6 +82,7 @@ * Go SDK users who build custom worker containers may run into issues with the move to distroless containers as a base (see Security Fixes). * The issue stems from distroless containers lacking additional tools, which current custom container processes may rely on. * See https://beam.apache.org/documentation/runtime/environments/#from-scratch-go for instructions on building and using a custom container. +* Python SDK has changed the default value for the `--max_cache_memory_usage_mb` pipeline option from 100 to 0. This option was first introduced in 2.52.0 SDK. This change restores the behavior of 2.51.0 SDK, which does not use the state cache. If your pipeline uses iterable side inputs views, consider increasing the cache size by setting the option manually. ([#30360](https://github.com/apache/beam/issues/30360)). ## Deprecations @@ -132,7 +137,7 @@ ## Known Issues -* N/A +* Some Python pipelines that run with 2.52.0-2.54.0 SDKs and use large materialized side inputs might be affected by a performance regression. To restore the prior behavior on these SDK versions, supply the `--max_cache_memory_usage_mb=0` pipeline option. ([#30360](https://github.com/apache/beam/issues/30360)). # [2.53.0] - 2024-01-04 @@ -173,7 +178,8 @@ ## Known Issues -* ([#29987](https://github.com/apache/beam/issues/29987)). +* Potential race condition causing NPE in DataflowExecutionStateSampler in Dataflow Java Streaming pipelines ([#29987](https://github.com/apache/beam/issues/29987)). +* Some Python pipelines that run with 2.52.0-2.54.0 SDKs and use large materialized side inputs might be affected by a performance regression. To restore the prior behavior on these SDK versions, supply the `--max_cache_memory_usage_mb=0` pipeline option. ([#30360](https://github.com/apache/beam/issues/30360)). # [2.52.0] - 2023-11-17 @@ -192,7 +198,7 @@ should handle this. ([#25252](https://github.com/apache/beam/issues/25252)). jobs using the DataStream API. By default the option is set to false, so the batch jobs are still executed using the DataSet API. * `upload_graph` as one of the Experiments options for DataflowRunner is no longer required when the graph is larger than 10MB for Java SDK ([PR#28621](https://github.com/apache/beam/pull/28621)). -* state amd side input cache has been enabled to a default of 100 MB. Use `--max_cache_memory_usage_mb=X` to provide cache size for the user state API and side inputs. (Python) ([#28770](https://github.com/apache/beam/issues/28770)). +* Introduced a pipeline option `--max_cache_memory_usage_mb` to configure state and side input cache size. The cache has been enabled to a default of 100 MB. Use `--max_cache_memory_usage_mb=X` to provide cache size for the user state API and side inputs. ([#28770](https://github.com/apache/beam/issues/28770)). * Beam YAML stable release. Beam pipelines can now be written using YAML and leverage the Beam YAML framework which includes a preliminary set of IO's and turnkey transforms. More information can be found in the YAML root folder and in the [README](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/yaml/README.md). @@ -218,6 +224,7 @@ as a workaround, a copy of "old" `CountingSource` class should be placed into a ## Known issues * MLTransform drops the identical elements in the output PCollection. For any duplicate elements, a single element will be emitted downstream. ([#29600](https://github.com/apache/beam/issues/29600)). +* Some Python pipelines that run with 2.52.0-2.54.0 SDKs and use large materialized side inputs might be affected by a performance regression. To restore the prior behavior on these SDK versions, supply the `--max_cache_memory_usage_mb=0` pipeline option. (Python) ([#30360](https://github.com/apache/beam/issues/30360)). # [2.51.0] - 2023-10-03 diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 19a1830b9b27..6434746fd3ab 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -2646,6 +2646,7 @@ class BeamModulePlugin implements Plugin { project.evaluationDependsOn(":sdks:python") project.evaluationDependsOn(":sdks:java:testing:expansion-service") + project.evaluationDependsOn(":sdks:java:core") project.evaluationDependsOn(":sdks:java:extensions:python") project.evaluationDependsOn(":sdks:go:test") @@ -2710,9 +2711,11 @@ class BeamModulePlugin implements Plugin { systemProperty "expansionPort", port systemProperty "semiPersistDir", config.semiPersistDir classpath = config.classpath + project.files( + project.project(":sdks:java:core").sourceSets.test.runtimeClasspath, project.project(":sdks:java:extensions:python").sourceSets.test.runtimeClasspath ) testClassesDirs = project.files( + project.project(":sdks:java:core").sourceSets.test.output.classesDirs, project.project(":sdks:java:extensions:python").sourceSets.test.output.classesDirs ) maxParallelForks config.numParallelTests diff --git a/it/google-cloud-platform/build.gradle b/it/google-cloud-platform/build.gradle index 4c5327b44c9a..1c63ade152d0 100644 --- a/it/google-cloud-platform/build.gradle +++ b/it/google-cloud-platform/build.gradle @@ -81,4 +81,7 @@ dependencies { tasks.register("GCSPerformanceTest", IoPerformanceTestUtilities.IoPerformanceTest, project, 'google-cloud-platform', 'FileBasedIOLT', ['configuration':'large','project':'apache-beam-testing', 'artifactBucket':'io-performance-temp']) tasks.register("BigTablePerformanceTest", IoPerformanceTestUtilities.IoPerformanceTest, project, 'google-cloud-platform', 'BigTableIOLT', ['configuration':'large','project':'apache-beam-testing', 'artifactBucket':'io-performance-temp']) +tasks.register("BigQueryPerformanceTest", IoPerformanceTestUtilities.IoPerformanceTest, project, 'google-cloud-platform', 'BigQueryIOLT', ['configuration':'medium','project':'apache-beam-testing', 'artifactBucket':'io-performance-temp']) +tasks.register("BigQueryStressTest", IoPerformanceTestUtilities.IoPerformanceTest, project, 'google-cloud-platform', 'BigQueryIOST', ['configuration':'medium','project':'apache-beam-testing', 'artifactBucket':'io-performance-temp']) tasks.register("BigQueryStorageApiStreamingPerformanceTest", IoPerformanceTestUtilities.IoPerformanceTest, project, 'google-cloud-platform', 'BigQueryStreamingLT', ['configuration':'large', 'project':'apache-beam-testing', 'artifactBucket':'io-performance-temp']) +tasks.register("WordCountIntegrationTest", IoPerformanceTestUtilities.IoPerformanceTest, project, 'google-cloud-platform', 'WordCountIT', ['project':'apache-beam-testing', 'artifactBucket':'io-performance-temp']) diff --git a/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOST.java b/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOST.java new file mode 100644 index 000000000000..6ffe1014c8ad --- /dev/null +++ b/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOST.java @@ -0,0 +1,604 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.it.gcp.bigquery; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.Timestamp; +import java.io.IOException; +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.text.ParseException; +import java.time.Duration; +import java.time.ZoneId; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.apache.beam.it.common.PipelineLauncher; +import org.apache.beam.it.common.PipelineOperator; +import org.apache.beam.it.common.TestProperties; +import org.apache.beam.it.common.utils.ResourceManagerUtils; +import org.apache.beam.it.gcp.IOLoadTestBase; +import org.apache.beam.runners.dataflow.options.DataflowPipelineWorkerPoolOptions; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.gcp.bigquery.AvroWriteRequest; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; +import org.apache.beam.sdk.io.synthetic.SyntheticSourceOptions; +import org.apache.beam.sdk.options.StreamingOptions; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.testutils.NamedTestResult; +import org.apache.beam.sdk.testutils.metrics.IOITMetrics; +import org.apache.beam.sdk.testutils.publishing.InfluxDBSettings; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.PeriodicImpulse; +import org.apache.beam.sdk.transforms.Reshuffle; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Longs; +import org.joda.time.Instant; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; + +/** + * BigQueryIO stress tests. The test is designed to assess the performance of BigQueryIO under + * various conditions. + * + *

Usage:
+ * - To run medium-scale stress tests: {@code gradle + * :it:google-cloud-platform:BigQueryStressTestMedium}
+ * - To run large-scale stress tests: {@code gradle + * :it:google-cloud-platform:BigQueryStressTestLarge} + */ +public final class BigQueryIOST extends IOLoadTestBase { + + private static final String READ_ELEMENT_METRIC_NAME = "read_count"; + private static final String TEST_ID = UUID.randomUUID().toString(); + private static final String TEST_TIMESTAMP = Timestamp.now().toString(); + private static final int DEFAULT_ROWS_PER_SECOND = 1000; + + /** + * The load will initiate at 1x, progressively increase to 2x and 4x, then decrease to 2x and + * eventually return to 1x. + */ + private static final int[] DEFAULT_LOAD_INCREASE_ARRAY = {1, 2, 2, 4, 2, 1}; + + private static BigQueryResourceManager resourceManager; + private static String tableQualifier; + private static InfluxDBSettings influxDBSettings; + + private Configuration configuration; + private String tempLocation; + private String testConfigName; + private TableSchema schema; + + @Rule public TestPipeline writePipeline = TestPipeline.create(); + + @BeforeClass + public static void beforeClass() { + resourceManager = + BigQueryResourceManager.builder("io-bigquery-lt", project, CREDENTIALS).build(); + resourceManager.createDataset(region); + } + + @Before + public void setup() { + // generate a random table names + String sourceTableName = + "io-bq-source-table-" + + DateTimeFormatter.ofPattern("MMddHHmmssSSS") + .withZone(ZoneId.of("UTC")) + .format(java.time.Instant.now()) + + UUID.randomUUID().toString().substring(0, 10); + tableQualifier = + String.format("%s:%s.%s", project, resourceManager.getDatasetId(), sourceTableName); + + // parse configuration + testConfigName = + TestProperties.getProperty("configuration", "medium", TestProperties.Type.PROPERTY); + configuration = TEST_CONFIGS_PRESET.get(testConfigName); + if (configuration == null) { + try { + configuration = Configuration.fromJsonString(testConfigName, Configuration.class); + } catch (IOException e) { + throw new IllegalArgumentException( + String.format( + "Unknown test configuration: [%s]. Pass to a valid configuration json, or use" + + " config presets: %s", + testConfigName, TEST_CONFIGS_PRESET.keySet())); + } + } + + // prepare schema + List fields = new ArrayList<>(configuration.numColumns); + for (int idx = 0; idx < configuration.numColumns; ++idx) { + fields.add(new TableFieldSchema().setName("data_" + idx).setType("BYTES")); + } + schema = new TableSchema().setFields(fields); + + // tempLocation needs to be set for bigquery IO writes + if (!Strings.isNullOrEmpty(tempBucketName)) { + tempLocation = String.format("gs://%s/temp/", tempBucketName); + writePipeline.getOptions().as(TestPipelineOptions.class).setTempRoot(tempLocation); + writePipeline.getOptions().setTempLocation(tempLocation); + } + } + + @AfterClass + public static void tearDownClass() { + ResourceManagerUtils.cleanResources(resourceManager); + } + + private static final Map TEST_CONFIGS_PRESET; + + static { + try { + TEST_CONFIGS_PRESET = + ImmutableMap.of( + "medium", + Configuration.fromJsonString( + "{\"numColumns\":10,\"rowsPerSecond\":25000,\"minutes\":30,\"numRecords\":90000000,\"valueSizeBytes\":1000,\"pipelineTimeout\":60,\"runner\":\"DataflowRunner\"}", + Configuration.class), + "large", + Configuration.fromJsonString( + "{\"numColumns\":20,\"rowsPerSecond\":25000,\"minutes\":240,\"numRecords\":720000000,\"valueSizeBytes\":10000,\"pipelineTimeout\":300,\"runner\":\"DataflowRunner\"}", + Configuration.class)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Test + public void testJsonStreamingWriteThenRead() throws IOException { + configuration.writeFormat = "JSON"; + configuration.writeMethod = "STREAMING_INSERTS"; + runTest(); + } + + @Test + public void testAvroStorageAPIWrite() throws IOException { + configuration.writeFormat = "AVRO"; + configuration.writeMethod = "STORAGE_WRITE_API"; + runTest(); + } + + @Test + public void testJsonStorageAPIWrite() throws IOException { + configuration.writeFormat = "JSON"; + configuration.writeMethod = "STORAGE_WRITE_API"; + runTest(); + } + + @Test + public void testAvroStorageAPIWriteAtLeastOnce() throws IOException { + configuration.writeFormat = "AVRO"; + configuration.writeMethod = "STORAGE_API_AT_LEAST_ONCE"; + runTest(); + } + + @Test + public void testJsonStorageAPIWriteAtLeastOnce() throws IOException { + configuration.writeFormat = "JSON"; + configuration.writeMethod = "STORAGE_API_AT_LEAST_ONCE"; + runTest(); + } + + /** + * Runs a stress test for BigQueryIO based on the specified configuration parameters. The method + * initializes the stress test by determining the WriteFormat, configuring the BigQueryIO. Write + * instance accordingly, and then executing data generation and read/write operations on BigQuery. + */ + private void runTest() throws IOException { + if (configuration.exportMetricsToInfluxDB) { + influxDBSettings = + InfluxDBSettings.builder() + .withHost(configuration.influxHost) + .withDatabase(configuration.influxDatabase) + .withMeasurement( + configuration.influxMeasurement + + "_" + + testConfigName + + "_" + + configuration.writeFormat + + "_" + + configuration.writeMethod) + .get(); + } + WriteFormat writeFormat = WriteFormat.valueOf(configuration.writeFormat); + BigQueryIO.Write writeIO = null; + switch (writeFormat) { + case AVRO: + writeIO = + BigQueryIO.write() + .withTriggeringFrequency(org.joda.time.Duration.standardSeconds(30)) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_APPEND) + .withAvroFormatFunction( + new AvroFormatFn( + configuration.numColumns, + !("STORAGE_WRITE_API".equalsIgnoreCase(configuration.writeMethod)))); + break; + case JSON: + writeIO = + BigQueryIO.write() + .withTriggeringFrequency(org.joda.time.Duration.standardSeconds(30)) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_APPEND) + .withSuccessfulInsertsPropagation(false) + .withFormatFunction(new JsonFormatFn(configuration.numColumns)); + break; + } + generateDataAndWrite(writeIO); + } + + /** + * The method creates a pipeline to simulate data generation and write operations to BigQuery + * table, based on the specified configuration parameters. The stress test involves varying the + * load dynamically over time, with options to use configurable parameters. + */ + private void generateDataAndWrite(BigQueryIO.Write writeIO) throws IOException { + BigQueryIO.Write.Method method = BigQueryIO.Write.Method.valueOf(configuration.writeMethod); + writePipeline.getOptions().as(StreamingOptions.class).setStreaming(true); + + // The PeriodicImpulse source will generate an element every this many millis: + int fireInterval = 1; + // Each element from PeriodicImpulse will fan out to this many elements: + int startMultiplier = + Math.max(configuration.rowsPerSecond, DEFAULT_ROWS_PER_SECOND) / DEFAULT_ROWS_PER_SECOND; + long stopAfterMillis = + org.joda.time.Duration.standardMinutes(configuration.minutes).getMillis(); + long totalRows = startMultiplier * stopAfterMillis / fireInterval; + List loadPeriods = + getLoadPeriods(configuration.minutes, DEFAULT_LOAD_INCREASE_ARRAY); + + PCollection source = + writePipeline + .apply( + PeriodicImpulse.create() + .stopAfter(org.joda.time.Duration.millis(stopAfterMillis - 1)) + .withInterval(org.joda.time.Duration.millis(fireInterval))) + .apply( + "Extract row IDs", + MapElements.into(TypeDescriptor.of(byte[].class)) + .via(instant -> Longs.toByteArray(instant.getMillis() % totalRows))); + if (startMultiplier > 1) { + source = + source + .apply( + "One input to multiple outputs", + ParDo.of(new MultiplierDoFn(startMultiplier, loadPeriods))) + .apply("Reshuffle fanout", Reshuffle.viaRandomKey()) + .apply("Counting element", ParDo.of(new CountingFn<>(READ_ELEMENT_METRIC_NAME))); + } + source.apply( + "Write to BQ", + writeIO + .to(tableQualifier) + .withMethod(method) + .withSchema(schema) + .withCustomGcsTempLocation(ValueProvider.StaticValueProvider.of(tempLocation))); + + PipelineLauncher.LaunchConfig options = + PipelineLauncher.LaunchConfig.builder("write-bigquery") + .setSdk(PipelineLauncher.Sdk.JAVA) + .setPipeline(writePipeline) + .addParameter("runner", configuration.runner) + .addParameter( + "autoscalingAlgorithm", + DataflowPipelineWorkerPoolOptions.AutoscalingAlgorithmType.THROUGHPUT_BASED + .toString()) + .addParameter("numWorkers", String.valueOf(configuration.numWorkers)) + .addParameter("maxNumWorkers", String.valueOf(configuration.maxNumWorkers)) + .addParameter("experiments", GcpOptions.STREAMING_ENGINE_EXPERIMENT) + .build(); + + PipelineLauncher.LaunchInfo launchInfo = pipelineLauncher.launch(project, region, options); + PipelineOperator.Result result = + pipelineOperator.waitUntilDone( + createConfig(launchInfo, Duration.ofMinutes(configuration.pipelineTimeout))); + + // Fail the test if pipeline failed. + assertNotEquals(PipelineOperator.Result.LAUNCH_FAILED, result); + + // check metrics + double numRecords = + pipelineLauncher.getMetric( + project, + region, + launchInfo.jobId(), + getBeamMetricsName(PipelineMetricsType.COUNTER, READ_ELEMENT_METRIC_NAME)); + Long rowCount = resourceManager.getRowCount(tableQualifier); + assertEquals(rowCount, numRecords, 0.5); + + // export metrics + MetricsConfiguration metricsConfig = + MetricsConfiguration.builder() + .setInputPCollection("Reshuffle fanout/Values/Values/Map.out0") + .setInputPCollectionV2("Reshuffle fanout/Values/Values/Map/ParMultiDo(Anonymous).out0") + .setOutputPCollection("Counting element.out0") + .setOutputPCollectionV2("Counting element/ParMultiDo(Counting).out0") + .build(); + try { + Map metrics = getMetrics(launchInfo, metricsConfig); + if (configuration.exportMetricsToInfluxDB) { + Collection namedTestResults = new ArrayList<>(); + for (Map.Entry entry : metrics.entrySet()) { + NamedTestResult metricResult = + NamedTestResult.create(TEST_ID, TEST_TIMESTAMP, entry.getKey(), entry.getValue()); + namedTestResults.add(metricResult); + } + IOITMetrics.publishToInflux(TEST_ID, TEST_TIMESTAMP, namedTestResults, influxDBSettings); + } else { + exportMetricsToBigQuery(launchInfo, metrics); + } + } catch (ParseException | InterruptedException e) { + throw new RuntimeException(e); + } + } + + /** + * Custom Apache Beam DoFn designed for use in stress testing scenarios. It introduces a dynamic + * load increase over time, multiplying the input elements based on the elapsed time since the + * start of processing. This class aims to simulate various load levels during stress testing. + */ + private static class MultiplierDoFn extends DoFn { + private final int startMultiplier; + private final long startTimesMillis; + private final List loadPeriods; + + MultiplierDoFn(int startMultiplier, List loadPeriods) { + this.startMultiplier = startMultiplier; + this.startTimesMillis = Instant.now().getMillis(); + this.loadPeriods = loadPeriods; + } + + @ProcessElement + public void processElement( + @Element byte[] element, + OutputReceiver outputReceiver, + @DoFn.Timestamp Instant timestamp) { + + int multiplier = this.startMultiplier; + long elapsedTimeMillis = timestamp.getMillis() - startTimesMillis; + + for (LoadPeriod loadPeriod : loadPeriods) { + if (elapsedTimeMillis >= loadPeriod.getPeriodStartMillis() + && elapsedTimeMillis < loadPeriod.getPeriodEndMillis()) { + multiplier *= loadPeriod.getLoadIncreaseMultiplier(); + break; + } + } + for (int i = 0; i < multiplier; i++) { + outputReceiver.output(element); + } + } + } + + abstract static class FormatFn implements SerializableFunction { + protected final int numColumns; + + public FormatFn(int numColumns) { + this.numColumns = numColumns; + } + } + + /** Avro format function that transforms AvroWriteRequest into a GenericRecord. */ + private static class AvroFormatFn extends FormatFn, GenericRecord> { + + protected final boolean isWrapBytes; + + public AvroFormatFn(int numColumns, boolean isWrapBytes) { + super(numColumns); + this.isWrapBytes = isWrapBytes; + } + + // TODO(https://github.com/apache/beam/issues/26408) eliminate this method once the Beam issue + // resolved + private Object maybeWrapBytes(byte[] input) { + if (isWrapBytes) { + return ByteBuffer.wrap(input); + } else { + return input; + } + } + + @Override + public GenericRecord apply(AvroWriteRequest writeRequest) { + byte[] data = Objects.requireNonNull(writeRequest.getElement()); + GenericRecord record = new GenericData.Record(writeRequest.getSchema()); + if (numColumns == 1) { + // only one column, just wrap incoming bytes + record.put("data_0", maybeWrapBytes(data)); + } else { + // otherwise, distribute bytes + int bytePerCol = data.length / numColumns; + int curIdx = 0; + for (int idx = 0; idx < numColumns - 1; ++idx) { + record.put( + "data_" + idx, maybeWrapBytes(Arrays.copyOfRange(data, curIdx, curIdx + bytePerCol))); + curIdx += bytePerCol; + } + record.put( + "data_" + (numColumns - 1), + maybeWrapBytes(Arrays.copyOfRange(data, curIdx, data.length))); + } + return record; + } + } + + /** JSON format function that transforms byte[] into a TableRow. */ + private static class JsonFormatFn extends FormatFn { + + public JsonFormatFn(int numColumns) { + super(numColumns); + } + + @Override + public TableRow apply(byte[] input) { + TableRow tableRow = new TableRow(); + Base64.Encoder encoder = Base64.getEncoder(); + if (numColumns == 1) { + // only one column, just wrap incoming bytes + tableRow.set("data_0", encoder.encodeToString(input)); + } else { + // otherwise, distribute bytes + int bytePerCol = input.length / numColumns; + int curIdx = 0; + for (int idx = 0; idx < numColumns - 1; ++idx) { + tableRow.set( + "data_" + idx, + encoder.encodeToString(Arrays.copyOfRange(input, curIdx, curIdx + bytePerCol))); + curIdx += bytePerCol; + } + tableRow.set( + "data_" + (numColumns - 1), + encoder.encodeToString(Arrays.copyOfRange(input, curIdx, input.length))); + } + return tableRow; + } + } + + /** + * Generates and returns a list of LoadPeriod instances representing periods of load increase + * based on the specified load increase array and total duration in minutes. + * + * @param minutesTotal The total duration in minutes for which the load periods are generated. + * @return A list of LoadPeriod instances defining periods of load increase. + */ + private List getLoadPeriods(int minutesTotal, int[] loadIncreaseArray) { + + List loadPeriods = new ArrayList<>(); + long periodDurationMillis = + Duration.ofMinutes(minutesTotal / loadIncreaseArray.length).toMillis(); + long startTimeMillis = 0; + + for (int loadIncreaseMultiplier : loadIncreaseArray) { + long endTimeMillis = startTimeMillis + periodDurationMillis; + loadPeriods.add(new LoadPeriod(loadIncreaseMultiplier, startTimeMillis, endTimeMillis)); + + startTimeMillis = endTimeMillis; + } + return loadPeriods; + } + + private enum WriteFormat { + AVRO, + JSON + } + + /** Options for Bigquery IO stress test. */ + static class Configuration extends SyntheticSourceOptions { + + /** + * Number of columns of each record. The column size is equally distributed as + * valueSizeBytes/numColumns. + */ + @JsonProperty public int numColumns = 1; + + /** Pipeline timeout in minutes. Must be a positive value. */ + @JsonProperty public int pipelineTimeout = 20; + + /** Runner specified to run the pipeline. */ + @JsonProperty public String runner = "DirectRunner"; + + /** Number of workers for the pipeline. */ + @JsonProperty public int numWorkers = 20; + + /** Maximum number of workers for the pipeline. */ + @JsonProperty public int maxNumWorkers = 100; + + /** BigQuery write method: DEFAULT/FILE_LOADS/STREAMING_INSERTS/STORAGE_WRITE_API. */ + @JsonProperty public String writeMethod = "DEFAULT"; + + /** BigQuery write format: AVRO/JSON. */ + @JsonProperty public String writeFormat = "AVRO"; + + /** + * Rate of generated elements sent to the source table. Will run with a minimum of 1k rows per + * second. + */ + @JsonProperty public int rowsPerSecond = DEFAULT_ROWS_PER_SECOND; + + /** Rows will be generated for this many minutes. */ + @JsonProperty public int minutes = 15; + + /** + * Determines the destination for exporting metrics. If set to true, metrics will be exported to + * InfluxDB and displayed using Grafana. If set to false, metrics will be exported to BigQuery + * and displayed with Looker Studio. + */ + @JsonProperty public boolean exportMetricsToInfluxDB = false; + + /** InfluxDB measurement to publish results to. * */ + @JsonProperty public String influxMeasurement = BigQueryIOST.class.getName(); + + /** InfluxDB host to publish metrics. * */ + @JsonProperty public String influxHost; + + /** InfluxDB database to publish metrics. * */ + @JsonProperty public String influxDatabase; + } + + /** + * Represents a period of time with associated load increase properties for stress testing + * scenarios. + */ + private static class LoadPeriod implements Serializable { + private final int loadIncreaseMultiplier; + private final long periodStartMillis; + private final long periodEndMillis; + + public LoadPeriod(int loadIncreaseMultiplier, long periodStartMillis, long periodEndMin) { + this.loadIncreaseMultiplier = loadIncreaseMultiplier; + this.periodStartMillis = periodStartMillis; + this.periodEndMillis = periodEndMin; + } + + public int getLoadIncreaseMultiplier() { + return loadIncreaseMultiplier; + } + + public long getPeriodStartMillis() { + return periodStartMillis; + } + + public long getPeriodEndMillis() { + return periodEndMillis; + } + } +} diff --git a/runners/direct-java/build.gradle b/runners/direct-java/build.gradle index b7f74dc3e538..c357b8a04328 100644 --- a/runners/direct-java/build.gradle +++ b/runners/direct-java/build.gradle @@ -142,8 +142,7 @@ task needsRunnerTests(type: Test) { excludeCategories "org.apache.beam.sdk.testing.LargeKeys\$Above100MB" // MetricsPusher isn't implemented in direct runner excludeCategories "org.apache.beam.sdk.testing.UsesMetricsPusher" - excludeCategories "org.apache.beam.sdk.testing.UsesJavaExpansionService" - excludeCategories "org.apache.beam.sdk.testing.UsesPythonExpansionService" + excludeCategories "org.apache.beam.sdk.testing.UsesExternalService" excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' } testLogging { @@ -173,8 +172,7 @@ task validatesRunner(type: Test) { excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' excludeCategories "org.apache.beam.sdk.testing.LargeKeys\$Above100MB" excludeCategories 'org.apache.beam.sdk.testing.UsesMetricsPusher' - excludeCategories "org.apache.beam.sdk.testing.UsesJavaExpansionService" - excludeCategories "org.apache.beam.sdk.testing.UsesPythonExpansionService" + excludeCategories "org.apache.beam.sdk.testing.UsesExternalService" // https://github.com/apache/beam/issues/18499 excludeCategories 'org.apache.beam.sdk.testing.UsesLoopingTimer' } diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle index d8bfcb3f533a..1d91284a3d16 100644 --- a/runners/flink/flink_runner.gradle +++ b/runners/flink/flink_runner.gradle @@ -279,6 +279,7 @@ def createValidatesRunnerTask(Map m) { excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream' } else { includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' // Should be run only in a properly configured SDK harness environment excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' diff --git a/runners/flink/job-server/flink_job_server.gradle b/runners/flink/job-server/flink_job_server.gradle index ab3f726e6ad9..e88b3c7d9e81 100644 --- a/runners/flink/job-server/flink_job_server.gradle +++ b/runners/flink/job-server/flink_job_server.gradle @@ -167,8 +167,7 @@ def portableValidatesRunnerTask(String name, boolean streaming, boolean checkpoi // Larger keys are possible, but they require more memory. excludeCategories 'org.apache.beam.sdk.testing.LargeKeys$Above10MB' excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics' - excludeCategories 'org.apache.beam.sdk.testing.UsesJavaExpansionService' - excludeCategories 'org.apache.beam.sdk.testing.UsesPythonExpansionService' + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging' excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage' excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics' diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index 27b1ca83a9b9..4c4b58d90323 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -180,8 +180,7 @@ def commonLegacyExcludeCategories = [ 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment', 'org.apache.beam.sdk.testing.LargeKeys$Above10MB', 'org.apache.beam.sdk.testing.UsesAttemptedMetrics', - 'org.apache.beam.sdk.testing.UsesJavaExpansionService', - 'org.apache.beam.sdk.testing.UsesPythonExpansionService', + 'org.apache.beam.sdk.testing.UsesExternalService', 'org.apache.beam.sdk.testing.UsesDistributionMetrics', 'org.apache.beam.sdk.testing.UsesGaugeMetrics', 'org.apache.beam.sdk.testing.UsesMultimapState', @@ -192,6 +191,7 @@ def commonLegacyExcludeCategories = [ ] def commonRunnerV2ExcludeCategories = [ + 'org.apache.beam.sdk.testing.UsesExternalService', 'org.apache.beam.sdk.testing.UsesGaugeMetrics', 'org.apache.beam.sdk.testing.UsesSetState', 'org.apache.beam.sdk.testing.UsesMapState', diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index a454945c3d01..138aa22ff473 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -1687,7 +1687,11 @@ public CompositeBehavior enterCompositeTransform(Node node) { String rootBigQueryTransform = ""; if (transform.getClass().equals(StorageApiLoads.class)) { StorageApiLoads storageLoads = (StorageApiLoads) transform; - failedTag = storageLoads.getFailedRowsTag(); + // If the storage load is directing exceptions to an error handler, we don't need to + // warn for unconsumed rows + if (!storageLoads.usesErrorHandler()) { + failedTag = storageLoads.getFailedRowsTag(); + } // For storage API the transform that outputs failed rows is nested one layer below // BigQueryIO. rootBigQueryTransform = node.getEnclosingNode().getFullName(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java index 80c957996ab7..0c690cf97757 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java @@ -178,9 +178,7 @@ public static StreamingEngineClient create( getWorkBudgetDistributor, dispatcherClient, new Random().nextLong()); - streamingEngineClient.startGetWorkerMetadataStream(); - streamingEngineClient.startWorkerMetadataConsumer(); - streamingEngineClient.getWorkBudgetRefresher.start(); + streamingEngineClient.start(); return streamingEngineClient; } @@ -206,12 +204,16 @@ static StreamingEngineClient forTesting( getWorkBudgetDistributor, dispatcherClient, clientId); - streamingEngineClient.startGetWorkerMetadataStream(); - streamingEngineClient.startWorkerMetadataConsumer(); - streamingEngineClient.getWorkBudgetRefresher.start(); + streamingEngineClient.start(); return streamingEngineClient; } + private void start() { + startGetWorkerMetadataStream(); + startWorkerMetadataConsumer(); + getWorkBudgetRefresher.start(); + } + @SuppressWarnings("FutureReturnValueIgnored") private void startWorkerMetadataConsumer() { newWorkerMetadataConsumer.submit( @@ -223,11 +225,6 @@ private void startWorkerMetadataConsumer() { }); } - @VisibleForTesting - boolean isWorkerMetadataReady() { - return !connections.get().equals(StreamingEngineConnectionState.EMPTY); - } - @VisibleForTesting void finish() { if (!started.compareAndSet(true, false)) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java index f755f0333387..f9011a90c065 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClientTest.java @@ -29,11 +29,14 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Comparator; import java.util.HashSet; +import java.util.List; import java.util.Optional; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -50,7 +53,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemProcessor; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor; -import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; @@ -62,10 +64,10 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.junit.After; import org.junit.Before; -import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; @@ -92,7 +94,10 @@ public class StreamingEngineClientTest { .setProjectId(PROJECT_ID) .setWorkerId(WORKER_ID) .build(); + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private final Set channels = new HashSet<>(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); private final GrpcWindmillStreamFactory streamFactory = @@ -109,11 +114,8 @@ public class StreamingEngineClientTest { private final GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.forTesting( stubFactory, new ArrayList<>(), new ArrayList<>(), new HashSet<>()); - private final GetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor()); private final AtomicReference connections = new AtomicReference<>(StreamingEngineConnectionState.EMPTY); - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private Server fakeStreamingEngineServer; private CountDownLatch getWorkerMetadataReady; private GetWorkerMetadataTestStub fakeGetWorkerMetadataStub; @@ -167,14 +169,16 @@ public void setUp() throws IOException { @After public void cleanUp() { + Preconditions.checkNotNull(streamingEngineClient).finish(); fakeGetWorkerMetadataStub.close(); fakeStreamingEngineServer.shutdownNow(); channels.forEach(ManagedChannel::shutdownNow); - Preconditions.checkNotNull(streamingEngineClient).finish(); } private StreamingEngineClient newStreamingEngineClient( - GetWorkBudget getWorkBudget, WorkItemProcessor workItemProcessor) { + GetWorkBudget getWorkBudget, + GetWorkBudgetDistributor getWorkBudgetDistributor, + WorkItemProcessor workItemProcessor) { return StreamingEngineClient.forTesting( JOB_HEADER, getWorkBudget, @@ -191,10 +195,15 @@ private StreamingEngineClient newStreamingEngineClient( public void testStreamsStartCorrectly() throws InterruptedException { long items = 10L; long bytes = 10L; + int numBudgetDistributionsExpected = 1; + + TestGetWorkBudgetDistributor getWorkBudgetDistributor = + spy(new TestGetWorkBudgetDistributor(numBudgetDistributionsExpected)); streamingEngineClient = newStreamingEngineClient( GetWorkBudget.builder().setItems(items).setBytes(bytes).build(), + getWorkBudgetDistributor, noOpProcessWorkItemFn()); String workerToken = "workerToken1"; @@ -210,12 +219,14 @@ public void testStreamsStartCorrectly() throws InterruptedException { getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); - StreamingEngineConnectionState currentConnections = waitForWorkerMetadataToBeConsumed(1); + waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); + + StreamingEngineConnectionState currentConnections = connections.get(); assertEquals(2, currentConnections.windmillConnections().size()); assertEquals(2, currentConnections.windmillStreams().size()); Set workerTokens = - connections.get().windmillConnections().values().stream() + currentConnections.windmillConnections().values().stream() .map(WindmillConnection::backendWorkerToken) .filter(Optional::isPresent) .map(Optional::get) @@ -238,9 +249,13 @@ public void testStreamsStartCorrectly() throws InterruptedException { @Test public void testScheduledBudgetRefresh() throws InterruptedException { + TestGetWorkBudgetDistributor getWorkBudgetDistributor = + spy(new TestGetWorkBudgetDistributor(2)); streamingEngineClient = newStreamingEngineClient( - GetWorkBudget.builder().setItems(1L).setBytes(1L).build(), noOpProcessWorkItemFn()); + GetWorkBudget.builder().setItems(1L).setBytes(1L).build(), + getWorkBudgetDistributor, + noOpProcessWorkItemFn()); getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata( @@ -249,18 +264,21 @@ public void testScheduledBudgetRefresh() throws InterruptedException { .addWorkEndpoints(metadataResponseEndpoint("workerToken")) .putAllGlobalDataEndpoints(DEFAULT) .build()); - waitForWorkerMetadataToBeConsumed(1); - Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS); + waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); verify(getWorkBudgetDistributor, atLeast(2)).distributeBudget(any(), any()); } @Test - @Ignore("https://github.com/apache/beam/issues/28957") // stuck test public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() throws InterruptedException { + int metadataCount = 2; + TestGetWorkBudgetDistributor getWorkBudgetDistributor = + spy(new TestGetWorkBudgetDistributor(metadataCount)); streamingEngineClient = newStreamingEngineClient( - GetWorkBudget.builder().setItems(1).setBytes(1).build(), noOpProcessWorkItemFn()); + GetWorkBudget.builder().setItems(1).setBytes(1).build(), + getWorkBudgetDistributor, + noOpProcessWorkItemFn()); String workerToken = "workerToken1"; String workerToken2 = "workerToken2"; @@ -292,9 +310,8 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); - - StreamingEngineConnectionState currentConnections = waitForWorkerMetadataToBeConsumed(2); - + waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); + StreamingEngineConnectionState currentConnections = connections.get(); assertEquals(1, currentConnections.windmillConnections().size()); assertEquals(1, currentConnections.windmillStreams().size()); Set workerTokens = @@ -310,10 +327,6 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() @Test public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedException { - streamingEngineClient = - newStreamingEngineClient( - GetWorkBudget.builder().setItems(1).setBytes(1).build(), noOpProcessWorkItemFn()); - String workerToken = "workerToken1"; String workerToken2 = "workerToken2"; String workerToken3 = "workerToken3"; @@ -346,39 +359,39 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce .putAllGlobalDataEndpoints(DEFAULT) .build(); + List workerMetadataResponses = + Lists.newArrayList(firstWorkerMetadata, secondWorkerMetadata, thirdWorkerMetadata); + + TestGetWorkBudgetDistributor getWorkBudgetDistributor = + spy(new TestGetWorkBudgetDistributor(workerMetadataResponses.size())); + streamingEngineClient = + newStreamingEngineClient( + GetWorkBudget.builder().setItems(1).setBytes(1).build(), + getWorkBudgetDistributor, + noOpProcessWorkItemFn()); + getWorkerMetadataReady.await(); - fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); - Thread.sleep(50); - fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); - Thread.sleep(50); - fakeGetWorkerMetadataStub.injectWorkerMetadata(thirdWorkerMetadata); - Thread.sleep(50); - verify(getWorkBudgetDistributor, atLeast(3)).distributeBudget(any(), any()); + + // Make sure we are injecting the metadata from smallest to largest. + workerMetadataResponses.stream() + .sorted(Comparator.comparingLong(WorkerMetadataResponse::getMetadataVersion)) + .forEach(fakeGetWorkerMetadataStub::injectWorkerMetadata); + + waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); + verify(getWorkBudgetDistributor, atLeast(workerMetadataResponses.size())) + .distributeBudget(any(), any()); } - private StreamingEngineConnectionState waitForWorkerMetadataToBeConsumed( - int expectedMetadataConsumed) throws InterruptedException { - int currentMetadataConsumed = 0; - StreamingEngineConnectionState currentConsumedMetadata = StreamingEngineConnectionState.EMPTY; - while (true) { - if (!connections.get().equals(currentConsumedMetadata)) { - ++currentMetadataConsumed; - if (currentMetadataConsumed == expectedMetadataConsumed) { - break; - } - currentConsumedMetadata = connections.get(); - } - } - // Wait for metadata to be consumed and budgets to be redistributed. - Thread.sleep(GetWorkBudgetRefresher.SCHEDULED_BUDGET_REFRESH_MILLIS); - return connections.get(); + private void waitForWorkerMetadataToBeConsumed( + TestGetWorkBudgetDistributor getWorkBudgetDistributor) throws InterruptedException { + getWorkBudgetDistributor.waitForBudgetDistribution(); } private static class GetWorkerMetadataTestStub extends CloudWindmillMetadataServiceV1Alpha1Grpc .CloudWindmillMetadataServiceV1Alpha1ImplBase { private static final WorkerMetadataResponse CLOSE_ALL_STREAMS = - WorkerMetadataResponse.newBuilder().setMetadataVersion(100).build(); + WorkerMetadataResponse.newBuilder().setMetadataVersion(Long.MAX_VALUE).build(); private final CountDownLatch ready; private @Nullable StreamObserver responseObserver; @@ -426,10 +439,22 @@ private void close() { } private static class TestGetWorkBudgetDistributor implements GetWorkBudgetDistributor { + private final CountDownLatch getWorkBudgetDistributorTriggered; + + private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) { + this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); + } + + @SuppressWarnings("ReturnValueIgnored") + private void waitForBudgetDistribution() throws InterruptedException { + getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); + } + @Override public void distributeBudget( ImmutableCollection streams, GetWorkBudget getWorkBudget) { streams.forEach(stream -> stream.adjustBudget(getWorkBudget.items(), getWorkBudget.bytes())); + getWorkBudgetDistributorTriggered.countDown(); } } } diff --git a/runners/jet/build.gradle b/runners/jet/build.gradle index 2b6bf2bcdff4..56a001a2bceb 100644 --- a/runners/jet/build.gradle +++ b/runners/jet/build.gradle @@ -71,6 +71,7 @@ task validatesRunnerBatch(type: Test) { useJUnit { includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' // Should be run only in a properly configured SDK harness environment + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' excludeCategories "org.apache.beam.sdk.testing.LargeKeys\$Above100MB" excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap' diff --git a/runners/portability/java/build.gradle b/runners/portability/java/build.gradle index 45709a0dbd1e..9830b48c83ad 100644 --- a/runners/portability/java/build.gradle +++ b/runners/portability/java/build.gradle @@ -149,6 +149,7 @@ def createUlrValidatesRunnerTask = { name, environmentType, dockerImageTask = "" useJUnit { includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' // Should be run only in a properly configured SDK harness environment + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics' excludeCategories 'org.apache.beam.sdk.testing.UsesOnWindowExpiration' diff --git a/runners/samza/build.gradle b/runners/samza/build.gradle index 8e2cc2a5eb7d..a50e0d62e59a 100644 --- a/runners/samza/build.gradle +++ b/runners/samza/build.gradle @@ -124,6 +124,7 @@ tasks.register("validatesRunner", Test) { useJUnit { includeCategories 'org.apache.beam.sdk.testing.NeedsRunner' includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' // Should be run only in a properly configured SDK harness environment excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo' diff --git a/runners/samza/job-server/build.gradle b/runners/samza/job-server/build.gradle index c9401a8aff17..7bb1b84dbaae 100644 --- a/runners/samza/job-server/build.gradle +++ b/runners/samza/job-server/build.gradle @@ -86,8 +86,7 @@ def portableValidatesRunnerTask(String name, boolean docker) { // Larger keys are possible, but they require more memory. excludeCategories 'org.apache.beam.sdk.testing.LargeKeys$Above10MB' excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics' - excludeCategories 'org.apache.beam.sdk.testing.UsesJavaExpansionService' - excludeCategories 'org.apache.beam.sdk.testing.UsesPythonExpansionService' + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging' excludeCategories 'org.apache.beam.sdk.testing.UsesFailureMessage' excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics' diff --git a/runners/spark/job-server/spark_job_server.gradle b/runners/spark/job-server/spark_job_server.gradle index 8945ad3498c7..6c884e8e2233 100644 --- a/runners/spark/job-server/spark_job_server.gradle +++ b/runners/spark/job-server/spark_job_server.gradle @@ -108,6 +108,7 @@ def portableValidatesRunnerTask(String name, boolean streaming, boolean docker, testCategories = { includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' // Should be run only in a properly configured SDK harness environment excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' excludeCategories 'org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders' @@ -167,6 +168,7 @@ def portableValidatesRunnerTask(String name, boolean streaming, boolean docker, // Batch testCategories = { includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' // Should be run only in a properly configured SDK harness environment excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' excludeCategories 'org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders' diff --git a/runners/spark/spark_runner.gradle b/runners/spark/spark_runner.gradle index d775cfd4d6e9..5d0a7f02d17f 100644 --- a/runners/spark/spark_runner.gradle +++ b/runners/spark/spark_runner.gradle @@ -274,8 +274,7 @@ def applyBatchValidatesRunnerSetup = { Test it -> // SDF excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo' // Portability - excludeCategories 'org.apache.beam.sdk.testing.UsesJavaExpansionService' - excludeCategories 'org.apache.beam.sdk.testing.UsesPythonExpansionService' + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' // Ordering excludeCategories 'org.apache.beam.sdk.testing.UsesPerKeyOrderedDelivery' @@ -349,8 +348,7 @@ def validatesRunnerStreaming = tasks.register("validatesRunnerStreaming", Test) excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo' excludeCategories 'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo' // Portability - excludeCategories 'org.apache.beam.sdk.testing.UsesJavaExpansionService' - excludeCategories 'org.apache.beam.sdk.testing.UsesPythonExpansionService' + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' // Ordering excludeCategories 'org.apache.beam.sdk.testing.UsesPerKeyOrderedDelivery' @@ -405,8 +403,7 @@ tasks.register("validatesStructuredStreamingRunnerBatch", Test) { // SDF excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo' // Portability - excludeCategories 'org.apache.beam.sdk.testing.UsesJavaExpansionService' - excludeCategories 'org.apache.beam.sdk.testing.UsesPythonExpansionService' + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer' excludeCategories 'org.apache.beam.sdk.testing.UsesTriggeredSideInputs' } diff --git a/runners/twister2/build.gradle b/runners/twister2/build.gradle index 36a044737d82..744c4e171fd1 100644 --- a/runners/twister2/build.gradle +++ b/runners/twister2/build.gradle @@ -80,6 +80,7 @@ def validatesRunnerBatch = tasks.register("validatesRunnerBatch", Test) { forkEvery 1 useJUnit { includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' // Should be run only in a properly configured SDK harness environment excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' excludeCategories 'org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders' diff --git a/sdks/go.mod b/sdks/go.mod index 7eb08fd17050..df02e9a90614 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -78,6 +78,7 @@ require ( github.com/Microsoft/hcsshim v0.11.4 // indirect github.com/containerd/log v0.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/distribution/reference v0.5.0 // indirect github.com/frankban/quicktest v1.14.0 // indirect github.com/go-logr/logr v1.3.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -85,6 +86,7 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/minio/highwayhash v1.0.2 // indirect + github.com/moby/sys/user v0.1.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/nats-io/jwt/v2 v2.5.3 // indirect @@ -141,8 +143,7 @@ require ( github.com/cncf/xds/go v0.0.0-20231109132714-523115ebc101 // indirect github.com/containerd/containerd v1.7.11 // indirect github.com/cpuguy83/dockercfg v0.3.1 // indirect - github.com/docker/distribution v2.8.2+incompatible // indirect - github.com/docker/docker v24.0.7+incompatible // but required to resolve issue docker has with go1.20 + github.com/docker/docker v25.0.3+incompatible // but required to resolve issue docker has with go1.20 github.com/docker/go-units v0.5.0 // indirect github.com/envoyproxy/go-control-plane v0.11.1 // indirect github.com/envoyproxy/protoc-gen-validate v1.0.2 // indirect @@ -174,7 +175,6 @@ require ( github.com/morikuni/aec v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0-rc5 // indirect - github.com/opencontainers/runc v1.1.12 // indirect github.com/pierrec/lz4/v4 v4.1.15 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pkg/xattr v0.4.9 // indirect diff --git a/sdks/go.sum b/sdks/go.sum index 275bec50e406..50b6d2039fc4 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -161,10 +161,10 @@ github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= -github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= -github.com/docker/docker v24.0.7+incompatible h1:Wo6l37AuwP3JaMnZa226lzVXGA3F9Ig1seQen0cKYlM= -github.com/docker/docker v24.0.7+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0= +github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/docker/docker v25.0.3+incompatible h1:D5fy/lYmY7bvZa0XTZ5/UJPljor41F+vdyJG5luQLfQ= +github.com/docker/docker v25.0.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= @@ -288,6 +288,7 @@ github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyE github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms= github.com/hashicorp/go-uuid v0.0.0-20180228145832-27454136f036/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= @@ -353,6 +354,8 @@ github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkV github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc= github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5lXtc= github.com/moby/sys/sequential v0.5.0/go.mod h1:tH2cOOs5V9MlPiXcQzRC+eEyab644PWKGRYaaV5ZZlo= +github.com/moby/sys/user v0.1.0 h1:WmZ93f5Ux6het5iituh9x2zAG7NFY9Aqi49jjE1PaQg= +github.com/moby/sys/user v0.1.0/go.mod h1:fKJhFOnsCN6xZ5gSfbM6zaHGgDJMrqt9/reuj4T7MmU= github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -379,8 +382,6 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8 github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0-rc5 h1:Ygwkfw9bpDvs+c9E34SdgGOj41dX/cbdlwvlWt0pnFI= github.com/opencontainers/image-spec v1.1.0-rc5/go.mod h1:X4pATf0uXsnn3g5aiGIsVnJBR4mxhKzfwmvK/B2NTm8= -github.com/opencontainers/runc v1.1.12 h1:BOIssBaW1La0/qbNZHXOOa71dZfZEQOzW7dqQf3phss= -github.com/opencontainers/runc v1.1.12/go.mod h1:S+lQwSfncpBha7XTy/5lBwWgm5+y5Ma/O44Ekby9FK8= github.com/pborman/getopt v0.0.0-20180729010549-6fdd0a2c7117/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pierrec/lz4/v4 v4.1.15 h1:MO0/ucJhngq7299dKLwIMtgTfbkoSPF6AoMYDd8Q4q0= @@ -483,12 +484,15 @@ go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 h1:aFJWCqJ go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1/go.mod h1:sEGXWArGqc3tVa+ekntsN65DmVbVeW+7lTKTjZF3/Fo= go.opentelemetry.io/otel v1.21.0 h1:hzLeKBZEL7Okw2mGzZ0cc4k/A7Fta0uoPgaJCr8fsFc= go.opentelemetry.io/otel v1.21.0/go.mod h1:QZzNPQPm1zLX4gZK4cMi+71eaorMSGT3A4znnUvNNEo= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= go.opentelemetry.io/otel/metric v1.21.0 h1:tlYWfeo+Bocx5kLEloTjbcDwBuELRrIFxwdQ36PlJu4= go.opentelemetry.io/otel/metric v1.21.0/go.mod h1:o1p3CA8nNHW8j5yuQLdc1eeqEaPfzug24uvsyIEJRWM= go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8= go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E= go.opentelemetry.io/otel/trace v1.21.0 h1:WD9i5gzvoUPuXIXH24ZNBudiarZDKuekPqi/E8fpfLc= go.opentelemetry.io/otel/trace v1.21.0/go.mod h1:LGbsEB0f9LGjN+OZaQQ26sohbOmiMR+BaslueVtS/qQ= +go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java index 43ab473dcc49..d68338eceaf4 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileBasedSource.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.io.fs.EmptyMatchTreatment; import org.apache.beam.sdk.io.fs.MatchResult; import org.apache.beam.sdk.io.fs.MatchResult.Metadata; +import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; @@ -472,23 +473,30 @@ public synchronized FileBasedSource getCurrentSource() { @Override protected final boolean startImpl() throws IOException { FileBasedSource source = getCurrentSource(); - this.channel = FileSystems.open(source.getSingleFileMetadata().resourceId()); - if (channel instanceof SeekableByteChannel) { - SeekableByteChannel seekChannel = (SeekableByteChannel) channel; - seekChannel.position(source.getStartOffset()); - } else { - // Channel is not seekable. Must not be a subrange. - checkArgument( - source.mode != Mode.SINGLE_FILE_OR_SUBRANGE, - "Subrange-based sources must only be defined for file types that support seekable " - + " read channels"); - checkArgument( - source.getStartOffset() == 0, - "Start offset %s is not zero but channel for reading the file is not seekable.", - source.getStartOffset()); - } + ResourceId resourceId = source.getSingleFileMetadata().resourceId(); + try { + this.channel = FileSystems.open(resourceId); + if (channel instanceof SeekableByteChannel) { + SeekableByteChannel seekChannel = (SeekableByteChannel) channel; + seekChannel.position(source.getStartOffset()); + } else { + // Channel is not seekable. Must not be a subrange. + checkArgument( + source.mode != Mode.SINGLE_FILE_OR_SUBRANGE, + "Subrange-based sources must only be defined for file types that support seekable " + + " read channels"); + checkArgument( + source.getStartOffset() == 0, + "Start offset %s is not zero but channel for reading the file is not seekable.", + source.getStartOffset()); + } - startReading(channel); + startReading(channel); + } catch (IOException e) { + LOG.error( + "Failed to process {}, which could be corrupted or have a wrong format.", resourceId); + throw new IOException(e); + } // Advance once to load the first record. return advanceImpl(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesExternalService.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesExternalService.java new file mode 100644 index 000000000000..a9e0b9d2236c --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesExternalService.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.testing; + +import org.apache.beam.sdk.annotations.Internal; + +/** + * Category tag for tests which relies on a pre-defined port, such as expansion service or transform + * service. Tests tagged with {@link UsesExternalService} should initialize such port before the + * test execution. + */ +@Internal +public interface UsesExternalService {} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesJavaExpansionService.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesJavaExpansionService.java index ad919ae4b88f..766854c8cafd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesJavaExpansionService.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesJavaExpansionService.java @@ -24,4 +24,4 @@ * UsesJavaExpansionService} should be run for runners which support cross-language transforms. */ @Internal -public interface UsesJavaExpansionService {} +public interface UsesJavaExpansionService extends UsesExternalService {} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesPythonExpansionService.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesPythonExpansionService.java index b92742e5db8b..0fbab223934d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesPythonExpansionService.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/UsesPythonExpansionService.java @@ -24,4 +24,4 @@ * UsesPythonExpansionService} should be run for runners which support cross-language transforms. */ @Internal -public interface UsesPythonExpansionService {} +public interface UsesPythonExpansionService extends UsesExternalService {} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandlingTestUtils.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandlingTestUtils.java index 41367765b920..b4db4867cfc7 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandlingTestUtils.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/errorhandling/ErrorHandlingTestUtils.java @@ -45,4 +45,13 @@ public static class ErrorSinkTransform } } } + + public static class EchoErrorTransform + extends PTransform, PCollection> { + + @Override + public PCollection expand(PCollection input) { + return input; + } + } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AvroRowWriter.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AvroRowWriter.java index 0b64a65c4503..1f45371b19ff 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AvroRowWriter.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AvroRowWriter.java @@ -47,9 +47,15 @@ class AvroRowWriter extends BigQueryRowWriter { } @Override - public void write(T element) throws IOException { + public void write(T element) throws IOException, BigQueryRowSerializationException { AvroWriteRequest writeRequest = new AvroWriteRequest<>(element, schema); - writer.append(toAvroRecord.apply(writeRequest)); + AvroT serializedRequest; + try { + serializedRequest = toAvroRecord.apply(writeRequest); + } catch (Exception e) { + throw new BigQueryRowSerializationException(e); + } + writer.append(serializedRequest); } public Schema getSchema() { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java index 23acd8e01f7f..56bd14318be4 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.bigquery; import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.resolveTempLocation; +import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.BAD_RECORD_TAG; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; @@ -57,6 +58,9 @@ import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.WithKeys; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; import org.apache.beam.sdk.transforms.windowing.AfterFirst; import org.apache.beam.sdk.transforms.windowing.AfterPane; import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime; @@ -77,6 +81,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; @@ -161,6 +166,8 @@ class BatchLoads private final RowWriterFactory rowWriterFactory; private final @Nullable String kmsKey; private final String tempDataset; + private final BadRecordRouter badRecordRouter; + private final ErrorHandler badRecordErrorHandler; private Coder tableDestinationCoder; // The maximum number of times to retry failed load or copy jobs. @@ -180,7 +187,9 @@ class BatchLoads @Nullable String kmsKey, boolean clusteringEnabled, boolean useAvroLogicalTypes, - String tempDataset) { + String tempDataset, + BadRecordRouter badRecordRouter, + ErrorHandler badRecordErrorHandler) { bigQueryServices = new BigQueryServicesImpl(); this.writeDisposition = writeDisposition; this.createDisposition = createDisposition; @@ -207,6 +216,8 @@ class BatchLoads this.tempDataset = tempDataset; this.tableDestinationCoder = clusteringEnabled ? TableDestinationCoderV3.of() : TableDestinationCoderV2.of(); + this.badRecordRouter = badRecordRouter; + this.badRecordErrorHandler = badRecordErrorHandler; } void setSchemaUpdateOptions(Set schemaUpdateOptions) { @@ -601,9 +612,13 @@ PCollection> writeDynamicallyShardedFil unwrittedRecordsTag, maxNumWritersPerBundle, maxFileSize, - rowWriterFactory)) + rowWriterFactory, + input.getCoder(), + badRecordRouter)) .withSideInputs(tempFilePrefix) - .withOutputTags(writtenFilesTag, TupleTagList.of(unwrittedRecordsTag))); + .withOutputTags( + writtenFilesTag, + TupleTagList.of(ImmutableList.of(unwrittedRecordsTag, BAD_RECORD_TAG)))); PCollection> writtenFiles = writeBundlesTuple .get(writtenFilesTag) @@ -612,6 +627,8 @@ PCollection> writeDynamicallyShardedFil writeBundlesTuple .get(unwrittedRecordsTag) .setCoder(KvCoder.of(ShardedKeyCoder.of(destinationCoder), elementCoder)); + badRecordErrorHandler.addErrorCollection( + writeBundlesTuple.get(BAD_RECORD_TAG).setCoder(BadRecord.getCoder(input.getPipeline()))); // If the bundles contain too many output tables to be written inline to files (due to memory // limits), any unwritten records will be spilled to the unwrittenRecordsTag PCollection. @@ -680,62 +697,92 @@ PCollection> writeDynamicallyShardedFil // parallelize properly. We also ensure that the files are written if a threshold number of // records are ready. Dynamic sharding is achieved via the withShardedKey() option provided by // GroupIntoBatches. - return input - .apply( - GroupIntoBatches.ofSize(FILE_TRIGGERING_RECORD_COUNT) - .withByteSize(byteSize) - .withMaxBufferingDuration(maxBufferingDuration) - .withShardedKey()) - .setCoder( - KvCoder.of( - org.apache.beam.sdk.util.ShardedKey.Coder.of(destinationCoder), - IterableCoder.of(elementCoder))) - .apply( - "StripShardId", - MapElements.via( - new SimpleFunction< - KV, Iterable>, - KV>>() { - @Override - public KV> apply( - KV, Iterable> - input) { - return KV.of(input.getKey().getKey(), input.getValue()); - } - })) - .setCoder(KvCoder.of(destinationCoder, IterableCoder.of(elementCoder))) - .apply( - "WriteGroupedRecords", - ParDo.of( - new WriteGroupedRecordsToFiles( - tempFilePrefix, maxFileSize, rowWriterFactory)) - .withSideInputs(tempFilePrefix)) + TupleTag> successfulResultsTag = new TupleTag<>(); + PCollectionTuple writeResults = + input + .apply( + GroupIntoBatches.ofSize(FILE_TRIGGERING_RECORD_COUNT) + .withByteSize(byteSize) + .withMaxBufferingDuration(maxBufferingDuration) + .withShardedKey()) + .setCoder( + KvCoder.of( + org.apache.beam.sdk.util.ShardedKey.Coder.of(destinationCoder), + IterableCoder.of(elementCoder))) + .apply( + "StripShardId", + MapElements.via( + new SimpleFunction< + KV, Iterable>, + KV>>() { + @Override + public KV> apply( + KV, Iterable> + input) { + return KV.of(input.getKey().getKey(), input.getValue()); + } + })) + .setCoder(KvCoder.of(destinationCoder, IterableCoder.of(elementCoder))) + .apply( + "WriteGroupedRecords", + ParDo.of( + new WriteGroupedRecordsToFiles( + tempFilePrefix, + maxFileSize, + rowWriterFactory, + badRecordRouter, + successfulResultsTag, + elementCoder)) + .withSideInputs(tempFilePrefix) + .withOutputTags(successfulResultsTag, TupleTagList.of(BAD_RECORD_TAG))); + badRecordErrorHandler.addErrorCollection( + writeResults.get(BAD_RECORD_TAG).setCoder(BadRecord.getCoder(input.getPipeline()))); + + return writeResults + .get(successfulResultsTag) .setCoder(WriteBundlesToFiles.ResultCoder.of(destinationCoder)); } private PCollection> writeShardedRecords( PCollection, ElementT>> shardedRecords, PCollectionView tempFilePrefix) { - return shardedRecords - .apply("GroupByDestination", GroupByKey.create()) - .apply( - "StripShardId", - MapElements.via( - new SimpleFunction< - KV, Iterable>, - KV>>() { - @Override - public KV> apply( - KV, Iterable> input) { - return KV.of(input.getKey().getKey(), input.getValue()); - } - })) - .setCoder(KvCoder.of(destinationCoder, IterableCoder.of(elementCoder))) - .apply( - "WriteGroupedRecords", - ParDo.of( - new WriteGroupedRecordsToFiles<>(tempFilePrefix, maxFileSize, rowWriterFactory)) - .withSideInputs(tempFilePrefix)) + TupleTag> successfulResultsTag = new TupleTag<>(); + PCollectionTuple writeResults = + shardedRecords + .apply("GroupByDestination", GroupByKey.create()) + .apply( + "StripShardId", + MapElements.via( + new SimpleFunction< + KV, Iterable>, + KV>>() { + @Override + public KV> apply( + KV, Iterable> input) { + return KV.of(input.getKey().getKey(), input.getValue()); + } + })) + .setCoder(KvCoder.of(destinationCoder, IterableCoder.of(elementCoder))) + .apply( + "WriteGroupedRecords", + ParDo.of( + new WriteGroupedRecordsToFiles<>( + tempFilePrefix, + maxFileSize, + rowWriterFactory, + badRecordRouter, + successfulResultsTag, + elementCoder)) + .withSideInputs(tempFilePrefix) + .withOutputTags(successfulResultsTag, TupleTagList.of(BAD_RECORD_TAG))); + + badRecordErrorHandler.addErrorCollection( + writeResults + .get(BAD_RECORD_TAG) + .setCoder(BadRecord.getCoder(shardedRecords.getPipeline()))); + + return writeResults + .get(successfulResultsTag) .setCoder(WriteBundlesToFiles.ResultCoder.of(destinationCoder)); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index cd62c5810d81..43c5af163190 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -19,6 +19,8 @@ import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.resolveTempLocation; import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryResourceNaming.createTempTableReference; +import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.BAD_RECORD_TAG; +import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.RECORDING_ROUTER; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; @@ -74,13 +76,13 @@ import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; import org.apache.beam.sdk.extensions.avro.io.AvroSource; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.extensions.gcp.util.Transport; import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath; import org.apache.beam.sdk.extensions.protobuf.ProtoCoder; import org.apache.beam.sdk.io.BoundedSource; -import org.apache.beam.sdk.io.BoundedSource.BoundedReader; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.fs.MoveOptions; import org.apache.beam.sdk.io.fs.ResolveOptions; @@ -111,6 +113,9 @@ import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.Element; +import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; +import org.apache.beam.sdk.transforms.DoFn.ProcessElement; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -120,6 +125,11 @@ import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.ThrowingBadRecordRouter; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler.DefaultErrorHandler; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -742,6 +752,8 @@ public static TypedRead read(SerializableFunction par .setUseAvroLogicalTypes(false) .setFormat(DataFormat.AVRO) .setProjectionPushdownApplied(false) + .setBadRecordErrorHandler(new DefaultErrorHandler<>()) + .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER) .build(); } @@ -770,6 +782,8 @@ public static TypedRead readWithDatumReader( .setUseAvroLogicalTypes(false) .setFormat(DataFormat.AVRO) .setProjectionPushdownApplied(false) + .setBadRecordErrorHandler(new DefaultErrorHandler<>()) + .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER) .build(); } @@ -985,6 +999,11 @@ abstract Builder setDatumReaderFactory( abstract Builder setUseAvroLogicalTypes(Boolean useAvroLogicalTypes); + abstract Builder setBadRecordErrorHandler( + ErrorHandler badRecordErrorHandler); + + abstract Builder setBadRecordRouter(BadRecordRouter badRecordRouter); + abstract Builder setProjectionPushdownApplied(boolean projectionPushdownApplied); } @@ -1033,6 +1052,10 @@ abstract Builder setDatumReaderFactory( abstract Boolean getUseAvroLogicalTypes(); + abstract ErrorHandler getBadRecordErrorHandler(); + + abstract BadRecordRouter getBadRecordRouter(); + abstract boolean getProjectionPushdownApplied(); /** @@ -1138,6 +1161,9 @@ public void validate(PipelineOptions options) { e); } } + checkArgument( + getBadRecordRouter().equals(BadRecordRouter.THROWING_ROUTER), + "BigQueryIO Read with Error Handling is only available when DIRECT_READ is used"); } ValueProvider table = getTableProvider(); @@ -1429,27 +1455,75 @@ private PCollection expandForDirectRead( ValueProvider tableProvider = getTableProvider(); Pipeline p = input.getPipeline(); if (tableProvider != null) { - // No job ID is required. Read directly from BigQuery storage. - PCollection rows = - p.apply( - org.apache.beam.sdk.io.Read.from( - BigQueryStorageTableSource.create( - tableProvider, - getFormat(), - getSelectedFields(), - getRowRestriction(), - getParseFn(), - outputCoder, - getBigQueryServices(), - getProjectionPushdownApplied()))); - if (beamSchema != null) { - rows.setSchema( - beamSchema, - getTypeDescriptor(), - getToBeamRowFn().apply(beamSchema), - getFromBeamRowFn().apply(beamSchema)); + // ThrowingBadRecordRouter is the default value, and is what is used if the user hasn't + // specified any particular error handling. + if (getBadRecordRouter() instanceof ThrowingBadRecordRouter) { + // No job ID is required. Read directly from BigQuery storage. + PCollection rows = + p.apply( + org.apache.beam.sdk.io.Read.from( + BigQueryStorageTableSource.create( + tableProvider, + getFormat(), + getSelectedFields(), + getRowRestriction(), + getParseFn(), + outputCoder, + getBigQueryServices(), + getProjectionPushdownApplied()))); + if (beamSchema != null) { + rows.setSchema( + beamSchema, + getTypeDescriptor(), + getToBeamRowFn().apply(beamSchema), + getFromBeamRowFn().apply(beamSchema)); + } + return rows; + } else { + // We need to manually execute the table source, so as to be able to capture exceptions + // to pipe to error handling + BigQueryStorageTableSource source = + BigQueryStorageTableSource.create( + tableProvider, + getFormat(), + getSelectedFields(), + getRowRestriction(), + getParseFn(), + outputCoder, + getBigQueryServices(), + getProjectionPushdownApplied()); + List> sources; + try { + // This splitting logic taken from the SDF implementation of Read + long estimatedSize = source.getEstimatedSizeBytes(bqOptions); + // Split into pieces as close to the default desired bundle size but if that would cause + // too few splits then prefer to split up to the default desired number of splits. + long desiredChunkSize; + if (estimatedSize <= 0) { + desiredChunkSize = 64 << 20; // 64mb + } else { + // 1mb --> 1 shard; 1gb --> 32 shards; 1tb --> 1000 shards, 1pb --> 32k shards + desiredChunkSize = Math.max(1 << 20, (long) (1000 * Math.sqrt(estimatedSize))); + } + sources = source.split(desiredChunkSize, bqOptions); + } catch (Exception e) { + throw new RuntimeException("Unable to split TableSource", e); + } + TupleTag rowTag = new TupleTag<>(); + PCollectionTuple resultTuple = + p.apply(Create.of(sources)) + .apply( + "Read Storage Table Source", + ParDo.of(new ReadTableSource(rowTag, getParseFn(), getBadRecordRouter())) + .withOutputTags(rowTag, TupleTagList.of(BAD_RECORD_TAG))); + getBadRecordErrorHandler() + .addErrorCollection( + resultTuple + .get(BAD_RECORD_TAG) + .setCoder(BadRecord.getCoder(input.getPipeline()))); + + return resultTuple.get(rowTag).setCoder(outputCoder); } - return rows; } checkArgument( @@ -1475,7 +1549,8 @@ private PCollection expandForDirectRead( PCollectionTuple tuple; PCollection rows; - if (!getWithTemplateCompatibility()) { + if (!getWithTemplateCompatibility() + && getBadRecordRouter() instanceof ThrowingBadRecordRouter) { // Create a singleton job ID token at pipeline construction time. String staticJobUuid = BigQueryHelpers.randomUUIDString(); jobIdTokenView = @@ -1588,6 +1663,52 @@ void cleanup(ContextContainer c) throws Exception { return rows.apply(new PassThroughThenCleanup<>(cleanupOperation, jobIdTokenView)); } + private static class ReadTableSource extends DoFn, T> { + + private final TupleTag rowTag; + + private final SerializableFunction parseFn; + + private final BadRecordRouter badRecordRouter; + + public ReadTableSource( + TupleTag rowTag, + SerializableFunction parseFn, + BadRecordRouter badRecordRouter) { + this.rowTag = rowTag; + this.parseFn = parseFn; + this.badRecordRouter = badRecordRouter; + } + + @ProcessElement + public void processElement( + @Element BoundedSource boundedSource, + MultiOutputReceiver outputReceiver, + PipelineOptions options) + throws Exception { + ErrorHandlingParseFn errorHandlingParseFn = new ErrorHandlingParseFn(parseFn); + BoundedSource sourceWithErrorHandlingParseFn; + if (boundedSource instanceof BigQueryStorageStreamSource) { + sourceWithErrorHandlingParseFn = + ((BigQueryStorageStreamSource) boundedSource).fromExisting(errorHandlingParseFn); + } else if (boundedSource instanceof BigQueryStorageStreamBundleSource) { + sourceWithErrorHandlingParseFn = + ((BigQueryStorageStreamBundleSource) boundedSource) + .fromExisting(errorHandlingParseFn); + } else { + throw new RuntimeException( + "Bounded Source is not BigQueryStorageStreamSource or BigQueryStorageStreamBundleSource, unable to read"); + } + readSource( + options, + rowTag, + outputReceiver, + sourceWithErrorHandlingParseFn, + errorHandlingParseFn, + badRecordRouter); + } + } + private PCollectionTuple createTupleForDirectRead( PCollection jobIdTokenCollection, Coder outputCoder, @@ -1724,13 +1845,45 @@ public void processElement(ProcessContext c) throws Exception { return tuple; } + private static class ErrorHandlingParseFn + implements SerializableFunction { + private final SerializableFunction parseFn; + + private transient SchemaAndRecord schemaAndRecord = null; + + private ErrorHandlingParseFn(SerializableFunction parseFn) { + this.parseFn = parseFn; + } + + @Override + public T apply(SchemaAndRecord input) { + schemaAndRecord = input; + try { + return parseFn.apply(input); + } catch (Exception e) { + throw new ParseException(e); + } + } + + public SchemaAndRecord getSchemaAndRecord() { + return schemaAndRecord; + } + } + + private static class ParseException extends RuntimeException { + public ParseException(Exception e) { + super(e); + } + } + private PCollection createPCollectionForDirectRead( PCollectionTuple tuple, Coder outputCoder, TupleTag readStreamsTag, PCollectionView readSessionView, PCollectionView tableSchemaView) { - PCollection rows = + TupleTag rowTag = new TupleTag<>(); + PCollectionTuple resultTuple = tuple .get(readStreamsTag) .apply(Reshuffle.viaRandomKey()) @@ -1738,36 +1891,44 @@ private PCollection createPCollectionForDirectRead( ParDo.of( new DoFn() { @ProcessElement - public void processElement(ProcessContext c) throws Exception { + public void processElement( + ProcessContext c, MultiOutputReceiver outputReceiver) + throws Exception { ReadSession readSession = c.sideInput(readSessionView); TableSchema tableSchema = BigQueryHelpers.fromJsonString( c.sideInput(tableSchemaView), TableSchema.class); ReadStream readStream = c.element(); + ErrorHandlingParseFn errorHandlingParseFn = + new ErrorHandlingParseFn(getParseFn()); + BigQueryStorageStreamSource streamSource = BigQueryStorageStreamSource.create( readSession, readStream, tableSchema, - getParseFn(), + errorHandlingParseFn, outputCoder, getBigQueryServices()); - // Read all of the data from the stream. In the event that this work - // item fails and is rescheduled, the same rows will be returned in - // the same order. - BoundedSource.BoundedReader reader = - streamSource.createReader(c.getPipelineOptions()); - for (boolean more = reader.start(); more; more = reader.advance()) { - c.output(reader.getCurrent()); - } + readSource( + c.getPipelineOptions(), + rowTag, + outputReceiver, + streamSource, + errorHandlingParseFn, + getBadRecordRouter()); } }) - .withSideInputs(readSessionView, tableSchemaView)) - .setCoder(outputCoder); + .withSideInputs(readSessionView, tableSchemaView) + .withOutputTags(rowTag, TupleTagList.of(BAD_RECORD_TAG))); - return rows; + getBadRecordErrorHandler() + .addErrorCollection( + resultTuple.get(BAD_RECORD_TAG).setCoder(BadRecord.getCoder(tuple.getPipeline()))); + + return resultTuple.get(rowTag).setCoder(outputCoder); } private PCollection createPCollectionForDirectReadWithStreamBundle( @@ -1776,7 +1937,8 @@ private PCollection createPCollectionForDirectReadWithStreamBundle( TupleTag> listReadStreamsTag, PCollectionView readSessionView, PCollectionView tableSchemaView) { - PCollection rows = + TupleTag rowTag = new TupleTag<>(); + PCollectionTuple resultTuple = tuple .get(listReadStreamsTag) .apply(Reshuffle.viaRandomKey()) @@ -1784,37 +1946,93 @@ private PCollection createPCollectionForDirectReadWithStreamBundle( ParDo.of( new DoFn, T>() { @ProcessElement - public void processElement(ProcessContext c) throws Exception { + public void processElement( + ProcessContext c, MultiOutputReceiver outputReceiver) + throws Exception { ReadSession readSession = c.sideInput(readSessionView); TableSchema tableSchema = BigQueryHelpers.fromJsonString( c.sideInput(tableSchemaView), TableSchema.class); List streamBundle = c.element(); + ErrorHandlingParseFn errorHandlingParseFn = + new ErrorHandlingParseFn(getParseFn()); + BigQueryStorageStreamBundleSource streamSource = BigQueryStorageStreamBundleSource.create( readSession, streamBundle, tableSchema, - getParseFn(), + errorHandlingParseFn, outputCoder, getBigQueryServices(), 1L); - // Read all of the data from the stream. In the event that this work - // item fails and is rescheduled, the same rows will be returned in - // the same order. - BoundedReader reader = - streamSource.createReader(c.getPipelineOptions()); - for (boolean more = reader.start(); more; more = reader.advance()) { - c.output(reader.getCurrent()); - } + readSource( + c.getPipelineOptions(), + rowTag, + outputReceiver, + streamSource, + errorHandlingParseFn, + getBadRecordRouter()); } }) - .withSideInputs(readSessionView, tableSchemaView)) - .setCoder(outputCoder); + .withSideInputs(readSessionView, tableSchemaView) + .withOutputTags(rowTag, TupleTagList.of(BAD_RECORD_TAG))); - return rows; + getBadRecordErrorHandler() + .addErrorCollection( + resultTuple.get(BAD_RECORD_TAG).setCoder(BadRecord.getCoder(tuple.getPipeline()))); + + return resultTuple.get(rowTag).setCoder(outputCoder); + } + + public static void readSource( + PipelineOptions options, + TupleTag rowTag, + MultiOutputReceiver outputReceiver, + BoundedSource streamSource, + ErrorHandlingParseFn errorHandlingParseFn, + BadRecordRouter badRecordRouter) + throws Exception { + // Read all the data from the stream. In the event that this work + // item fails and is rescheduled, the same rows will be returned in + // the same order. + BoundedSource.BoundedReader reader = streamSource.createReader(options); + + try { + if (reader.start()) { + outputReceiver.get(rowTag).output(reader.getCurrent()); + } else { + return; + } + } catch (ParseException e) { + GenericRecord record = errorHandlingParseFn.getSchemaAndRecord().getRecord(); + badRecordRouter.route( + outputReceiver, + record, + AvroCoder.of(record.getSchema()), + (Exception) e.getCause(), + "Unable to parse record reading from BigQuery"); + } + + while (true) { + try { + if (reader.advance()) { + outputReceiver.get(rowTag).output(reader.getCurrent()); + } else { + return; + } + } catch (ParseException e) { + GenericRecord record = errorHandlingParseFn.getSchemaAndRecord().getRecord(); + badRecordRouter.route( + outputReceiver, + record, + AvroCoder.of(record.getSchema()), + (Exception) e.getCause(), + "Unable to parse record reading from BigQuery"); + } + } } @Override @@ -2014,6 +2232,13 @@ public TypedRead withRowRestriction(ValueProvider rowRestriction) { return toBuilder().setRowRestriction(rowRestriction).build(); } + public TypedRead withErrorHandler(ErrorHandler badRecordErrorHandler) { + return toBuilder() + .setBadRecordErrorHandler(badRecordErrorHandler) + .setBadRecordRouter(BadRecordRouter.RECORDING_ROUTER) + .build(); + } + public TypedRead withTemplateCompatibility() { return toBuilder().setWithTemplateCompatibility(true).build(); } @@ -2151,6 +2376,8 @@ public static Write write() { .setDirectWriteProtos(true) .setDefaultMissingValueInterpretation( AppendRowsRequest.MissingValueInterpretation.DEFAULT_VALUE) + .setBadRecordErrorHandler(new DefaultErrorHandler<>()) + .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER) .build(); } @@ -2357,6 +2584,10 @@ public enum Method { abstract @Nullable SerializableFunction getRowMutationInformationFn(); + abstract ErrorHandler getBadRecordErrorHandler(); + + abstract BadRecordRouter getBadRecordRouter(); + abstract Builder toBuilder(); @AutoValue.Builder @@ -2465,6 +2696,11 @@ abstract Builder setDeterministicRecordIdFn( abstract Builder setRowMutationInformationFn( SerializableFunction rowMutationFn); + abstract Builder setBadRecordErrorHandler( + ErrorHandler badRecordErrorHandler); + + abstract Builder setBadRecordRouter(BadRecordRouter badRecordRouter); + abstract Write build(); } @@ -3131,6 +3367,13 @@ public Write withWriteTempDataset(String writeTempDataset) { return toBuilder().setWriteTempDataset(writeTempDataset).build(); } + public Write withErrorHandler(ErrorHandler errorHandler) { + return toBuilder() + .setBadRecordErrorHandler(errorHandler) + .setBadRecordRouter(RECORDING_ROUTER) + .build(); + } + @Override public void validate(PipelineOptions pipelineOptions) { BigQueryOptions options = pipelineOptions.as(BigQueryOptions.class); @@ -3538,6 +3781,9 @@ private WriteResult continueExpandTyped( checkArgument( !getPropagateSuccessfulStorageApiWrites(), "withPropagateSuccessfulStorageApiWrites only supported when using storage api writes."); + checkArgument( + getBadRecordRouter() instanceof ThrowingBadRecordRouter, + "Error Handling is not supported with STREAMING_INSERTS"); RowWriterFactory.TableRowWriterFactory tableRowWriterFactory = (RowWriterFactory.TableRowWriterFactory) rowWriterFactory; @@ -3572,6 +3818,10 @@ private WriteResult continueExpandTyped( checkArgument( !getPropagateSuccessfulStorageApiWrites(), "withPropagateSuccessfulStorageApiWrites only supported when using storage api writes."); + if (!(getBadRecordRouter() instanceof ThrowingBadRecordRouter)) { + LOG.warn( + "Error Handling is partially supported when using FILE_LOADS. Consider using STORAGE_WRITE_API or STORAGE_API_AT_LEAST_ONCE"); + } // Batch load handles wrapped json string value differently than the other methods. Raise a // warning when applies. @@ -3610,7 +3860,9 @@ private WriteResult continueExpandTyped( getKmsKey(), getClustering() != null, getUseAvroLogicalTypes(), - getWriteTempDataset()); + getWriteTempDataset(), + getBadRecordRouter(), + getBadRecordErrorHandler()); batchLoads.setTestServices(getBigQueryServices()); if (getSchemaUpdateOptions() != null) { batchLoads.setSchemaUpdateOptions(getSchemaUpdateOptions()); @@ -3730,7 +3982,9 @@ private WriteResult continueExpandTyped( getIgnoreUnknownValues(), getPropagateSuccessfulStorageApiWrites(), getRowMutationInformationFn() != null, - getDefaultMissingValueInterpretation()); + getDefaultMissingValueInterpretation(), + getBadRecordRouter(), + getBadRecordErrorHandler()); return input.apply("StorageApiLoads", storageApiLoads); } else { throw new RuntimeException("Unexpected write method " + method); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java index f659cc066829..2fa5bdf25a10 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java @@ -56,6 +56,9 @@ import org.apache.beam.sdk.schemas.logicaltypes.NanosDuration; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; import org.apache.beam.sdk.util.construction.PTransformTranslation.TransformPayloadTranslator; import org.apache.beam.sdk.util.construction.SdkComponents; import org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar; @@ -101,6 +104,8 @@ static class BigQueryIOReadTranslator implements TransformPayloadTranslator transform) { fieldValues.put("use_avro_logical_types", transform.getUseAvroLogicalTypes()); } fieldValues.put("projection_pushdown_applied", transform.getProjectionPushdownApplied()); + fieldValues.put("bad_record_router", toByteArray(transform.getBadRecordRouter())); + fieldValues.put( + "bad_record_error_handler", toByteArray(transform.getBadRecordErrorHandler())); return Row.withSchema(schema).withFieldValues(fieldValues).build(); } @@ -304,6 +312,11 @@ public TypedRead fromConfigRow(Row configRow, PipelineOptions options) { if (projectionPushdownApplied != null) { builder = builder.setProjectionPushdownApplied(projectionPushdownApplied); } + byte[] badRecordRouter = configRow.getBytes("bad_record_router"); + builder.setBadRecordRouter((BadRecordRouter) fromByteArray(badRecordRouter)); + byte[] badRecordErrorHandler = configRow.getBytes("bad_record_error_handler"); + builder.setBadRecordErrorHandler( + (ErrorHandler) fromByteArray(badRecordErrorHandler)); return builder.build(); } catch (InvalidClassException e) { @@ -378,6 +391,8 @@ static class BigQueryIOWriteTranslator implements TransformPayloadTranslator transform) { fieldValues.put( "row_mutation_information_fn", toByteArray(transform.getRowMutationInformationFn())); } + fieldValues.put("bad_record_router", toByteArray(transform.getBadRecordRouter())); + fieldValues.put( + "bad_record_error_handler", toByteArray(transform.getBadRecordErrorHandler())); return Row.withSchema(schema).withFieldValues(fieldValues).build(); } @@ -822,6 +840,11 @@ public Write fromConfigRow(Row configRow, PipelineOptions options) { builder.setRowMutationInformationFn( (SerializableFunction) fromByteArray(rowMutationInformationFnBytes)); } + byte[] badRecordRouter = configRow.getBytes("bad_record_router"); + builder.setBadRecordRouter((BadRecordRouter) fromByteArray(badRecordRouter)); + byte[] badRecordErrorHandler = configRow.getBytes("bad_record_error_handler"); + builder.setBadRecordErrorHandler( + (ErrorHandler) fromByteArray(badRecordErrorHandler)); return builder.build(); } catch (InvalidClassException e) { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryRowWriter.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryRowWriter.java index a442144e1610..b846a06af580 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryRowWriter.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryRowWriter.java @@ -63,7 +63,7 @@ protected OutputStream getOutputStream() { return out; } - abstract void write(T value) throws Exception; + abstract void write(T value) throws IOException, BigQueryRowSerializationException; long getByteSize() { return out.getCount(); @@ -80,4 +80,11 @@ Result getResult() { checkState(isClosed, "Not yet closed"); return new Result(resourceId, out.getCount()); } + + public static class BigQueryRowSerializationException extends Exception { + + public BigQueryRowSerializationException(Exception e) { + super(e); + } + } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamBundleSource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamBundleSource.java index a2df86af1ee6..eeb747f21ea5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamBundleSource.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamBundleSource.java @@ -110,6 +110,18 @@ public BigQueryStorageStreamBundleSource fromExisting(List newStr getMinBundleSize()); } + public BigQueryStorageStreamBundleSource fromExisting( + SerializableFunction parseFn) { + return new BigQueryStorageStreamBundleSource<>( + readSession, + streamBundle, + jsonTableSchema, + parseFn, + outputCoder, + bqServices, + getMinBundleSize()); + } + private final ReadSession readSession; private final List streamBundle; private final String jsonTableSchema; @@ -334,10 +346,6 @@ private boolean readNextRecord() throws IOException { reader.processReadRowsResponse(response); } - SchemaAndRecord schemaAndRecord = new SchemaAndRecord(reader.readSingleRecord(), tableSchema); - - current = parseFn.apply(schemaAndRecord); - // Calculates the fraction of the current stream that has been consumed. This value is // calculated by interpolating between the fraction consumed value from the previous server // response (or zero if we're consuming the first response) and the fractional value in the @@ -355,6 +363,11 @@ private boolean readNextRecord() throws IOException { // progress made in the current Stream gives us the overall StreamBundle progress. fractionOfStreamBundleConsumed = (currentStreamBundleIndex + fractionOfCurrentStreamConsumed) / source.streamBundle.size(); + + SchemaAndRecord schemaAndRecord = new SchemaAndRecord(reader.readSingleRecord(), tableSchema); + + current = parseFn.apply(schemaAndRecord); + return true; } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java index a4336cd48f94..8f7f50febaf4 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java @@ -82,6 +82,12 @@ public BigQueryStorageStreamSource fromExisting(ReadStream newReadStream) { readSession, newReadStream, jsonTableSchema, parseFn, outputCoder, bqServices); } + public BigQueryStorageStreamSource fromExisting( + SerializableFunction parseFn) { + return new BigQueryStorageStreamSource<>( + readSession, readStream, jsonTableSchema, parseFn, outputCoder, bqServices); + } + private final ReadSession readSession; private final ReadStream readStream; private final String jsonTableSchema; @@ -274,10 +280,6 @@ private synchronized boolean readNextRecord() throws IOException { reader.processReadRowsResponse(response); } - SchemaAndRecord schemaAndRecord = new SchemaAndRecord(reader.readSingleRecord(), tableSchema); - - current = parseFn.apply(schemaAndRecord); - // Updates the fraction consumed value. This value is calculated by interpolating between // the fraction consumed value from the previous server response (or zero if we're consuming // the first response) and the fractional value in the current response based on how many of @@ -291,6 +293,10 @@ private synchronized boolean readNextRecord() throws IOException { * 1.0 / totalRowsInCurrentResponse; + SchemaAndRecord schemaAndRecord = new SchemaAndRecord(reader.readSingleRecord(), tableSchema); + + current = parseFn.apply(schemaAndRecord); + return true; } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages.java index 23fe0250b7d9..aefdb79c535c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.BAD_RECORD_TAG; + import com.google.api.services.bigquery.model.TableRow; import java.io.IOException; import org.apache.beam.sdk.coders.Coder; @@ -27,12 +29,15 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; @@ -50,6 +55,7 @@ public class StorageApiConvertMessages private final Coder> successCoder; private final @Nullable SerializableFunction rowMutationFn; + private final BadRecordRouter badRecordRouter; public StorageApiConvertMessages( StorageApiDynamicDestinations dynamicDestinations, @@ -58,7 +64,8 @@ public StorageApiConvertMessages( TupleTag> successfulWritesTag, Coder errorCoder, Coder> successCoder, - @Nullable SerializableFunction rowMutationFn) { + @Nullable SerializableFunction rowMutationFn, + BadRecordRouter badRecordRouter) { this.dynamicDestinations = dynamicDestinations; this.bqServices = bqServices; this.failedWritesTag = failedWritesTag; @@ -66,6 +73,7 @@ public StorageApiConvertMessages( this.errorCoder = errorCoder; this.successCoder = successCoder; this.rowMutationFn = rowMutationFn; + this.badRecordRouter = badRecordRouter; } @Override @@ -82,11 +90,16 @@ public PCollectionTuple expand(PCollection> input) { operationName, failedWritesTag, successfulWritesTag, - rowMutationFn)) - .withOutputTags(successfulWritesTag, TupleTagList.of(failedWritesTag)) + rowMutationFn, + badRecordRouter, + input.getCoder())) + .withOutputTags( + successfulWritesTag, + TupleTagList.of(ImmutableList.of(failedWritesTag, BAD_RECORD_TAG))) .withSideInputs(dynamicDestinations.getSideInputs())); result.get(successfulWritesTag).setCoder(successCoder); result.get(failedWritesTag).setCoder(errorCoder); + result.get(BAD_RECORD_TAG).setCoder(BadRecord.getCoder(input.getPipeline())); return result; } @@ -98,6 +111,8 @@ public static class ConvertMessagesDoFn failedWritesTag; private final TupleTag> successfulWritesTag; private final @Nullable SerializableFunction rowMutationFn; + private final BadRecordRouter badRecordRouter; + Coder> elementCoder; private transient @Nullable DatasetService datasetServiceInternal = null; ConvertMessagesDoFn( @@ -106,13 +121,17 @@ public static class ConvertMessagesDoFn failedWritesTag, TupleTag> successfulWritesTag, - @Nullable SerializableFunction rowMutationFn) { + @Nullable SerializableFunction rowMutationFn, + BadRecordRouter badRecordRouter, + Coder> elementCoder) { this.dynamicDestinations = dynamicDestinations; this.messageConverters = new TwoLevelMessageConverterCache<>(operationName); this.bqServices = bqServices; this.failedWritesTag = failedWritesTag; this.successfulWritesTag = successfulWritesTag; this.rowMutationFn = rowMutationFn; + this.badRecordRouter = badRecordRouter; + this.elementCoder = elementCoder; } private DatasetService getDatasetService(PipelineOptions pipelineOptions) throws IOException { @@ -159,9 +178,19 @@ public void processElement( .toMessage(element.getValue(), rowMutationInformation) .withTimestamp(timestamp); o.get(successfulWritesTag).output(KV.of(element.getKey(), payload)); - } catch (TableRowToStorageApiProto.SchemaConversionException e) { - TableRow tableRow = messageConverter.toTableRow(element.getValue()); - o.get(failedWritesTag).output(new BigQueryStorageApiInsertError(tableRow, e.toString())); + } catch (TableRowToStorageApiProto.SchemaConversionException conversionException) { + TableRow tableRow; + try { + tableRow = messageConverter.toTableRow(element.getValue()); + } catch (Exception e) { + badRecordRouter.route(o, element, elementCoder, e, "Unable to convert value to TableRow"); + return; + } + o.get(failedWritesTag) + .output(new BigQueryStorageApiInsertError(tableRow, conversionException.toString())); + } catch (Exception e) { + badRecordRouter.route( + o, element, elementCoder, e, "Unable to convert value to StorageWriteApiPayload"); } } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java index 0227b8020129..62174b5c917a 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java @@ -17,8 +17,11 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import static org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.BAD_RECORD_TAG; + import com.google.api.services.bigquery.model.TableRow; import com.google.cloud.bigquery.storage.v1.AppendRowsRequest; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.concurrent.ThreadLocalRandom; import javax.annotation.Nullable; @@ -32,6 +35,10 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.ThrowingBadRecordRouter; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.util.ShardedKey; @@ -68,6 +75,10 @@ public class StorageApiLoads private final AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation; + private final BadRecordRouter badRecordRouter; + + private final ErrorHandler badRecordErrorHandler; + public StorageApiLoads( Coder destinationCoder, StorageApiDynamicDestinations dynamicDestinations, @@ -83,7 +94,9 @@ public StorageApiLoads( boolean ignoreUnknownValues, boolean propagateSuccessfulStorageApiWrites, boolean usesCdc, - AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation) { + AppendRowsRequest.MissingValueInterpretation defaultMissingValueInterpretation, + BadRecordRouter badRecordRouter, + ErrorHandler badRecordErrorHandler) { this.destinationCoder = destinationCoder; this.dynamicDestinations = dynamicDestinations; this.rowUpdateFn = rowUpdateFn; @@ -101,12 +114,18 @@ public StorageApiLoads( } this.usesCdc = usesCdc; this.defaultMissingValueInterpretation = defaultMissingValueInterpretation; + this.badRecordRouter = badRecordRouter; + this.badRecordErrorHandler = badRecordErrorHandler; } public TupleTag getFailedRowsTag() { return failedRowsTag; } + public boolean usesErrorHandler() { + return !(badRecordRouter instanceof ThrowingBadRecordRouter); + } + @Override public WriteResult expand(PCollection> input) { Coder payloadCoder; @@ -143,7 +162,8 @@ public WriteResult expandInconsistent( successfulConvertedRowsTag, BigQueryStorageApiInsertErrorCoder.of(), successCoder, - rowUpdateFn)); + rowUpdateFn, + badRecordRouter)); PCollectionTuple writeRecordsResult = convertMessagesResult .get(successfulConvertedRowsTag) @@ -171,6 +191,9 @@ public WriteResult expandInconsistent( if (successfulWrittenRowsTag != null) { successfulWrittenRows = writeRecordsResult.get(successfulWrittenRowsTag); } + + addErrorCollections(convertMessagesResult, writeRecordsResult); + return WriteResult.in( input.getPipeline(), null, @@ -201,7 +224,8 @@ public WriteResult expandTriggered( successfulConvertedRowsTag, BigQueryStorageApiInsertErrorCoder.of(), successCoder, - rowUpdateFn)); + rowUpdateFn, + badRecordRouter)); PCollection, Iterable>> groupedRecords; @@ -261,6 +285,8 @@ public WriteResult expandTriggered( successfulWrittenRows = writeRecordsResult.get(successfulWrittenRowsTag); } + addErrorCollections(convertMessagesResult, writeRecordsResult); + return WriteResult.in( input.getPipeline(), null, @@ -319,7 +345,8 @@ public WriteResult expandUntriggered( successfulConvertedRowsTag, BigQueryStorageApiInsertErrorCoder.of(), successCoder, - rowUpdateFn)); + rowUpdateFn, + badRecordRouter)); PCollectionTuple writeRecordsResult = convertMessagesResult @@ -350,6 +377,8 @@ public WriteResult expandUntriggered( successfulWrittenRows = writeRecordsResult.get(successfulWrittenRowsTag); } + addErrorCollections(convertMessagesResult, writeRecordsResult); + return WriteResult.in( input.getPipeline(), null, @@ -362,4 +391,53 @@ public WriteResult expandUntriggered( successfulWrittenRowsTag, successfulWrittenRows); } + + private void addErrorCollections( + PCollectionTuple convertMessagesResult, PCollectionTuple writeRecordsResult) { + if (usesErrorHandler()) { + PCollection badRecords = + PCollectionList.of( + convertMessagesResult + .get(failedRowsTag) + .apply( + "ConvertMessageFailuresToBadRecord", + ParDo.of( + new ConvertInsertErrorToBadRecord( + "Failed to Convert to Storage API Message")))) + .and(convertMessagesResult.get(BAD_RECORD_TAG)) + .and( + writeRecordsResult + .get(failedRowsTag) + .apply( + "WriteRecordFailuresToBadRecord", + ParDo.of( + new ConvertInsertErrorToBadRecord( + "Failed to Write Message to Storage API")))) + .apply("flattenBadRecords", Flatten.pCollections()); + badRecordErrorHandler.addErrorCollection(badRecords); + } + } + + private static class ConvertInsertErrorToBadRecord + extends DoFn { + + private final String errorMessage; + + public ConvertInsertErrorToBadRecord(String errorMessage) { + this.errorMessage = errorMessage; + } + + @ProcessElement + public void processElement( + @Element BigQueryStorageApiInsertError bigQueryStorageApiInsertError, + OutputReceiver outputReceiver) + throws IOException { + outputReceiver.output( + BadRecord.fromExceptionInformation( + bigQueryStorageApiInsertError, + BigQueryStorageApiInsertErrorCoder.of(), + null, + errorMessage)); + } + } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowWriter.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowWriter.java index 6cbeb61f624f..4d5fb1b3d746 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowWriter.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowWriter.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.bigquery; import com.google.api.services.bigquery.model.TableRow; +import java.io.IOException; import java.nio.charset.StandardCharsets; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.Context; @@ -37,8 +38,13 @@ class TableRowWriter extends BigQueryRowWriter { } @Override - void write(T value) throws Exception { - TableRow tableRow = toRow.apply(value); + void write(T value) throws IOException, BigQueryRowSerializationException { + TableRow tableRow; + try { + tableRow = toRow.apply(value); + } catch (Exception e) { + throw new BigQueryRowSerializationException(e); + } CODER.encode(tableRow, getOutputStream(), Context.OUTER); getOutputStream().write(NEWLINE); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java index 894983ab664f..9d84abbbbf1a 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java @@ -32,8 +32,10 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.StructuredCoder; import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryRowWriter.BigQueryRowSerializationException; import org.apache.beam.sdk.io.gcp.bigquery.WriteBundlesToFiles.Result; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.KV; @@ -69,6 +71,8 @@ class WriteBundlesToFiles private final int maxNumWritersPerBundle; private final long maxFileSize; private final RowWriterFactory rowWriterFactory; + private final Coder> coder; + private final BadRecordRouter badRecordRouter; private int spilledShardNumber; /** @@ -165,12 +169,16 @@ public void verifyDeterministic() {} TupleTag, ElementT>> unwrittenRecordsTag, int maxNumWritersPerBundle, long maxFileSize, - RowWriterFactory rowWriterFactory) { + RowWriterFactory rowWriterFactory, + Coder> coder, + BadRecordRouter badRecordRouter) { this.tempFilePrefixView = tempFilePrefixView; this.unwrittenRecordsTag = unwrittenRecordsTag; this.maxNumWritersPerBundle = maxNumWritersPerBundle; this.maxFileSize = maxFileSize; this.rowWriterFactory = rowWriterFactory; + this.coder = coder; + this.badRecordRouter = badRecordRouter; } @StartBundle @@ -197,7 +205,10 @@ BigQueryRowWriter createAndInsertWriter( @ProcessElement public void processElement( - ProcessContext c, @Element KV element, BoundedWindow window) + ProcessContext c, + @Element KV element, + BoundedWindow window, + MultiOutputReceiver outputReceiver) throws Exception { Map> writers = Preconditions.checkStateNotNull(this.writers); @@ -234,17 +245,32 @@ public void processElement( try { writer.write(element.getValue()); - } catch (Exception e) { - // Discard write result and close the write. + } catch (BigQueryRowSerializationException e) { try { - writer.close(); - // The writer does not need to be reset, as this DoFn cannot be reused. - } catch (Exception closeException) { - // Do not mask the exception that caused the write to fail. - e.addSuppressed(closeException); + badRecordRouter.route( + outputReceiver, + element, + coder, + e, + "Unable to Write BQ Record to File because serialization to TableRow failed"); + } catch (Exception e2) { + cleanupWriter(writer, e2); } - throw e; + } catch (Exception e) { + cleanupWriter(writer, e); + } + } + + private void cleanupWriter(BigQueryRowWriter writer, Exception e) throws Exception { + // Discard write result and close the write. + try { + writer.close(); + // The writer does not need to be reset, as this DoFn cannot be reused. + } catch (Exception closeException) { + // Do not mask the exception that caused the write to fail. + e.addSuppressed(closeException); } + throw e; } @FinishBundle diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteGroupedRecordsToFiles.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteGroupedRecordsToFiles.java index 236b07d74756..3a4f377ce2b8 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteGroupedRecordsToFiles.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteGroupedRecordsToFiles.java @@ -17,9 +17,14 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryRowWriter.BigQueryRowSerializationException; +import org.apache.beam.sdk.io.gcp.bigquery.WriteBundlesToFiles.Result; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; /** * Receives elements grouped by their destination, and writes them out to a file. Since all the @@ -31,21 +36,30 @@ class WriteGroupedRecordsToFiles private final PCollectionView tempFilePrefix; private final long maxFileSize; private final RowWriterFactory rowWriterFactory; + private final BadRecordRouter badRecordRouter; + private final TupleTag> successfulResultsTag; + private final Coder elementCoder; WriteGroupedRecordsToFiles( PCollectionView tempFilePrefix, long maxFileSize, - RowWriterFactory rowWriterFactory) { + RowWriterFactory rowWriterFactory, + BadRecordRouter badRecordRouter, + TupleTag> successfulResultsTag, + Coder elementCoder) { this.tempFilePrefix = tempFilePrefix; this.maxFileSize = maxFileSize; this.rowWriterFactory = rowWriterFactory; + this.badRecordRouter = badRecordRouter; + this.successfulResultsTag = successfulResultsTag; + this.elementCoder = elementCoder; } @ProcessElement public void processElement( ProcessContext c, @Element KV> element, - OutputReceiver> o) + MultiOutputReceiver outputReceiver) throws Exception { String tempFilePrefix = c.sideInput(this.tempFilePrefix); @@ -58,20 +72,29 @@ public void processElement( if (writer.getByteSize() > maxFileSize) { writer.close(); BigQueryRowWriter.Result result = writer.getResult(); - o.output( - new WriteBundlesToFiles.Result<>( - result.resourceId.toString(), result.byteSize, c.element().getKey())); + outputReceiver + .get(successfulResultsTag) + .output( + new WriteBundlesToFiles.Result<>( + result.resourceId.toString(), result.byteSize, c.element().getKey())); writer = rowWriterFactory.createRowWriter(tempFilePrefix, element.getKey()); } - writer.write(tableRow); + try { + writer.write(tableRow); + } catch (BigQueryRowSerializationException e) { + badRecordRouter.route( + outputReceiver, tableRow, elementCoder, e, "Unable to Write BQ Record to File"); + } } } finally { writer.close(); } BigQueryRowWriter.Result result = writer.getResult(); - o.output( - new WriteBundlesToFiles.Result<>( - result.resourceId.toString(), result.byteSize, c.element().getKey())); + outputReceiver + .get(successfulResultsTag) + .output( + new WriteBundlesToFiles.Result<>( + result.resourceId.toString(), result.byteSize, c.element().getKey())); } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java index d355d6bb9336..2b1c111269df 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java @@ -19,6 +19,7 @@ import static org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions.BIGQUERY_EARLY_ROLLOUT_REGION; +import com.google.api.services.bigquery.model.TableRow; import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; @@ -33,6 +34,9 @@ import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils.ErrorSinkTransform; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.junit.Test; @@ -107,4 +111,46 @@ public void testBigQueryStorageQuery1G() throws Exception { setUpTestEnvironment("1G"); runBigQueryIOStorageQueryPipeline(); } + + static class FailingTableRowParser implements SerializableFunction { + + public static final BigQueryIOStorageReadIT.FailingTableRowParser INSTANCE = + new BigQueryIOStorageReadIT.FailingTableRowParser(); + + private int parseCount = 0; + + @Override + public TableRow apply(SchemaAndRecord schemaAndRecord) { + parseCount++; + if (parseCount % 50 == 0) { + throw new RuntimeException("ExpectedException"); + } + return TableRowParser.INSTANCE.apply(schemaAndRecord); + } + } + + @Test + public void testBigQueryStorageQueryWithErrorHandling1M() throws Exception { + setUpTestEnvironment("1M"); + Pipeline p = Pipeline.create(options); + ErrorHandler> errorHandler = + p.registerBadRecordErrorHandler(new ErrorSinkTransform()); + PCollection count = + p.apply( + "Read", + BigQueryIO.read(FailingTableRowParser.INSTANCE) + .fromQuery("SELECT * FROM `" + options.getInputTable() + "`") + .usingStandardSql() + .withMethod(Method.DIRECT_READ) + .withErrorHandler(errorHandler)) + .apply("Count", Count.globally()); + + errorHandler.close(); + + // When 1/50 elements fail sequentially, this is the expected success count + PAssert.thatSingleton(count).isEqualTo(10381L); + // this is the total elements, less the successful elements + PAssert.thatSingleton(errorHandler.getOutput()).isEqualTo(10592L - 10381L); + p.run().waitUntilFinish(); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryTest.java index af6dd505b916..0c5325286dd7 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryTest.java @@ -78,6 +78,9 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils.ErrorSinkTransform; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; @@ -769,18 +772,8 @@ public void testQuerySourceCreateReader() throws Exception { querySource.createReader(options); } - @Test - public void testReadFromBigQueryIO() throws Exception { - doReadFromBigQueryIO(false); - } - - @Test - public void testReadFromBigQueryIOWithTemplateCompatibility() throws Exception { - doReadFromBigQueryIO(true); - } - - private void doReadFromBigQueryIO(boolean templateCompatibility) throws Exception { - + public TypedRead> configureTypedRead( + SerializableFunction> parseFn) throws Exception { TableReference sourceTableRef = BigQueryHelpers.parseTableSpec("project:dataset.table"); fakeDatasetService.createDataset( @@ -840,15 +833,29 @@ private void doReadFromBigQueryIO(boolean templateCompatibility) throws Exceptio when(fakeStorageClient.readRows(expectedReadRowsRequest, "")) .thenReturn(new FakeBigQueryServerStream<>(readRowsResponses)); - BigQueryIO.TypedRead> typedRead = - BigQueryIO.read(new ParseKeyValue()) - .fromQuery(encodedQuery) - .withMethod(Method.DIRECT_READ) - .withTestServices( - new FakeBigQueryServices() - .withDatasetService(fakeDatasetService) - .withJobService(fakeJobService) - .withStorageClient(fakeStorageClient)); + return BigQueryIO.read(parseFn) + .fromQuery(encodedQuery) + .withMethod(Method.DIRECT_READ) + .withTestServices( + new FakeBigQueryServices() + .withDatasetService(fakeDatasetService) + .withJobService(fakeJobService) + .withStorageClient(fakeStorageClient)); + } + + @Test + public void testReadFromBigQueryIO() throws Exception { + doReadFromBigQueryIO(false); + } + + @Test + public void testReadFromBigQueryIOWithTemplateCompatibility() throws Exception { + doReadFromBigQueryIO(true); + } + + private void doReadFromBigQueryIO(boolean templateCompatibility) throws Exception { + + BigQueryIO.TypedRead> typedRead = configureTypedRead(new ParseKeyValue()); if (templateCompatibility) { typedRead = typedRead.withTemplateCompatibility(); @@ -862,4 +869,35 @@ private void doReadFromBigQueryIO(boolean templateCompatibility) throws Exceptio p.run(); } + + private static final class FailingParseKeyValue + implements SerializableFunction> { + @Override + public KV apply(SchemaAndRecord input) { + if (input.getRecord().get("name").toString().equals("B")) { + throw new RuntimeException("ExpectedException"); + } + return KV.of( + input.getRecord().get("name").toString(), (Long) input.getRecord().get("number")); + } + } + + @Test + public void testReadFromBigQueryWithExceptionHandling() throws Exception { + + TypedRead> typedRead = configureTypedRead(new FailingParseKeyValue()); + + ErrorHandler> errorHandler = + p.registerBadRecordErrorHandler(new ErrorSinkTransform()); + typedRead = typedRead.withErrorHandler(errorHandler); + PCollection> output = p.apply(typedRead); + errorHandler.close(); + + PAssert.that(output) + .containsInAnyOrder(ImmutableList.of(KV.of("A", 1L), KV.of("C", 3L), KV.of("D", 4L))); + + PAssert.thatSingleton(errorHandler.getOutput()).isEqualTo(1L); + + p.run(); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java index e95ad4678ea8..4e20d3634800 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java @@ -20,6 +20,7 @@ import static org.apache.beam.sdk.io.gcp.bigquery.TestBigQueryOptions.BIGQUERY_EARLY_ROLLOUT_REGION; import static org.junit.Assert.assertEquals; +import com.google.api.services.bigquery.model.TableRow; import com.google.cloud.bigquery.storage.v1.DataFormat; import java.util.Map; import org.apache.beam.sdk.Pipeline; @@ -43,6 +44,9 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils.ErrorSinkTransform; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; @@ -121,6 +125,45 @@ private void runBigQueryIOStorageReadPipeline() { p.run().waitUntilFinish(); } + static class FailingTableRowParser implements SerializableFunction { + + public static final FailingTableRowParser INSTANCE = new FailingTableRowParser(); + + private int parseCount = 0; + + @Override + public TableRow apply(SchemaAndRecord schemaAndRecord) { + parseCount++; + if (parseCount % 50 == 0) { + throw new RuntimeException("ExpectedException"); + } + return TableRowParser.INSTANCE.apply(schemaAndRecord); + } + } + + private void runBigQueryIOStorageReadPipelineErrorHandling() throws Exception { + Pipeline p = Pipeline.create(options); + ErrorHandler> errorHandler = + p.registerBadRecordErrorHandler(new ErrorSinkTransform()); + PCollection count = + p.apply( + "Read", + BigQueryIO.read(FailingTableRowParser.INSTANCE) + .from(options.getInputTable()) + .withMethod(Method.DIRECT_READ) + .withFormat(options.getDataFormat()) + .withErrorHandler(errorHandler)) + .apply("Count", Count.globally()); + + errorHandler.close(); + + // When 1/50 elements fail sequentially, this is the expected success count + PAssert.thatSingleton(count).isEqualTo(10381L); + // this is the total elements, less the successful elements + PAssert.thatSingleton(errorHandler.getOutput()).isEqualTo(10592L - 10381L); + p.run().waitUntilFinish(); + } + @Test public void testBigQueryStorageRead1GAvro() throws Exception { setUpTestEnvironment("1G", DataFormat.AVRO); @@ -133,6 +176,18 @@ public void testBigQueryStorageRead1GArrow() throws Exception { runBigQueryIOStorageReadPipeline(); } + @Test + public void testBigQueryStorageRead1MErrorHandlingAvro() throws Exception { + setUpTestEnvironment("1M", DataFormat.AVRO); + runBigQueryIOStorageReadPipelineErrorHandling(); + } + + @Test + public void testBigQueryStorageRead1MErrorHandlingArrow() throws Exception { + setUpTestEnvironment("1M", DataFormat.ARROW); + runBigQueryIOStorageReadPipelineErrorHandling(); + } + @Test public void testBigQueryStorageReadWithAvro() throws Exception { storageReadWithSchema(DataFormat.AVRO); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslationTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslationTest.java index 668f4eef4d83..7f2ff8945482 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslationTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslationTest.java @@ -71,6 +71,8 @@ public class BigQueryIOTranslationTest { READ_TRANSFORM_SCHEMA_MAPPING.put("getUseAvroLogicalTypes", "use_avro_logical_types"); READ_TRANSFORM_SCHEMA_MAPPING.put( "getProjectionPushdownApplied", "projection_pushdown_applied"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getBadRecordRouter", "bad_record_router"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getBadRecordErrorHandler", "bad_record_error_handler"); } static final Map WRITE_TRANSFORM_SCHEMA_MAPPING = new HashMap<>(); @@ -128,6 +130,8 @@ public class BigQueryIOTranslationTest { WRITE_TRANSFORM_SCHEMA_MAPPING.put("getWriteTempDataset", "write_temp_dataset"); WRITE_TRANSFORM_SCHEMA_MAPPING.put( "getRowMutationInformationFn", "row_mutation_information_fn"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getBadRecordRouter", "bad_record_router"); + WRITE_TRANSFORM_SCHEMA_MAPPING.put("getBadRecordErrorHandler", "bad_record_error_handler"); } @Test diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java index 21d3e53a0701..f42734af7671 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java @@ -71,6 +71,7 @@ import java.util.Collections; import java.util.EnumSet; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; @@ -83,6 +84,7 @@ import java.util.stream.Collectors; import java.util.stream.LongStream; import java.util.stream.StreamSupport; +import org.apache.avro.Schema.Field; import org.apache.avro.SchemaBuilder; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericDatumWriter; @@ -92,6 +94,7 @@ import org.apache.avro.io.Encoder; import org.apache.beam.runners.direct.DirectOptions; import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; @@ -142,6 +145,10 @@ import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils.EchoErrorTransform; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils.ErrorSinkTransform; import org.apache.beam.sdk.transforms.windowing.AfterWatermark; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; @@ -175,6 +182,7 @@ import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -904,6 +912,124 @@ public void testBatchFileLoadsWithTempTablesCreateNever() throws Exception { containsInAnyOrder(Iterables.toArray(elements, TableRow.class))); } + private static final SerializableFunction failingIntegerToTableRow = + new SerializableFunction() { + @Override + public TableRow apply(Integer input) { + if (input == 15) { + throw new RuntimeException("Expected Exception"); + } + return new TableRow().set("number", input); + } + }; + + @Test + public void testBatchLoadsWithTableRowErrorHandling() throws Exception { + assumeTrue(!useStreaming); + assumeTrue(!useStorageApi); + List elements = Lists.newArrayList(); + for (int i = 0; i < 30; ++i) { + elements.add(i); + } + + ErrorHandler> errorHandler = + p.registerBadRecordErrorHandler(new ErrorSinkTransform()); + + WriteResult result = + p.apply(Create.of(elements).withCoder(BigEndianIntegerCoder.of())) + .apply( + BigQueryIO.write() + .to("dataset-id.table-id") + .withFormatFunction(failingIntegerToTableRow) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withSchema( + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("name").setType("STRING"), + new TableFieldSchema().setName("number").setType("INTEGER")))) + .withTestServices(fakeBqServices) + .withErrorHandler(errorHandler) + .withoutValidation()); + + errorHandler.close(); + + PAssert.that(result.getSuccessfulTableLoads()) + .containsInAnyOrder(new TableDestination("project-id:dataset-id.table-id", null)); + PAssert.thatSingleton(errorHandler.getOutput()).isEqualTo(1L); + p.run(); + + elements.remove(15); + assertThat( + fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id").stream() + .map(tr -> ((Integer) tr.get("number"))) + .collect(Collectors.toList()), + containsInAnyOrder(Iterables.toArray(elements, Integer.class))); + } + + private static final org.apache.avro.Schema avroSchema = + org.apache.avro.Schema.createRecord( + ImmutableList.of( + new Field( + "number", + org.apache.avro.Schema.create(org.apache.avro.Schema.Type.LONG), + "nodoc", + 0))); + private static final SerializableFunction, GenericRecord> + failingLongToAvro = + new SerializableFunction, GenericRecord>() { + @Override + public GenericRecord apply(AvroWriteRequest input) { + if (input.getElement() == 15) { + throw new RuntimeException("Expected Exception"); + } + return new GenericRecordBuilder(avroSchema).set("number", input.getElement()).build(); + } + }; + + @Test + public void testBatchLoadsWithAvroErrorHandling() throws Exception { + assumeTrue(!useStreaming); + assumeTrue(!useStorageApi); + List elements = Lists.newArrayList(); + for (long i = 0L; i < 30L; ++i) { + elements.add(i); + } + + ErrorHandler> errorHandler = + p.registerBadRecordErrorHandler(new ErrorSinkTransform()); + + WriteResult result = + p.apply(Create.of(elements).withCoder(VarLongCoder.of())) + .apply( + BigQueryIO.write() + .to("dataset-id.table-id") + .withAvroFormatFunction(failingLongToAvro) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withSchema( + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("number").setType("INTEGER")))) + .withTestServices(fakeBqServices) + .withErrorHandler(errorHandler) + .withoutValidation()); + + errorHandler.close(); + + PAssert.that(result.getSuccessfulTableLoads()) + .containsInAnyOrder(new TableDestination("project-id:dataset-id.table-id", null)); + PAssert.thatSingleton(errorHandler.getOutput()).isEqualTo(1L); + p.run(); + + elements.remove(15); + assertThat( + fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id").stream() + .map(tr -> Long.valueOf((String) tr.get("number"))) + .collect(Collectors.toList()), + containsInAnyOrder(Iterables.toArray(elements, Long.class))); + } + @Test public void testStreamingInsertsFailuresNoRetryPolicy() throws Exception { assumeTrue(!useStorageApi); @@ -1337,6 +1463,120 @@ public void testStreamingStorageApiWriteWithAutoSharding() throws Exception { storageWrite(true); } + // There are two failure scenarios in storage write. + // first is in conversion, which is triggered by using a bad format function + // second is in actually sending to BQ, which is triggered by telling te dataset service + // to fail a row + private void storageWriteWithErrorHandling(boolean autoSharding) throws Exception { + assumeTrue(useStorageApi); + if (autoSharding) { + assumeTrue(!useStorageApiApproximate); + assumeTrue(useStreaming); + } + List elements = Lists.newArrayList(); + for (int i = 0; i < 30; ++i) { + elements.add(i); + } + + Function shouldFailRow = + (Function & Serializable) + tr -> + tr.containsKey("number") + && (tr.get("number").equals("27") || tr.get("number").equals("3")); + fakeDatasetService.setShouldFailRow(shouldFailRow); + + TestStream testStream = + TestStream.create(BigEndianIntegerCoder.of()) + .addElements(elements.get(0), Iterables.toArray(elements.subList(1, 10), Integer.class)) + .advanceProcessingTime(Duration.standardMinutes(1)) + .addElements( + elements.get(10), Iterables.toArray(elements.subList(11, 20), Integer.class)) + .advanceProcessingTime(Duration.standardMinutes(1)) + .addElements( + elements.get(20), Iterables.toArray(elements.subList(21, 30), Integer.class)) + .advanceWatermarkToInfinity(); + + ErrorHandler> errorHandler = + p.registerBadRecordErrorHandler(new EchoErrorTransform()); + + BigQueryIO.Write write = + BigQueryIO.write() + .to("project-id:dataset-id.table-id") + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withFormatFunction(failingIntegerToTableRow) + .withSchema( + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("number").setType("INTEGER")))) + .withTestServices(fakeBqServices) + .withErrorHandler(errorHandler) + .withoutValidation(); + + if (useStreaming) { + if (!useStorageApiApproximate) { + write = + write + .withTriggeringFrequency(Duration.standardSeconds(30)) + .withNumStorageWriteApiStreams(2); + } + if (autoSharding) { + write = write.withAutoSharding(); + } + } + + PTransform> source = + useStreaming ? testStream : Create.of(elements).withCoder(BigEndianIntegerCoder.of()); + + p.apply(source).apply("WriteToBQ", write); + + errorHandler.close(); + + PAssert.that(errorHandler.getOutput()) + .satisfies( + badRecords -> { + int count = 0; + Iterator iterator = badRecords.iterator(); + while (iterator.hasNext()) { + count++; + iterator.next(); + } + Assert.assertEquals("Wrong number of bad records", 3, count); + return null; + }); + + p.run().waitUntilFinish(); + + // remove the "bad" elements from the expected elements written + elements.remove(27); + elements.remove(15); + elements.remove(3); + assertThat( + fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id").stream() + .map(tr -> Integer.valueOf((String) tr.get("number"))) + .collect(Collectors.toList()), + containsInAnyOrder(Iterables.toArray(elements, Integer.class))); + } + + @Test + public void testBatchStorageApiWriteWithErrorHandling() throws Exception { + assumeTrue(!useStreaming); + storageWriteWithErrorHandling(false); + } + + @Test + public void testStreamingStorageApiWriteWithErrorHandling() throws Exception { + assumeTrue(useStreaming); + storageWriteWithErrorHandling(false); + } + + @Test + public void testStreamingStorageApiWriteWithAutoShardingWithErrorHandling() throws Exception { + assumeTrue(useStreaming); + assumeTrue(!useStorageApiApproximate); + storageWriteWithErrorHandling(true); + } + @DefaultSchema(JavaFieldSchema.class) static class SchemaPojo { final String name; diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 7648ab4064da..ad9e31d52852 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -2178,13 +2178,11 @@ def expand(self, pcoll): def find_in_nested_dict(schema): for field in schema['fields']: if field['type'] == 'JSON': - raise ValueError( - 'Found JSON type in table schema. JSON data ' - 'insertion is currently not supported with ' - 'FILE_LOADS write method. This is supported with ' - 'STREAMING_INSERTS. For more information: ' - 'https://cloud.google.com/bigquery/docs/reference/' - 'standard-sql/json-data#ingest_json_data') + logging.warning( + 'Found JSON type in TableSchema for "File_LOADS" write ' + 'method. Make sure the TableSchema field is a parsed ' + 'JSON to ensure the read as a JSON type. Otherwise it ' + 'will read as a raw (escaped) string.') elif field['type'] == 'STRUCT': find_in_nested_dict(field) diff --git a/sdks/python/apache_beam/io/requestresponse.py b/sdks/python/apache_beam/io/requestresponse.py index 63ec7061d3e5..706bce95f5ee 100644 --- a/sdks/python/apache_beam/io/requestresponse.py +++ b/sdks/python/apache_beam/io/requestresponse.py @@ -19,16 +19,26 @@ import abc import concurrent.futures import contextlib +import enum +import json import logging import sys import time +from datetime import timedelta +from typing import Any +from typing import Dict from typing import Generic from typing import Optional +from typing import Tuple from typing import TypeVar +from typing import Union from google.api_core.exceptions import TooManyRequests import apache_beam as beam +import redis +from apache_beam import pvalue +from apache_beam.coders import coders from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler from apache_beam.metrics import Metrics from apache_beam.ml.inference.vertex_ai_inference import MSEC_TO_SEC @@ -37,10 +47,24 @@ RequestT = TypeVar('RequestT') ResponseT = TypeVar('ResponseT') -DEFAULT_TIMEOUT_SECS = 30 # seconds +# DEFAULT_TIMEOUT_SECS represents the time interval for completing the request +# with external source. +DEFAULT_TIMEOUT_SECS = 30 + +# DEFAULT_CACHE_ENTRY_TTL_SEC represents the total time-to-live +# for cache record. +DEFAULT_CACHE_ENTRY_TTL_SEC = 24 * 60 * 60 _LOGGER = logging.getLogger(__name__) +__all__ = [ + 'RequestResponseIO', + 'ExponentialBackOffRepeater', + 'DefaultThrottler', + 'NoOpsRepeater', + 'RedisCache', +] + class UserCodeExecutionException(Exception): """Base class for errors related to calling Web APIs.""" @@ -90,6 +114,7 @@ class Caller(contextlib.AbstractContextManager, abc.ABC, Generic[RequestT, ResponseT]): """Interface for user custom code intended for API calls. + For setup and teardown of clients when applicable, implement the ``__enter__`` and ``__exit__`` methods respectively.""" @abc.abstractmethod @@ -107,16 +132,27 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): return None + def get_cache_key(self, request: RequestT) -> str: + """Returns the request to be cached. + + This is how the response will be looked up in the cache as well. + By default, entire request is cached as the key for the cache. + Implement this method to override the key for the cache. + For example, in `BigTableEnrichmentHandler`, the row key for the element + is returned here. + """ + return "" + class ShouldBackOff(abc.ABC): """ - ShouldBackOff provides mechanism to apply adaptive throttling. + Provides mechanism to apply adaptive throttling. """ pass class Repeater(abc.ABC): - """Repeater provides mechanism to repeat requests for a + """Provides mechanism to repeat requests for a configurable condition.""" @abc.abstractmethod def repeat( @@ -125,17 +161,17 @@ def repeat( request: RequestT, timeout: float, metrics_collector: Optional[_MetricsCollector]) -> ResponseT: - """repeat method is called from the RequestResponseIO when - a repeater is enabled. + """Implements a repeater strategy for RequestResponseIO when a repeater + is enabled. Args: - caller: :class:`apache_beam.io.requestresponse.Caller` object that calls - the API. + caller: a `~apache_beam.io.requestresponse.Caller` object that + calls the API. request: input request to repeat. timeout: time to wait for the request to complete. metrics_collector: (Optional) a - ``:class:`apache_beam.io.requestresponse._MetricsCollector``` object to - collect the metrics for RequestResponseIO. + `~apache_beam.io.requestresponse._MetricsCollector` object + to collect the metrics for RequestResponseIO. """ pass @@ -167,9 +203,10 @@ def _execute_request( class ExponentialBackOffRepeater(Repeater): - """Exponential BackOff Repeater uses exponential backoff retry strategy for - exceptions due to the remote service such as TooManyRequests (HTTP 429), - UserCodeTimeoutException, UserCodeQuotaException. + """Configure exponential backoff retry strategy. + + It retries for exceptions due to the remote service such as + TooManyRequests (HTTP 429), UserCodeTimeoutException, UserCodeQuotaException. It utilizes the decorator :func:`apache_beam.utils.retry.with_exponential_backoff`. @@ -189,20 +226,19 @@ def repeat( a repeater is enabled. Args: - caller: :class:`apache_beam.io.requestresponse.Caller` object that + caller: a `~apache_beam.io.requestresponse.Caller` object that calls the API. request: input request to repeat. timeout: time to wait for the request to complete. metrics_collector: (Optional) a - ``:class:`apache_beam.io.requestresponse._MetricsCollector``` object to + `~apache_beam.io.requestresponse._MetricsCollector` object to collect the metrics for RequestResponseIO. """ return _execute_request(caller, request, timeout, metrics_collector) class NoOpsRepeater(Repeater): - """ - NoOpsRepeater executes a request just once irrespective of any exception. + """Executes a request just once irrespective of any exception. """ def repeat( self, @@ -213,18 +249,8 @@ def repeat( return _execute_request(caller, request, timeout, metrics_collector) -class CacheReader(abc.ABC): - """CacheReader provides mechanism to read from the cache.""" - pass - - -class CacheWriter(abc.ABC): - """CacheWriter provides mechanism to write to the cache.""" - pass - - class PreCallThrottler(abc.ABC): - """PreCallThrottler provides a throttle mechanism before sending request.""" + """Provides a throttle mechanism before sending request.""" pass @@ -251,75 +277,16 @@ def __init__( self.delay_secs = delay_secs -class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT], - beam.PCollection[ResponseT]]): - """A :class:`RequestResponseIO` transform to read and write to APIs. - - Processes an input :class:`~apache_beam.pvalue.PCollection` of requests - by making a call to the API as defined in :class:`Caller`'s `__call__` - and returns a :class:`~apache_beam.pvalue.PCollection` of responses. - """ - def __init__( - self, - caller: Caller[RequestT, ResponseT], - timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, - should_backoff: Optional[ShouldBackOff] = None, - repeater: Repeater = ExponentialBackOffRepeater(), - cache_reader: Optional[CacheReader] = None, - cache_writer: Optional[CacheWriter] = None, - throttler: PreCallThrottler = DefaultThrottler(), - ): - """ - Instantiates a RequestResponseIO transform. - - Args: - caller (~apache_beam.io.requestresponse.Caller): an implementation of - `Caller` object that makes call to the API. - timeout (float): timeout value in seconds to wait for response from API. - should_backoff (~apache_beam.io.requestresponse.ShouldBackOff): - (Optional) provides methods for backoff. - repeater (~apache_beam.io.requestresponse.Repeater): provides method to - repeat failed requests to API due to service errors. Defaults to - :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to - repeat requests with exponential backoff. - cache_reader (~apache_beam.io.requestresponse.CacheReader): (Optional) - provides methods to read external cache. - cache_writer (~apache_beam.io.requestresponse.CacheWriter): (Optional) - provides methods to write to external cache. - throttler (~apache_beam.io.requestresponse.PreCallThrottler): - provides methods to pre-throttle a request. Defaults to - :class:`apache_beam.io.requestresponse.DefaultThrottler` for - client-side adaptive throttling using - :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler` - """ - self._caller = caller - self._timeout = timeout - self._should_backoff = should_backoff - if repeater: - self._repeater = repeater - else: - self._repeater = NoOpsRepeater() - self._cache_reader = cache_reader - self._cache_writer = cache_writer - self._throttler = throttler +class _FilterCacheReadFn(beam.DoFn): + """A `DoFn` that partitions cache reads. - def expand( - self, - requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]: - # TODO(riteshghorse): handle Cache and Throttle PTransforms when available. - if isinstance(self._throttler, DefaultThrottler): - return requests | _Call( - caller=self._caller, - timeout=self._timeout, - should_backoff=self._should_backoff, - repeater=self._repeater, - throttler=self._throttler) + It emits to main output for successful cache read requests or + to the tagged output - `cache_misses` - otherwise.""" + def process(self, element: Tuple[RequestT, ResponseT], *args, **kwargs): + if not element[1]: + yield pvalue.TaggedOutput('cache_misses', element[0]) else: - return requests | _Call( - caller=self._caller, - timeout=self._timeout, - should_backoff=self._should_backoff, - repeater=self._repeater) + yield element class _Call(beam.PTransform[beam.PCollection[RequestT], @@ -333,15 +300,11 @@ class _Call(beam.PTransform[beam.PCollection[RequestT], regulate the duration of each call, defaults to 30 seconds. Args: - caller (:class:`apache_beam.io.requestresponse.Caller`): a callable - object that invokes API call. + caller: a `Caller` object that invokes API call. timeout (float): timeout value in seconds to wait for response from API. - should_backoff (~apache_beam.io.requestresponse.ShouldBackOff): - (Optional) provides methods for backoff. - repeater (~apache_beam.io.requestresponse.Repeater): (Optional) provides - methods to repeat requests to API. - throttler (~apache_beam.io.requestresponse.PreCallThrottler): - (Optional) provides methods to pre-throttle a request. + should_backoff: (Optional) provides methods for backoff. + repeater: (Optional) provides methods to repeat requests to API. + throttler: (Optional) provides methods to pre-throttle a request. """ def __init__( self, @@ -411,3 +374,431 @@ def process(self, request: RequestT, *args, **kwargs): def teardown(self): self._metrics_collector.teardown_counter.inc(1) self._caller.__exit__(*sys.exc_info()) + + +class Cache(abc.ABC): + """Base Cache class for + :class:`apache_beam.io.requestresponse.RequestResponseIO`. + + For adding cache support to RequestResponseIO, implement this class. + """ + @abc.abstractmethod + def get_read(self): + """returns a PTransform that reads from the cache.""" + pass + + @abc.abstractmethod + def get_write(self): + """returns a PTransform that writes to the cache.""" + pass + + @property + @abc.abstractmethod + def request_coder(self): + """request coder to use with Cache.""" + pass + + @request_coder.setter + @abc.abstractmethod + def request_coder(self, request_coder: coders.Coder): + """sets the request coder to use with Cache.""" + pass + + @property + @abc.abstractmethod + def source_caller(self): + """Actual caller that is using the cache.""" + pass + + @source_caller.setter + @abc.abstractmethod + def source_caller(self, caller: Caller): + """Sets the source caller for + :class:`apache_beam.io.requestresponse.RequestResponseIO` to pull + cache request key from respective callers.""" + pass + + +class _RedisMode(enum.Enum): + """ + Mode of operation for redis cache when using + `~apache_beam.io.requestresponse._RedisCaller`. + """ + READ = 0 + WRITE = 1 + + +class _RedisCaller(Caller): + """An implementation of + `~apache_beam.io.requestresponse.Caller` for Redis client. + + It provides the functionality for making requests to Redis server using + :class:`apache_beam.io.requestresponse.RequestResponseIO`. + """ + def __init__( + self, + host: str, + port: int, + time_to_live: Union[int, timedelta], + *, + request_coder: Optional[coders.Coder], + response_coder: Optional[coders.Coder], + kwargs: Optional[Dict[str, Any]] = None, + source_caller: Optional[Caller] = None, + mode: _RedisMode, + ): + """ + Args: + host (str): The hostname or IP address of the Redis server. + port (int): The port number of the Redis server. + time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for + records stored in Redis. Provide an integer (in seconds) or a + `datetime.timedelta` object. + request_coder: (Optional[`coders.Coder`]) coder for requests stored + in Redis. + response_coder: (Optional[`coders.Coder`]) coder for decoding responses + received from Redis. + kwargs: Optional(Dict[str, Any]) additional keyword arguments that + are required to connect to your redis server. Same as `redis.Redis()`. + source_caller: (Optional[`Caller`]): The source caller using this Redis + cache in case of fetching the cache request to store in Redis. + mode: `_RedisMode` An enum type specifying the operational mode of + the `_RedisCaller`. + """ + self.host, self.port = host, port + self.time_to_live = time_to_live + self.request_coder = request_coder + self.response_coder = response_coder + self.kwargs = kwargs + self.source_caller = source_caller + self.mode = mode + + def __enter__(self): + self.client = redis.Redis(self.host, self.port, **self.kwargs) + + def __call__(self, element, *args, **kwargs): + if self.mode == _RedisMode.READ: + cache_request = self.source_caller.get_cache_key(element) + # check if the caller is a enrichment handler. EnrichmentHandler + # provides the request format for cache. + if cache_request: + encoded_request = self.request_coder.encode(cache_request) + else: + encoded_request = self.request_coder.encode(element) + + encoded_response = self.client.get(encoded_request) + if not encoded_response: + # no cache entry present for this request. + return element, None + + if self.response_coder is None: + try: + response_dict = json.loads(encoded_response.decode('utf-8')) + response = beam.Row(**response_dict) + except Exception: + _LOGGER.warning( + 'cannot decode response from redis cache for %s.' % element) + return element, None + else: + response = self.response_coder.decode(encoded_response) + return element, response + else: + cache_request = self.source_caller.get_cache_key(element[0]) + if cache_request: + encoded_request = self.request_coder.encode(cache_request) + else: + encoded_request = self.request_coder.encode(element[0]) + if self.response_coder is None: + try: + encoded_response = json.dumps(element[1]._asdict()).encode('utf-8') + except Exception: + _LOGGER.warning( + 'cannot encode response %s for %s to store in ' + 'redis cache.' % (element[1], element[0])) + return element + else: + encoded_response = self.response_coder.encode(element[1]) + # Write to cache with TTL. Set nx to True to prevent overwriting for the + # same key. + self.client.set( + encoded_request, encoded_response, self.time_to_live, nx=True) + return element + + def __exit__(self, exc_type, exc_val, exc_tb): + self.client.close() + + +class _ReadFromRedis(beam.PTransform[beam.PCollection[RequestT], + beam.PCollection[ResponseT]]): + """A `PTransform` that performs Redis cache read.""" + def __init__( + self, + host: str, + port: int, + time_to_live: Union[int, timedelta], + *, + kwargs: Optional[Dict[str, Any]] = None, + request_coder: Optional[coders.Coder], + response_coder: Optional[coders.Coder], + source_caller: Optional[Caller[RequestT, ResponseT]] = None, + ): + """ + Args: + host (str): The hostname or IP address of the Redis server. + port (int): The port number of the Redis server. + time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for + records stored in Redis. Provide an integer (in seconds) or a + `datetime.timedelta` object. + kwargs: Optional(Dict[str, Any]) additional keyword arguments that + are required to connect to your redis server. Same as `redis.Redis()`. + request_coder: (Optional[`coders.Coder`]) coder for requests stored + in Redis. + response_coder: (Optional[`coders.Coder`]) coder for decoding responses + received from Redis. + source_caller: (Optional[`Caller`]): The source caller using this Redis + cache in case of fetching the cache request to store in Redis. + """ + self.request_coder = request_coder + self.response_coder = response_coder + self.redis_caller = _RedisCaller( + host, + port, + time_to_live, + request_coder=self.request_coder, + response_coder=self.response_coder, + kwargs=kwargs, + source_caller=source_caller, + mode=_RedisMode.READ) + + def expand( + self, + requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]: + return requests | RequestResponseIO(self.redis_caller) + + +class _WriteToRedis(beam.PTransform[beam.PCollection[Tuple[RequestT, + ResponseT]], + beam.PCollection[ResponseT]]): + """A `PTransfrom` that performs write to Redis cache.""" + def __init__( + self, + host: str, + port: int, + time_to_live: Union[int, timedelta], + *, + kwargs: Optional[Dict[str, Any]] = None, + request_coder: Optional[coders.Coder], + response_coder: Optional[coders.Coder], + source_caller: Optional[Caller[RequestT, ResponseT]] = None, + ): + """ + Args: + host (str): The hostname or IP address of the Redis server. + port (int): The port number of the Redis server. + time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for + records stored in Redis. Provide an integer (in seconds) or a + `datetime.timedelta` object. + kwargs: Optional(Dict[str, Any]) additional keyword arguments that + are required to connect to your redis server. Same as `redis.Redis()`. + request_coder: (Optional[`coders.Coder`]) coder for requests stored + in Redis. + response_coder: (Optional[`coders.Coder`]) coder for decoding responses + received from Redis. + source_caller: (Optional[`Caller`]): The source caller using this Redis + cache in case of fetching the cache request to store in Redis. + """ + self.request_coder = request_coder + self.response_coder = response_coder + self.redis_caller = _RedisCaller( + host, + port, + time_to_live, + request_coder=self.request_coder, + response_coder=self.response_coder, + kwargs=kwargs, + source_caller=source_caller, + mode=_RedisMode.WRITE) + + def expand( + self, elements: beam.PCollection[Tuple[RequestT, ResponseT]] + ) -> beam.PCollection[ResponseT]: + return elements | RequestResponseIO(self.redis_caller) + + +def ensure_coders_exist(request_coder): + """checks if the coder exists to encode the request for caching.""" + if not request_coder: + raise ValueError( + 'need request coder to be able to use ' + 'Cache with RequestResponseIO.') + + +class RedisCache(Cache): + """Configure cache using Redis for + :class:`apache_beam.io.requestresponse.RequestResponseIO`.""" + def __init__( + self, + host: str, + port: int, + time_to_live: Union[int, timedelta] = DEFAULT_CACHE_ENTRY_TTL_SEC, + *, + request_coder: Optional[coders.Coder] = None, + response_coder: Optional[coders.Coder] = None, + **kwargs, + ): + """ + Args: + host (str): The hostname or IP address of the Redis server. + port (int): The port number of the Redis server. + time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for + records stored in Redis. Provide an integer (in seconds) or a + `datetime.timedelta` object. + request_coder: (Optional[`coders.Coder`]) coder for encoding requests. + response_coder: (Optional[`coders.Coder`]) coder for decoding responses + received from Redis. + kwargs: Optional additional keyword arguments that + are required to connect to your redis server. Same as `redis.Redis()`. + """ + self._host = host + self._port = port + self._time_to_live = time_to_live + self._request_coder = request_coder + self._response_coder = response_coder + self._kwargs = kwargs if kwargs else {} + self._source_caller = None + + def get_read(self): + """get_read returns a PTransform for reading from the cache.""" + ensure_coders_exist(self._request_coder) + return _ReadFromRedis( + self._host, + self._port, + time_to_live=self._time_to_live, + kwargs=self._kwargs, + request_coder=self._request_coder, + response_coder=self._response_coder, + source_caller=self._source_caller) + + def get_write(self): + """returns a PTransform for writing to the cache.""" + ensure_coders_exist(self._request_coder) + return _WriteToRedis( + self._host, + self._port, + time_to_live=self._time_to_live, + kwargs=self._kwargs, + request_coder=self._request_coder, + response_coder=self._response_coder, + source_caller=self._source_caller) + + @property + def source_caller(self): + return self._source_caller + + @source_caller.setter + def source_caller(self, source_caller: Caller): + self._source_caller = source_caller + + @property + def request_coder(self): + return self._request_coder + + @request_coder.setter + def request_coder(self, request_coder: coders.Coder): + self._request_coder = request_coder + + +class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT], + beam.PCollection[ResponseT]]): + """A :class:`RequestResponseIO` transform to read and write to APIs. + + Processes an input :class:`~apache_beam.pvalue.PCollection` of requests + by making a call to the API as defined in `Caller`'s `__call__` method + and returns a :class:`~apache_beam.pvalue.PCollection` of responses. + """ + def __init__( + self, + caller: Caller[RequestT, ResponseT], + timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, + should_backoff: Optional[ShouldBackOff] = None, + repeater: Repeater = ExponentialBackOffRepeater(), + cache: Optional[Cache] = None, + throttler: PreCallThrottler = DefaultThrottler(), + ): + """ + Instantiates a RequestResponseIO transform. + + Args: + caller: an implementation of + `Caller` object that makes call to the API. + timeout (float): timeout value in seconds to wait for response from API. + should_backoff: (Optional) provides methods for backoff. + repeater: provides method to repeat failed requests to API due to service + errors. Defaults to + :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to + repeat requests with exponential backoff. + cache: (Optional) a `~apache_beam.io.requestresponse.Cache` object + to use the appropriate cache. + throttler: provides methods to pre-throttle a request. Defaults to + :class:`apache_beam.io.requestresponse.DefaultThrottler` for + client-side adaptive throttling using + :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler` + """ + self._caller = caller + self._timeout = timeout + self._should_backoff = should_backoff + if repeater: + self._repeater = repeater + else: + self._repeater = NoOpsRepeater() + self._cache = cache + self._throttler = throttler + + def expand( + self, + requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]: + # TODO(riteshghorse): handle Throttle PTransforms when available. + + if self._cache: + self._cache.source_caller = self._caller + + inputs = requests + + if self._cache: + # read from cache. + outputs = inputs | self._cache.get_read() + # filter responses that are None and send them to the Call transform + # to fetch a value from external service. + cached_responses, inputs = (outputs + | beam.ParDo(_FilterCacheReadFn() + ).with_outputs( + 'cache_misses', main='cached_responses')) + + if isinstance(self._throttler, DefaultThrottler): + # DefaultThrottler applies throttling in the DoFn of + # Call PTransform. + responses = ( + inputs + | _Call( + caller=self._caller, + timeout=self._timeout, + should_backoff=self._should_backoff, + repeater=self._repeater, + throttler=self._throttler)) + else: + # No throttling mechanism. The requests are made to the external source + # as they come. + responses = ( + inputs + | _Call( + caller=self._caller, + timeout=self._timeout, + should_backoff=self._should_backoff, + repeater=self._repeater)) + + if self._cache: + # write to cache. + _ = responses | self._cache.get_write() + return (cached_responses, responses) | beam.Flatten() + + return responses diff --git a/sdks/python/apache_beam/io/requestresponse_it_test.py b/sdks/python/apache_beam/io/requestresponse_it_test.py index 396347c58d16..bd8c63dea587 100644 --- a/sdks/python/apache_beam/io/requestresponse_it_test.py +++ b/sdks/python/apache_beam/io/requestresponse_it_test.py @@ -15,6 +15,7 @@ # limitations under the License. # import base64 +import logging import sys import typing import unittest @@ -22,15 +23,19 @@ from typing import Tuple from typing import Union +import pytest import urllib3 import apache_beam as beam +from apache_beam.coders import coders from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.testing.test_pipeline import TestPipeline # pylint: disable=ungrouped-imports try: + from testcontainers.redis import RedisContainer from apache_beam.io.requestresponse import Caller + from apache_beam.io.requestresponse import RedisCache from apache_beam.io.requestresponse import RequestResponseIO from apache_beam.io.requestresponse import UserCodeExecutionException from apache_beam.io.requestresponse import UserCodeQuotaException @@ -41,6 +46,8 @@ _PAYLOAD = base64.b64encode(bytes('payload', 'utf-8')) _HTTP_ENDPOINT_ADDRESS_FLAG = '--httpEndpointAddress' +_LOGGER = logging.getLogger(__name__) + class EchoITOptions(PipelineOptions): """Shared options for running integration tests on a deployed @@ -52,6 +59,7 @@ class EchoITOptions(PipelineOptions): def _add_argparse_args(cls, parser) -> None: parser.add_argument( _HTTP_ENDPOINT_ADDRESS_FLAG, + default='http://10.138.0.32:8080', dest='http_endpoint_address', help='The HTTP address of the Echo API endpoint; must being with ' 'http(s)://') @@ -95,7 +103,8 @@ def __call__(self, request: Request, *args, **kwargs) -> EchoResponse: or a ``UserCodeQuotaException``. """ try: - resp = urllib3.request( + http = urllib3.PoolManager() + resp = http.request( "POST", self.url, json={ @@ -118,6 +127,18 @@ def __call__(self, request: Request, *args, **kwargs) -> EchoResponse: raise UserCodeExecutionException(e) +class ValidateResponse(beam.DoFn): + """Validates response received from Mock API server.""" + def process(self, element, *args, **kwargs): + if (element.id != 'echo-should-never-exceed-quota' or + element.payload != _PAYLOAD): + raise ValueError( + 'got EchoResponse(id: %s, payload: %s), want ' + 'EchoResponse(id: echo-should-never-exceed-quota, ' + 'payload: %s' % (element.id, element.payload, _PAYLOAD)) + + +@pytest.mark.uses_mock_api class EchoHTTPCallerTestIT(unittest.TestCase): options: Union[EchoITOptions, None] = None client: Union[EchoHTTPCaller, None] = None @@ -131,58 +152,157 @@ def setUpClass(cls) -> None: cls.client = EchoHTTPCaller(http_endpoint_address) - def setUp(self) -> None: - client, options = EchoHTTPCallerTestIT._get_client_and_options() - - req = Request(id=options.should_exceed_quota_id, payload=_PAYLOAD) - try: - # The following is needed to exceed the API - client(req) - client(req) - client(req) - except UserCodeExecutionException as e: - if not isinstance(e, UserCodeQuotaException): - raise e - @classmethod def _get_client_and_options(cls) -> Tuple[EchoHTTPCaller, EchoITOptions]: assert cls.options is not None assert cls.client is not None return cls.client, cls.options - def test_given_valid_request_receives_response(self): + def test_request_response_io(self): client, options = EchoHTTPCallerTestIT._get_client_and_options() - req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD) + with TestPipeline(is_integration_test=True) as test_pipeline: + output = ( + test_pipeline + | 'Create PCollection' >> beam.Create([req]) + | 'RRIO Transform' >> RequestResponseIO(client) + | 'Validate' >> beam.ParDo(ValidateResponse())) + self.assertIsNotNone(output) - response: EchoResponse = client(req) - self.assertEqual(req.id, response.id) - self.assertEqual(req.payload, response.payload) +class ValidateCacheResponses(beam.DoFn): + """Validates that the responses are fetched from the cache.""" + def process(self, element, *args, **kwargs): + if not element[1] or 'cached-' not in element[1]: + raise ValueError( + 'responses not fetched from cache even though cache ' + 'entries are present.') - def test_given_exceeded_quota_should_raise(self): - client, options = EchoHTTPCallerTestIT._get_client_and_options() - req = Request(id=options.should_exceed_quota_id, payload=_PAYLOAD) +class ValidateCallerResponses(beam.DoFn): + """Validates that the responses are fetched from the caller.""" + def process(self, element, *args, **kwargs): + if not element[1] or 'ACK-' not in element[1]: + raise ValueError('responses not fetched from caller when they should.') - self.assertRaises(UserCodeQuotaException, lambda: client(req)) - def test_not_found_should_raise(self): - client, _ = EchoHTTPCallerTestIT._get_client_and_options() +class FakeCallerForCache(Caller[str, str]): + def __init__(self, use_cache: bool = False): + self.use_cache = use_cache - req = Request(id='i-dont-exist-quota-id', payload=_PAYLOAD) - self.assertRaisesRegex( - UserCodeExecutionException, "Not Found", lambda: client(req)) + def __enter__(self): + pass - def test_request_response_io(self): - client, options = EchoHTTPCallerTestIT._get_client_and_options() - req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD) - with TestPipeline(is_integration_test=True) as test_pipeline: - output = ( + def __call__(self, element, *args, **kwargs): + if self.use_cache: + return None, None + + return element, 'ACK-{element}' + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +@pytest.mark.uses_redis +class TestRedisCache(unittest.TestCase): + def setUp(self) -> None: + self.retries = 3 + self._start_container() + + def test_rrio_cache_all_miss(self): + """Cache is empty so all responses are fetched from caller.""" + caller = FakeCallerForCache() + req = ['redis', 'cachetools', 'memcache'] + cache = RedisCache( + self.host, + self.port, + time_to_live=30, + request_coder=coders.StrUtf8Coder(), + response_coder=coders.StrUtf8Coder()) + with TestPipeline(is_integration_test=True) as p: + _ = ( + p + | beam.Create(req) + | RequestResponseIO(caller, cache=cache) + | beam.ParDo(ValidateCallerResponses())) + + def test_rrio_cache_all_hit(self): + """Validate that records are fetched from cache.""" + caller = FakeCallerForCache() + requests = ['foo', 'bar'] + responses = ['cached-foo', 'cached-bar'] + coder = coders.StrUtf8Coder() + for i in range(len(requests)): + enc_req = coder.encode(requests[i]) + enc_resp = coder.encode(responses[i]) + self.client.setex(enc_req, 120, enc_resp) + cache = RedisCache( + self.host, + self.port, + time_to_live=30, + request_coder=coders.StrUtf8Coder(), + response_coder=coders.StrUtf8Coder()) + with TestPipeline(is_integration_test=True) as p: + _ = ( + p + | beam.Create(requests) + | RequestResponseIO(caller, cache=cache) + | beam.ParDo(ValidateCacheResponses())) + + def test_rrio_cache_miss_and_hit(self): + """Run two back-to-back pipelines, one with pulling the data from caller + and other from the cache.""" + caller = FakeCallerForCache() + requests = ['beam', 'flink', 'spark'] + cache = RedisCache( + self.host, + self.port, + request_coder=coders.StrUtf8Coder(), + response_coder=coders.StrUtf8Coder()) + with TestPipeline(is_integration_test=True) as p: + _ = ( + p + | beam.Create(requests) + | RequestResponseIO(caller, cache=cache) + | beam.ParDo(ValidateCallerResponses())) + + caller = FakeCallerForCache(use_cache=True) + with TestPipeline(is_integration_test=True) as p: + _ = ( + p + | beam.Create(requests) + | RequestResponseIO(caller, cache=cache) + | beam.ParDo(ValidateCallerResponses())) + + def test_rrio_no_coder_exception(self): + caller = FakeCallerForCache() + requests = ['beam', 'flink', 'spark'] + cache = RedisCache(self.host, self.port) + with self.assertRaises(ValueError): + test_pipeline = beam.Pipeline() + _ = ( test_pipeline - | 'Create PCollection' >> beam.Create([req]) - | 'RRIO Transform' >> RequestResponseIO(client)) - self.assertIsNotNone(output) + | beam.Create(requests) + | RequestResponseIO(caller, cache=cache)) + res = test_pipeline.run() + res.wait_until_finish() + + def tearDown(self) -> None: + self.container.stop() + + def _start_container(self): + for i in range(self.retries): + try: + self.container = RedisContainer(image='redis:7.2.4') + self.container.start() + self.host = self.container.get_container_host_ip() + self.port = self.container.get_exposed_port(6379) + self.client = self.container.get_client() + break + except Exception as e: + if i == self.retries - 1: + _LOGGER.error('Unable to start redis container for RRIO tests.') + raise e if __name__ == '__main__': diff --git a/sdks/python/apache_beam/io/requestresponse_test.py b/sdks/python/apache_beam/io/requestresponse_test.py index 6d807c2a8eb8..cfc2fe5e668d 100644 --- a/sdks/python/apache_beam/io/requestresponse_test.py +++ b/sdks/python/apache_beam/io/requestresponse_test.py @@ -23,7 +23,8 @@ # pylint: disable=ungrouped-imports try: from google.api_core.exceptions import TooManyRequests - from apache_beam.io.requestresponse import Caller, DefaultThrottler + from apache_beam.io.requestresponse import Caller + from apache_beam.io.requestresponse import DefaultThrottler from apache_beam.io.requestresponse import RequestResponseIO from apache_beam.io.requestresponse import UserCodeExecutionException from apache_beam.io.requestresponse import UserCodeTimeoutException diff --git a/sdks/python/apache_beam/io/requestresponse_tests_requirements.txt b/sdks/python/apache_beam/io/requestresponse_tests_requirements.txt new file mode 100644 index 000000000000..1d8869705097 --- /dev/null +++ b/sdks/python/apache_beam/io/requestresponse_tests_requirements.txt @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +redis>=5.0.0 diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py new file mode 100644 index 000000000000..9e4480788257 --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py @@ -0,0 +1,134 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable +from typing import List +from typing import Optional + +import apache_beam as beam +import tensorflow as tf +import tensorflow_hub as hub +import tensorflow_text as text # required to register TF ops. # pylint: disable=unused-import +from apache_beam.ml.inference import utils +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor +from apache_beam.ml.inference.tensorflow_inference import default_tensor_inference_fn +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler + +__all__ = ['TensorflowHubTextEmbeddings'] + + +# TODO: https://github.com/apache/beam/issues/30288 +# Replace with TFModelHandlerTensor when load_model() supports TFHUB models. +class _TensorflowHubModelHandler(TFModelHandlerTensor): + """ + Note: Intended for internal use only. No backwards compatibility guarantees. + """ + def __init__(self, preprocessing_url: Optional[str], *args, **kwargs): + self.preprocessing_url = preprocessing_url + super().__init__(*args, **kwargs) + + def load_model(self): + # unable to load the models with tf.keras.models.load_model so + # using hub.KerasLayer instead + model = hub.KerasLayer(self._model_uri, **self._load_model_args) + return model + + def _convert_prediction_result_to_list( + self, predictions: Iterable[PredictionResult]): + result = [] + for prediction in predictions: + inference = prediction.inference.numpy().tolist() + result.append(inference) + return result + + def run_inference(self, batch, model, inference_args, model_id=None): + if not inference_args: + inference_args = {} + if not self.preprocessing_url: + predictions = default_tensor_inference_fn( + model=model, + batch=batch, + inference_args=inference_args, + model_id=model_id) + return self._convert_prediction_result_to_list(predictions) + + vectorized_batch = tf.stack(batch, axis=0) + preprocessor_fn = hub.KerasLayer(self.preprocessing_url) + vectorized_batch = preprocessor_fn(vectorized_batch) + predictions = model(vectorized_batch) + # https://www.tensorflow.org/text/tutorials/classify_text_with_bert#using_the_bert_model # pylint: disable=line-too-long + # pooled_output -> represents the text as a whole. This is an embeddings + # of the whole text. The shape is [batch_size, embedding_dimension] + # sequence_output -> represents the text as a sequence of tokens. This is + # an embeddings of each token in the text. The shape is + # [batch_size, max_sequence_length, embedding_dimension] + # pooled output is the embeedings as per the documentation. so let's use + # that. + embeddings = predictions['pooled_output'] + predictions = utils._convert_to_result(batch, embeddings, model_id) + return self._convert_prediction_result_to_list(predictions) + + +class TensorflowHubTextEmbeddings(EmbeddingsManager): + def __init__( + self, + columns: List[str], + hub_url: str, + preprocessing_url: Optional[str] = None, + **kwargs): + """ + Embedding config for tensorflow hub models. This config can be used with + MLTransform to embed text data. Models are loaded using the RunInference + PTransform with the help of a ModelHandler. + + Args: + columns: The columns containing the text to be embedded. + hub_url: The url of the tensorflow hub model. + preprocessing_url: The url of the preprocessing model. This is optional. + If provided, the preprocessing model will be used to preprocess the + text before feeding it to the main model. + min_batch_size: The minimum batch size to be used for inference. + max_batch_size: The maximum batch size to be used for inference. + large_model: Whether to share the model across processes. + """ + super().__init__(columns=columns, **kwargs) + self.model_uri = hub_url + self.preprocessing_url = preprocessing_url + + def get_model_handler(self) -> ModelHandler: + # override the default inference function + return _TensorflowHubModelHandler( + model_uri=self.model_uri, + preprocessing_url=self.preprocessing_url, + min_batch_size=self.min_batch_size, + max_batch_size=self.max_batch_size, + large_model=self.large_model, + ) + + def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: + """ + Returns a RunInference object that is used to run inference on the text + input using _TextEmbeddingHandler. + """ + return ( + RunInference( + model_handler=_TextEmbeddingHandler(self), + inference_args=self.inference_args, + )) diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py new file mode 100644 index 000000000000..b08ca8e2d8ea --- /dev/null +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub_test.py @@ -0,0 +1,176 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest +import uuid + +import apache_beam as beam +from apache_beam.ml.transforms.base import MLTransform + +hub_url = 'https://tfhub.dev/google/nnlm-en-dim128/2' +test_query_column = 'test_query' +test_query = 'This is a test query' + +# pylint: disable=ungrouped-imports +try: + from apache_beam.ml.transforms.embeddings.tensorflow_hub import TensorflowHubTextEmbeddings +except ImportError: + TensorflowHubTextEmbeddings = None # type: ignore + +# pylint: disable=ungrouped-imports +try: + import tensorflow_transform as tft + from apache_beam.ml.transforms.tft import ScaleTo01 +except ImportError: + tft = None + + +@unittest.skipIf( + TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.') +class TFHubEmbeddingsTest(unittest.TestCase): + def setUp(self) -> None: + self.artifact_location = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def test_tfhub_text_embeddings(self): + embedding_config = TensorflowHubTextEmbeddings( + hub_url=hub_url, columns=[test_query_column]) + with beam.Pipeline() as pipeline: + transformed_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + def assert_element(element): + assert len(element[test_query_column]) == 128 + + _ = (transformed_pcoll | beam.Map(assert_element)) + + @unittest.skipIf(tft is None, 'Tensorflow Transform is not installed.') + def test_embeddings_with_scale_to_0_1(self): + embedding_config = TensorflowHubTextEmbeddings( + hub_url=hub_url, + columns=[test_query_column], + ) + with beam.Pipeline() as pipeline: + transformed_pcoll = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config).with_transform( + ScaleTo01(columns=[test_query_column]))) + + def assert_element(element): + assert max(element[test_query_column]) == 1 + + _ = ( + transformed_pcoll | beam.Map(lambda x: x.as_dict()) + | beam.Map(assert_element)) + + def pipeline_with_configurable_artifact_location( + self, + pipeline, + embedding_config=None, + read_artifact_location=None, + write_artifact_location=None): + if write_artifact_location: + return ( + pipeline + | MLTransform(write_artifact_location=write_artifact_location). + with_transform(embedding_config)) + elif read_artifact_location: + return ( + pipeline + | MLTransform(read_artifact_location=read_artifact_location)) + else: + raise NotImplementedError + + def test_embeddings_with_read_artifact_location(self): + with beam.Pipeline() as p: + embedding_config = TensorflowHubTextEmbeddings( + hub_url=hub_url, columns=[test_query_column]) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }])) + _ = self.pipeline_with_configurable_artifact_location( + pipeline=data, + embedding_config=embedding_config, + write_artifact_location=self.artifact_location) + + with beam.Pipeline() as p: + data = ( + p + | "CreateData" >> beam.Create([{ + test_query_column: test_query + }, { + test_query_column: test_query + }])) + result_pcoll = self.pipeline_with_configurable_artifact_location( + pipeline=data, read_artifact_location=self.artifact_location) + + def assert_element(element): + # 0.29836970567703247 + assert round(element, 2) == 0.3 + + _ = ( + result_pcoll + | beam.Map(lambda x: max(x[test_query_column])) + | beam.Map(assert_element)) + + def test_with_int_data_types(self): + embedding_config = TensorflowHubTextEmbeddings( + hub_url=hub_url, columns=[test_query_column]) + with self.assertRaises(TypeError): + with beam.Pipeline() as pipeline: + _ = ( + pipeline + | "CreateData" >> beam.Create([{ + test_query_column: 1 + }]) + | "MLTransform" >> MLTransform( + write_artifact_location=self.artifact_location).with_transform( + embedding_config)) + + +@unittest.skipIf( + TensorflowHubTextEmbeddings is None, 'Tensorflow is not installed.') +class TFHubEmbeddingsGCSArtifactLocationTest(TFHubEmbeddingsTest): + def setUp(self): + self.artifact_location = os.path.join( + 'gs://temp-storage-for-perf-tests/tfhub', uuid.uuid4().hex) + + def tearDown(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 030a25aa85c2..65e7824f6891 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -1194,18 +1194,22 @@ def _add_argparse_args(cls, parser): '--max_cache_memory_usage_mb', dest='max_cache_memory_usage_mb', type=int, - default=100, + default=0, help=( 'Size of the SDK Harness cache to store user state and side ' - 'inputs in MB. Default is 100MB. If the cache is full, least ' + 'inputs in MB. The cache is disabled by default. Increasing ' + 'cache size might improve performance of some pipelines, such as ' + 'pipelines that use iterable side input views, but can ' + 'lead to an increase in memory consumption and OOM errors if ' + 'workers are not appropriately provisioned. ' + 'Using the cache might decrease performance pipelines using ' + 'materialized side inputs. ' + 'If the cache is full, least ' 'recently used elements will be evicted. This cache is per ' 'each SDK Harness instance. SDK Harness is a component ' 'responsible for executing the user code and communicating with ' 'the runner. Depending on the runner, there may be more than one ' - 'SDK Harness process running on the same worker node. Increasing ' - 'cache size might improve performance of some pipelines, but can ' - 'lead to an increase in memory consumption and OOM errors if ' - 'workers are not appropriately provisioned.')) + 'SDK Harness process running on the same worker node.')) def validate(self, validator): errors = [] diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index b783d61f95c9..0858d628a55c 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -61,6 +61,7 @@ 'AsIter', 'AsList', 'AsDict', + 'AsMultiMap', 'EmptySideInput', 'Row', ] diff --git a/sdks/python/apache_beam/transforms/enrichment.py b/sdks/python/apache_beam/transforms/enrichment.py index a2f961be6437..93344835e930 100644 --- a/sdks/python/apache_beam/transforms/enrichment.py +++ b/sdks/python/apache_beam/transforms/enrichment.py @@ -15,18 +15,23 @@ # limitations under the License. # import logging +from datetime import timedelta from typing import Any from typing import Callable from typing import Dict from typing import Optional from typing import TypeVar +from typing import Union import apache_beam as beam +from apache_beam.coders import coders +from apache_beam.io.requestresponse import DEFAULT_CACHE_ENTRY_TTL_SEC from apache_beam.io.requestresponse import DEFAULT_TIMEOUT_SECS from apache_beam.io.requestresponse import Caller from apache_beam.io.requestresponse import DefaultThrottler from apache_beam.io.requestresponse import ExponentialBackOffRepeater from apache_beam.io.requestresponse import PreCallThrottler +from apache_beam.io.requestresponse import RedisCache from apache_beam.io.requestresponse import Repeater from apache_beam.io.requestresponse import RequestResponseIO @@ -44,8 +49,15 @@ _LOGGER = logging.getLogger(__name__) +def has_valid_redis_address(host: str, port: int) -> bool: + """returns `True` if both host and port are not `None`.""" + if host and port: + return True + return False + + def cross_join(left: Dict[str, Any], right: Dict[str, Any]) -> beam.Row: - """cross_join performs a cross join on two `dict` objects. + """performs a cross join on two `dict` objects. Joins the columns of the right row onto the left row. @@ -71,20 +83,29 @@ def cross_join(left: Dict[str, Any], right: Dict[str, Any]) -> beam.Row: class EnrichmentSourceHandler(Caller[InputT, OutputT]): - """Wrapper class for :class:`apache_beam.io.requestresponse.Caller`. + """Wrapper class for `apache_beam.io.requestresponse.Caller`. Ensure that the implementation of ``__call__`` method returns a tuple of `beam.Row` objects. """ - pass + def get_cache_key(self, request: InputT) -> str: + """Returns the request to be cached. This is how the response will be + looked up in the cache as well. + + Implement this method to provide the key for the cache. + By default, the entire request is stored as the cache key. + + For example, in `BigTableEnrichmentHandler`, the row key for the element + is returned here. + """ + return "request: %s" % request class Enrichment(beam.PTransform[beam.PCollection[InputT], beam.PCollection[OutputT]]): """A :class:`apache_beam.transforms.enrichment.Enrichment` transform to enrich elements in a PCollection. - **NOTE:** This transform and its implementation are under development and - do not provide backward compatibility guarantees. + Uses the :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler` to enrich elements by joining the metadata from external source. @@ -100,12 +121,11 @@ class Enrichment(beam.PTransform[beam.PCollection[InputT], join_fn: A lambda function to join original element with lookup metadata. Defaults to `CROSS_JOIN`. timeout: (Optional) timeout for source requests. Defaults to 30 seconds. - repeater (~apache_beam.io.requestresponse.Repeater): provides method to - repeat failed requests to API due to service errors. Defaults to + repeater: provides method to repeat failed requests to API due to service + errors. Defaults to :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to repeat requests with exponential backoff. - throttler (~apache_beam.io.requestresponse.PreCallThrottler): - provides methods to pre-throttle a request. Defaults to + throttler: provides methods to pre-throttle a request. Defaults to :class:`apache_beam.io.requestresponse.DefaultThrottler` for client-side adaptive throttling using :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`. @@ -116,8 +136,8 @@ def __init__( join_fn: JoinFn = cross_join, timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, repeater: Repeater = ExponentialBackOffRepeater(), - throttler: PreCallThrottler = DefaultThrottler(), - ): + throttler: PreCallThrottler = DefaultThrottler()): + self._cache = None self._source_handler = source_handler self._join_fn = join_fn self._timeout = timeout @@ -126,12 +146,55 @@ def __init__( def expand(self, input_row: beam.PCollection[InputT]) -> beam.PCollection[OutputT]: + # For caching with enrichment transform, enrichment handlers provide a + # get_cache_key() method that returns a unique string formatted + # request for that row. + request_coder = coders.StrUtf8Coder() + if self._cache: + self._cache.request_coder = request_coder + fetched_data = input_row | RequestResponseIO( caller=self._source_handler, timeout=self._timeout, repeater=self._repeater, + cache=self._cache, throttler=self._throttler) # EnrichmentSourceHandler returns a tuple of (request,response). return fetched_data | beam.Map( lambda x: self._join_fn(x[0]._asdict(), x[1]._asdict())) + + def with_redis_cache( + self, + host: str, + port: int, + time_to_live: Union[int, timedelta] = DEFAULT_CACHE_ENTRY_TTL_SEC, + *, + request_coder: Optional[coders.Coder] = None, + response_coder: Optional[coders.Coder] = None, + **kwargs, + ): + """Configure the Redis cache to use with enrichment transform. + + Args: + host (str): The hostname or IP address of the Redis server. + port (int): The port number of the Redis server. + time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for + records stored in Redis. Provide an integer (in seconds) or a + `datetime.timedelta` object. + request_coder: (Optional[`coders.Coder`]) coder for requests stored + in Redis. + response_coder: (Optional[`coders.Coder`]) coder for decoding responses + received from Redis. + kwargs: Optional additional keyword arguments that + are required to connect to your redis server. Same as `redis.Redis()`. + """ + if has_valid_redis_address(host, port): + self._cache = RedisCache( # type: ignore[assignment] + host=host, + port=port, + time_to_live=time_to_live, + request_coder=request_coder, + response_coder=response_coder, + **kwargs) + return self diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py index 873dd156cb87..943000a9f6bb 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py @@ -53,9 +53,8 @@ class ExceptionLevel(Enum): class BigTableEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): - """BigTableEnrichmentHandler is a handler for - :class:`apache_beam.transforms.enrichment.Enrichment` transform to interact - with GCP BigTable. + """A handler for :class:`apache_beam.transforms.enrichment.Enrichment` + transform to interact with GCP BigTable. Args: project_id (str): GCP project-id of the BigTable cluster. @@ -161,3 +160,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.client = None self.instance = None self._table = None + + def get_cache_key(self, request: beam.Row) -> str: + """Returns a string formatted with row key since it is unique to + a request made to `Bigtable`.""" + return "%s: %s" % (self._row_key, request._asdict()[self._row_key]) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py index 86fc438960d3..b792bc8ba946 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py @@ -16,15 +16,18 @@ # import datetime +import logging import unittest from typing import Dict from typing import List from typing import NamedTuple from typing import Tuple +from unittest.mock import MagicMock import pytest import apache_beam as beam +from apache_beam.coders import coders from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import BeamAssertException @@ -33,11 +36,14 @@ from google.api_core.exceptions import NotFound from google.cloud.bigtable import Client from google.cloud.bigtable.row_filters import ColumnRangeFilter + from testcontainers.redis import RedisContainer from apache_beam.transforms.enrichment import Enrichment from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler from apache_beam.transforms.enrichment_handlers.bigtable import ExceptionLevel except ImportError: - raise unittest.SkipTest('GCP BigTable dependencies are not installed.') + raise unittest.SkipTest('Bigtable test dependencies are not installed.') + +_LOGGER = logging.getLogger(__name__) class ValidateResponse(beam.DoFn): @@ -142,7 +148,7 @@ def create_rows(table): row.commit() -@pytest.mark.it_postcommit +@pytest.mark.uses_redis class TestBigTableEnrichment(unittest.TestCase): def setUp(self): self.project_id = 'apache-beam-testing' @@ -160,8 +166,25 @@ def setUp(self): instance = client.instance(self.instance_id) self.table = instance.table(self.table_id) create_rows(self.table) + self.retries = 3 + self._start_container() + + def _start_container(self): + for i in range(self.retries): + try: + self.container = RedisContainer(image='redis:7.2.4') + self.container.start() + self.host = self.container.get_container_host_ip() + self.port = self.container.get_exposed_port(6379) + self.client = self.container.get_client() + break + except Exception as e: + if i == self.retries - 1: + _LOGGER.error('Unable to start redis container for RRIO tests.') + raise e def tearDown(self) -> None: + self.container.stop() self.table = None def test_enrichment_with_bigtable(self): @@ -336,6 +359,73 @@ def test_enrichment_with_bigtable_with_timestamp(self): expected_enriched_fields, include_timestamp=True))) + def test_bigtable_enrichment_with_redis(self): + """ + In this test, we run two pipelines back to back. + + In the first pipeline, we run a simple bigtable enrichment pipeline with + zero cache records. Therefore, it makes call to the Bigtable source and + ultimately writes to the cache with a TTL of 300 seconds. + + For the second pipeline, we mock the `BigTableEnrichmentHandler`'s + `__call__` method to always return a `None` response. However, this change + won't impact the second pipeline because the Enrichment transform first + checks the cache to fulfill requests. Since all requests are cached, it + will return from there without making calls to the Bigtable source. + """ + expected_fields = [ + 'sale_id', 'customer_id', 'product_id', 'quantity', 'product' + ] + expected_enriched_fields = { + 'product': ['product_name', 'product_stock'], + } + start_column = 'product_name'.encode() + column_filter = ColumnRangeFilter(self.column_family_id, start_column) + bigtable = BigTableEnrichmentHandler( + project_id=self.project_id, + instance_id=self.instance_id, + table_id=self.table_id, + row_key=self.row_key, + row_filter=column_filter) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create1" >> beam.Create(self.req) + | "Enrich W/ BigTable1" >> Enrichment(bigtable).with_redis_cache( + self.host, self.port, 300) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields))) + + # manually check cache entry + c = coders.StrUtf8Coder() + for req in self.req: + key = bigtable.get_cache_key(req) + response = self.client.get(c.encode(key)) + if not response: + raise ValueError("No cache entry found for %s" % key) + + actual = BigTableEnrichmentHandler.__call__ + BigTableEnrichmentHandler.__call__ = MagicMock( + return_value=( + beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), + beam.Row())) + + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create2" >> beam.Create(self.req) + | "Enrich W/ BigTable2" >> Enrichment(bigtable).with_redis_cache( + self.host, self.port) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields))) + BigTableEnrichmentHandler.__call__ = actual + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_it_test.py b/sdks/python/apache_beam/transforms/enrichment_it_test.py index 89842cb18be0..4a45fae2e869 100644 --- a/sdks/python/apache_beam/transforms/enrichment_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_it_test.py @@ -109,7 +109,7 @@ def process(self, element: beam.Row, *args, **kwargs): raise BeamAssertException(f"Expected a not None field: {field}") -@pytest.mark.it_postcommit +@pytest.mark.uses_mock_api class TestEnrichment(unittest.TestCase): options: Union[EchoITOptions, None] = None client: Union[SampleHTTPEnrichment, None] = None diff --git a/sdks/python/apache_beam/yaml/README.md b/sdks/python/apache_beam/yaml/README.md index c278c7755ba1..b90809eaf03e 100644 --- a/sdks/python/apache_beam/yaml/README.md +++ b/sdks/python/apache_beam/yaml/README.md @@ -19,15 +19,7 @@ # Beam YAML API -While Beam provides powerful APIs for authoring sophisticated data -processing pipelines, it often still has too high a barrier for -getting started and authoring simple pipelines. Even setting up the -environment, installing the dependencies, and setting up the project -can be an overwhelming amount of boilerplate for some (though -https://beam.apache.org/blog/beam-starter-projects/ has gone a long -way in making this easier). - -Here we provide a simple declarative syntax for describing pipelines +The Beam YAML API provides a simple declarative syntax for describing pipelines that does not require coding experience or learning how to use an SDK—any text editor will do. Some installation may be required to actually *execute* a pipeline, but @@ -44,544 +36,10 @@ or consumption (e.g. a lineage analysis tool) and expect it to be more easily manipulated and semantically meaningful than the Beam protos themselves (which concern themselves more with execution). -It should be noted that everything here is still under development, but any -features already included are considered stable. Feedback is welcome at -dev@apache.beam.org. - -## Running pipelines - -The Beam yaml parser is currently included as part of the Apache Beam Python SDK. -This can be installed (e.g. within a virtual environment) as - -``` -pip install apache_beam[yaml,gcp] -``` - -In addition, several of the provided transforms (such as SQL) are implemented -in Java and their expansion will require a working Java interpeter. (The -requisite artifacts will be automatically downloaded from the apache maven -repositories, so no further installs will be required.) -Docker is also currently required for local execution of these -cross-language-requiring transforms, but not for submission to a non-local -runner such as Flink or Dataflow. - -Once the prerequisites are installed, you can execute a pipeline defined -in a yaml file as - -``` -python -m apache_beam.yaml.main --yaml_pipeline_file=/path/to/pipeline.yaml [other pipeline options such as the runner] -``` - -You can do a dry-run of your pipeline using the render runner to see what the -execution graph is, e.g. - -``` -python -m apache_beam.yaml.main --yaml_pipeline_file=/path/to/pipeline.yaml --runner=apache_beam.runners.render.RenderRunner --render_output=out.png [--render_port=0] -``` - -(This requires [Graphviz](https://graphviz.org/download/) to be installed to render the pipeline.) - -We intend to support running a pipeline on Dataflow by directly passing the -yaml specification to a template, no local installation of the Beam SDKs required. - -## Example pipelines - -Here is a simple pipeline that reads some data from csv files and -writes it out in json format. - -``` -pipeline: - transforms: - - type: ReadFromCsv - config: - path: /path/to/input*.csv - - type: WriteToJson - config: - path: /path/to/output.json - input: ReadFromCsv -``` - -We can also add a transformation - -``` -pipeline: - transforms: - - type: ReadFromCsv - config: - path: /path/to/input*.csv - - type: Filter - config: - language: python - keep: "col3 > 100" - input: ReadFromCsv - - type: WriteToJson - config: - path: /path/to/output.json - input: Filter -``` - -or two. - -``` -pipeline: - transforms: - - type: ReadFromCsv - config: - path: /path/to/input*.csv - - type: Filter - config: - language: python - keep: "col3 > 100" - input: ReadFromCsv - - type: Sql - config: - query: "select col1, count(*) as cnt from PCOLLECTION group by col1" - input: Filter - - type: WriteToJson - config: - path: /path/to/output.json - input: Sql -``` - -Transforms can be named to help with monitoring and debugging. - -``` -pipeline: - transforms: - - type: ReadFromCsv - name: ReadMyData - config: - path: /path/to/input*.csv - - type: Filter - name: KeepBigRecords - config: - language: python - keep: "col3 > 100" - input: ReadMyData - - type: Sql - name: MySqlTransform - config: - query: "select col1, count(*) as cnt from PCOLLECTION group by col1" - input: KeepBigRecords - - type: WriteToJson - name: WriteTheOutput - config: - path: /path/to/output.json - input: MySqlTransform -``` - -(This is also needed to disambiguate if more than one transform of the same -type is used.) - -If the pipeline is linear, we can let the inputs be implicit by designating -the pipeline as a `chain` type. - -``` -pipeline: - type: chain - - transforms: - - type: ReadFromCsv - config: - path: /path/to/input*.csv - - type: Filter - config: - language: python - keep: "col3 > 100" - - type: Sql - name: MySqlTransform - config: - query: "select col1, count(*) as cnt from PCOLLECTION group by col1" - - type: WriteToJson - config: - path: /path/to/output.json -``` - -As syntactic sugar, we can name the first and last transforms in our pipeline -as `source` and `sink`. - -``` -pipeline: - type: chain - - source: - type: ReadFromCsv - config: - path: /path/to/input*.csv - - transforms: - - type: Filter - config: - language: python - keep: "col3 > 100" - - - type: Sql - name: MySqlTransform - config: - query: "select col1, count(*) as cnt from PCOLLECTION group by col1" - - sink: - type: WriteToJson - config: - path: /path/to/output.json -``` - -Arbitrary non-linear pipelines are supported as well, though in this case -inputs must be explicitly named. -Here we read two sources, join them, and write two outputs. - -``` -pipeline: - transforms: - - type: ReadFromCsv - name: ReadLeft - config: - path: /path/to/left*.csv - - - type: ReadFromCsv - name: ReadRight - config: - path: /path/to/right*.csv - - - type: Sql - config: - query: select A.col1, B.col2 from A join B using (col3) - input: - A: ReadLeft - B: ReadRight - - - type: WriteToJson - name: WriteAll - input: Sql - config: - path: /path/to/all.json - - - type: Filter - name: FilterToBig - input: Sql - config: - language: python - keep: "col2 > 100" - - - type: WriteToCsv - name: WriteBig - input: FilterToBig - config: - path: /path/to/big.csv -``` - -One can, however, nest `chains` within a non-linear pipeline. -For example, here `ExtraProcessingForBigRows` is itself a "chain" transform -that has a single input and contains its own sink. - -``` -pipeline: - transforms: - - type: ReadFromCsv - name: ReadLeft - config: - path: /path/to/left*.csv - - - type: ReadFromCsv - name: ReadRight - config: - path: /path/to/right*.csv - - - type: Sql - config: - query: select A.col1, B.col2 from A join B using (col3) - input: - A: ReadLeft - B: ReadRight - - - type: WriteToJson - name: WriteAll - input: Sql - config: - path: /path/to/all.json - - - type: chain - name: ExtraProcessingForBigRows - input: Sql - transforms: - - type: Filter - config: - language: python - keep: "col2 > 100" - - type: Filter - config: - language: python - keep: "len(col1) > 10" - - type: Filter - config: - language: python - keep: "col1 > 'z'" - sink: - type: WriteToCsv - config: - path: /path/to/big.csv -``` - -## Windowing - -This API can be used to define both streaming and batch pipelines. -In order to meaningfully aggregate elements in a streaming pipeline, -some kind of windowing is typically required. Beam's -[windowing](https://beam.apache.org/documentation/programming-guide/#windowing) -and [triggering](https://beam.apache.org/documentation/programming-guide/#triggers) -can be declared using the same WindowInto transform available in all other -SDKs. - -``` -pipeline: - type: chain - transforms: - - type: ReadFromPubSub - config: - topic: myPubSubTopic - format: json - schema: - type: object - properties: - col1: {type: string} - col2: {type: integer} - col3: {type: number} - - type: WindowInto - windowing: - type: fixed - size: 60s - - type: SomeGroupingTransform - config: - arg: ... - - type: WriteToPubSub - config: - topic: anotherPubSubTopic - format: json -``` - -Rather than using an explicit `WindowInto` operation, one may instead tag a -transform itself with a specified windowing which will cause its inputs -(and hence the transform itself) to be applied with that windowing. - -``` -pipeline: - type: chain - transforms: - - type: ReadFromPubSub - config: - topic: myPubSubTopic - format: ... - schema: ... - - type: SomeGroupingTransform - config: - arg: ... - windowing: - type: sliding - size: 60s - period: 10s - - type: WriteToPubSub - config: - topic: anotherPubSubTopic - format: json -``` - -Note that the `Sql` operation itself is often a from of aggregation, and -applying a windowing (or consuming an already windowed input) will cause all -grouping to be done per window. - -``` -pipeline: - type: chain - transforms: - - type: ReadFromPubSub - config: - topic: myPubSubTopic - format: ... - schema: ... - - type: Sql - config: - query: "select col1, count(*) as c from PCOLLECTION" - windowing: - type: sessions - gap: 60s - - type: WriteToPubSub - config: - topic: anotherPubSubTopic - format: json -``` - -The specified windowing is applied to all inputs, in this case resulting in -a join per window. - -``` -pipeline: - transforms: - - type: ReadFromPubSub - name: ReadLeft - config: - topic: leftTopic - format: ... - schema: ... - - - type: ReadFromPubSub - name: ReadRight - config: - topic: rightTopic - format: ... - schema: ... - - - type: Sql - config: - query: select A.col1, B.col2 from A join B using (col3) - input: - A: ReadLeft - B: ReadRight - windowing: - type: fixed - size: 60s -``` - -For a transform with no inputs, the specified windowing is instead applied to -its output(s). As per the Beam model, the windowing is then inherited by all -consuming operations. This is especially useful for root operations like Read. - -``` -pipeline: - type: chain - transforms: - - type: ReadFromPubSub - config: - topic: myPubSubTopic - format: ... - schema: ... - windowing: - type: fixed - size: 60s - - type: Sql - config: - query: "select col1, count(*) as c from PCOLLECTION" - - type: WriteToPubSub - config: - topic: anotherPubSubTopic - format: json -``` - -One can also specify windowing at the top level of a pipeline (or composite), -which is a shorthand to simply applying this same windowing to all root -operations (that don't otherwise specify their own windowing), -and can be an effective way to apply it everywhere. - -``` -pipeline: - type: chain - transforms: - - type: ReadFromPubSub - config: - topic: myPubSubTopic - format: ... - schema: ... - - type: Sql - config: - query: "select col1, count(*) as c from PCOLLECTION" - - type: WriteToPubSub - config: - topic: anotherPubSubTopic - format: json - windowing: - type: fixed - size: 60 -``` - -Note that all these windowing specifications are compatible with the `source` -and `sink` syntax as well - -``` -pipeline: - type: chain - - source: - type: ReadFromPubSub - config: - topic: myPubSubTopic - format: ... - schema: ... - windowing: - type: fixed - size: 10s - - transforms: - - type: Sql - config: - query: "select col1, count(*) as c from PCOLLECTION" - - sink: - type: WriteToCsv - config: - path: /path/to/output.json - windowing: - type: fixed - size: 5m -``` - - -## Providers - -Though we aim to offer a large suite of built-in transforms, it is inevitable -that people will want to be able to author their own. This is made possible -through the notion of Providers which leverage expansion services and -schema transforms. - -For example, one could build a jar that vends a -[cross language transform](https://beam.apache.org/documentation/sdks/python-multi-language-pipelines/) -or [schema transform](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/schemas/transforms/SchemaTransformProvider.html) -and then use it in a transform as follows - -``` -pipeline: - type: chain - source: - type: ReadFromCsv - config: - path: /path/to/input*.csv - - transforms: - - type: MyCustomTransform - config: - arg: whatever - - sink: - type: WriteToJson - config: - path: /path/to/output.json - -providers: - - type: javaJar - jar: /path/or/url/to/myExpansionService.jar - transforms: - MyCustomTransform: "urn:registered:in:expansion:service" -``` - -Arbitrary Python transforms can be provided as well, using the syntax - -``` -providers: - - type: pythonPackage - packages: - - my_pypi_package>=version - - /path/to/local/package.zip - transforms: - MyCustomTransform: "pkg.subpkg.PTransformClassOrCallable" -``` - -## Other Resources - -* [Example pipelines](https://gist.github.com/robertwb/2cb26973f1b1203e8f5f8f88c5764da0) -* [More examples](https://github.com/Polber/beam/tree/jkinard/bug-bash/sdks/python/apache_beam/yaml/examples) -* [Transform glossary](https://gist.github.com/robertwb/64e2f51ff88320eeb6ffd96634202df7) +## More details -Additional documentation in this directory +User-facing documentation for Beam YAML has moved to the main Beam site at +https://beam.apache.org/documentation/sdks/yaml/ -* [Mapping](yaml_mapping.md) -* [Aggregation](yaml_combine.md) -* [Error handling](yaml_errors.md) -* [Inlining Python](inline_python.md) +For information about contributing to Beam YAML see +https://docs.google.com/document/d/19zswPXxxBxlAUmswYPUtSc-IVAu1qWvpjo1ZSDMRbu0 diff --git a/sdks/python/apache_beam/yaml/inline_python.md b/sdks/python/apache_beam/yaml/inline_python.md index 72b8b76c58a2..1d8ded5fe557 100644 --- a/sdks/python/apache_beam/yaml/inline_python.md +++ b/sdks/python/apache_beam/yaml/inline_python.md @@ -17,186 +17,7 @@ under the License. --> -# Using PyTransform form YAML +# Using PyTransform from YAML -Beam YAML provides the ability to easily invoke Python transforms via the -`PyTransform` type, simply referencing them by fully qualified name. -For example, - -``` -- type: PyTransform - config: - constructor: apache_beam.pkg.module.SomeTransform - args: [1, 'foo'] - kwargs: - baz: 3 -``` - -will invoke the transform `apache_beam.pkg.mod.SomeTransform(1, 'foo', baz=3)`. -This fully qualified name can be any PTransform class or other callable that -returns a PTransform. Note, however, that PTransforms that do not accept or -return schema'd data may not be as useable to use from YAML. -Restoring the schema-ness after a non-schema returning transform can be done -by using the `callable` option on `MapToFields` which takes the entire element -as an input, e.g. - -``` -- type: PyTransform - config: - constructor: apache_beam.pkg.module.SomeTransform - args: [1, 'foo'] - kwargs: - baz: 3 -- type: MapToFields - config: - language: python - fields: - col1: - callable: 'lambda element: element.col1' - output_type: string - col2: - callable: 'lambda element: element.col2' - output_type: integer -``` - -This can be used to call arbitrary transforms in the Beam SDK, e.g. - -``` -pipeline: - transforms: - - type: PyTransform - name: ReadFromTsv - input: {} - config: - constructor: apache_beam.io.ReadFromCsv - kwargs: - path: '/path/to/*.tsv' - sep: '\t' - skip_blank_lines: True - true_values: ['yes'] - false_values: ['no'] - comment: '#' - on_bad_lines: 'skip' - binary: False - splittable: False -``` - - -## Defining a transform inline using `__constructor__` - -If the desired transform does not exist, one can define it inline as well. -This is done with the special `__constructor__` keywords, -similar to how cross-language transforms are done. - -With the `__constuctor__` keyword, one defines a Python callable that, on -invocation, *returns* the desired transform. The first argument (or `source` -keyword argument, if there are no positional arguments) -is interpreted as the Python code. For example - -``` -- type: PyTransform - config: - constructor: __constructor__ - kwargs: - source: | - import apache_beam as beam - - def create_my_transform(inc): - return beam.Map(lambda x: beam.Row(a=x.col2 + inc)) - - inc: 10 -``` - -will apply `beam.Map(lambda x: beam.Row(a=x.col2 + 10))` to the incoming -PCollection. - -As a class object can be invoked as its own constructor, this allows one to -define a `beam.PTransform` inline, e.g. - -``` -- type: PyTransform - config: - constructor: __constructor__ - kwargs: - source: | - class MyPTransform(beam.PTransform): - def __init__(self, inc): - self._inc = inc - def expand(self, pcoll): - return pcoll | beam.Map(lambda x: beam.Row(a=x.col2 + self._inc)) - - inc: 10 -``` - -which works exactly as one would expect. - - -## Defining a transform inline using `__callable__` - -The `__callable__` keyword works similarly, but instead of defining a -callable that returns an applicable `PTransform` one simply defines the -expansion to be performed as a callable. This is analogous to BeamPython's -`ptransform.ptransform_fn` decorator. - -In this case one can simply write - -``` -- type: PyTransform - config: - constructor: __callable__ - kwargs: - source: | - def my_ptransform(pcoll, inc): - return pcoll | beam.Map(lambda x: beam.Row(a=x.col2 + inc)) - - inc: 10 -``` - - -# External transforms - -One can also invoke PTransforms define elsewhere via a `python` provider, -for example - -``` -pipeline: - transforms: - - ... - - type: MyTransform - config: - kwarg: whatever - -providers: - - ... - - type: python - input: ... - config: - packages: - - 'some_pypi_package>=version' - transforms: - MyTransform: 'pkg.module.MyTransform' -``` - -These can be defined inline as well, with or without dependencies, e.g. - -``` -pipeline: - transforms: - - ... - - type: ToCase - input: ... - config: - upper: True - -providers: - - type: python - config: {} - transforms: - 'ToCase': | - @beam.ptransform_fn - def ToCase(pcoll, upper): - if upper: - return pcoll | beam.Map(lambda x: str(x).upper()) - else: - return pcoll | beam.Map(lambda x: str(x).lower()) -``` +The contents of this file have been moved to the main Apache Beam site at +https://beam.apache.org/documentation/sdks/yaml-inline-python/ diff --git a/sdks/python/apache_beam/yaml/readme_test.py b/sdks/python/apache_beam/yaml/readme_test.py index dca1dbbc7365..85ce47d0a3d3 100644 --- a/sdks/python/apache_beam/yaml/readme_test.py +++ b/sdks/python/apache_beam/yaml/readme_test.py @@ -287,19 +287,24 @@ def expand(self, pcoll): return pcoll +# These are copied from $ROOT/website/www/site/content/en/documentation/sdks +# at build time. +YAML_DOCS_DIR = os.path.join(os.path.join(os.path.dirname(__file__), 'docs')) + ReadMeTest = createTestSuite( - 'ReadMeTest', os.path.join(os.path.dirname(__file__), 'README.md')) + 'ReadMeTest', os.path.join(YAML_DOCS_DIR, 'yaml.md')) ErrorHandlingTest = createTestSuite( - 'ErrorHandlingTest', - os.path.join(os.path.dirname(__file__), 'yaml_errors.md')) + 'ErrorHandlingTest', os.path.join(YAML_DOCS_DIR, 'yaml-errors.md')) + +MappingTest = createTestSuite( + 'MappingTest', os.path.join(YAML_DOCS_DIR, 'yaml-udf.md')) CombineTest = createTestSuite( - 'CombineTest', os.path.join(os.path.dirname(__file__), 'yaml_combine.md')) + 'CombineTest', os.path.join(YAML_DOCS_DIR, 'yaml-combine.md')) InlinePythonTest = createTestSuite( - 'InlinePythonTest', - os.path.join(os.path.dirname(__file__), 'inline_python.md')) + 'InlinePythonTest', os.path.join(YAML_DOCS_DIR, 'yaml-inline-python.md')) if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/sdks/python/apache_beam/yaml/yaml_combine.md b/sdks/python/apache_beam/yaml/yaml_combine.md index e2fef304fb0a..fd3b4fe5f829 100644 --- a/sdks/python/apache_beam/yaml/yaml_combine.md +++ b/sdks/python/apache_beam/yaml/yaml_combine.md @@ -17,150 +17,5 @@ under the License. --> -# Beam YAML Aggregations - -Beam YAML has EXPERIMENTAL ability to do aggregations to group and combine -values across records. The is accomplished via the `Combine` transform type. -Currently `Combine` needs to be in the `yaml_experimental_features` -option to use this transform. - -For example, one can write - -``` -- type: Combine - config: - group_by: col1 - combine: - total: - value: col2 - fn: - type: sum -``` - -If the function has no configuration requirements, it can be provided directly -as a string - -``` -- type: Combine - config: - group_by: col1 - combine: - total: - value: col2 - fn: sum -``` - -This can be simplified further if the output field name is the same as the input -field name - -``` -- type: Combine - config: - group_by: col1 - combine: - col2: sum -``` - -One can aggregate over may fields at once - -``` -- type: Combine - config: - group_by: col1 - combine: - col2: sum - col3: max -``` - -and/or group by more than one field - -``` -- type: Combine - config: - group_by: [col1, col2] - combine: - col3: sum -``` - -or none at all (which will result in a global combine with a single output) - -``` -- type: Combine - config: - group_by: [] - combine: - col2: sum - col3: max -``` - -## Windowed aggregation - -As with all transforms, `Combine` can take a windowing parameter - -``` -- type: Combine - windowing: - type: fixed - size: 60 - config: - group_by: col1 - combine: - col2: sum - col3: max -``` - -If no windowing specification is provided, it inherits the windowing -parameters from upstream, e.g. - -``` -- type: WindowInto - windowing: - type: fixed - size: 60 -- type: Combine - config: - group_by: col1 - combine: - col2: sum - col3: max -``` - -is equivalent to the previous example. - - -## Custom aggregation functions - -One can use aggregation functions defined in Python by setting the language -parameter. - -``` -- type: Combine - config: - language: python - group_by: col1 - combine: - biggest: - value: "col2 + col2" - fn: - type: 'apache_beam.transforms.combiners.TopCombineFn' - config: - n: 10 -``` - -## SQL-style aggregations - -By setting the language to SQL, one can provide full SQL snippets as the -combine fn. - -``` -- type: Combine - config: - language: sql - group_by: col1 - combine: - num_values: "count(*)" - total: "sum(col2)" -``` - -One can of course also use the `Sql` transform type and provide a query -directly. +The contents of this file have been moved to the main Beam site at +https://beam.apache.org/documentation/sdks/yaml-combine/ diff --git a/sdks/python/apache_beam/yaml/yaml_errors.md b/sdks/python/apache_beam/yaml/yaml_errors.md index aec602393674..369dbeb7a274 100644 --- a/sdks/python/apache_beam/yaml/yaml_errors.md +++ b/sdks/python/apache_beam/yaml/yaml_errors.md @@ -19,182 +19,5 @@ # Beam YAML Error Handling -The larger one's pipeline gets, the more common it is to encounter "exceptional" -data that is malformatted, doesn't handle the proper preconditions, or otherwise -breaks during processing. Generally any such record will cause the pipeline to -permanently fail, but often it is desirable to allow the pipeline to continue, -re-directing bad records to another path for special handling or simply -recording them for later off-line analysis. This is often called the -"dead letter queue" pattern. - -Beam YAML has special support for this pattern if the transform supports a -`error_handling` config parameter with an `output` field. For example, -the following code will write all "good" processed records to one file and -any "bad" records to a separate file. - -``` -pipeline: - transforms: - - type: ReadFromCsv - config: - path: /path/to/input*.csv - - - type: MapToFields - input: ReadFromCsv - config: - language: python - fields: - col1: col1 - # This could raise a divide-by-zero error. - ratio: col2 / col3 - error_handling: - output: my_error_output - - - type: WriteToJson - input: MapToFields - config: - path: /path/to/output.json - - - type: WriteToJson - name: WriteErrorsToJson - input: MapToFields.my_error_output - config: - path: /path/to/errors.json -``` - -Note that with `error_handling` declared, `MapToFields.my_error_output` -**must** be consumed; to ignore it will be an error. Any use is fine, e.g. -logging the bad records to stdout would be sufficient (though not recommended -for a robust pipeline). - -Note also that the exact format of the error outputs is still being finalized. -They can be safely printed and written to outputs, but their precise schema -may change in a future version of Beam and should not yet be depended on. - -Some transforms allow for extra arguments in their error_handling config, e.g. -for Python functions one can give a `threshold` which limits the relative number -of records that can be bad before considering the entire pipeline a failure - -``` -pipeline: - transforms: - - type: ReadFromCsv - config: - path: /path/to/input*.csv - - - type: MapToFields - input: ReadFromCsv - config: - language: python - fields: - col1: col1 - # This could raise a divide-by-zero error. - ratio: col2 / col3 - error_handling: - output: my_error_output - # If more than 10% of records throw an error, stop the pipeline. - threshold: 0.1 - - - type: WriteToJson - input: MapToFields - config: - path: /path/to/output.json - - - type: WriteToJson - name: WriteErrorsToJson - input: MapToFields.my_error_output - config: - path: /path/to/errors.json -``` - -One can do arbitrary further processing on these failed records if desired, -e.g. - -``` -pipeline: - transforms: - - type: ReadFromCsv - config: - path: /path/to/input*.csv - - - type: MapToFields - name: ComputeRatio - input: ReadFromCsv - config: - language: python - fields: - col1: col1 - # This could raise a divide-by-zero error. - ratio: col2 / col3 - error_handling: - output: my_error_output - - - type: MapToFields - name: ComputeRatioForBadRecords - input: ComputeRatio.my_error_output - config: - language: python - fields: - col1: col1 - ratio: col2 / (col3 + 1) - error_handling: - output: still_bad - - - type: WriteToJson - # Takes as input everything from the "success" path of both transforms. - input: [ComputeRatio, ComputeRatioForBadRecords] - config: - path: /path/to/output.json - - - type: WriteToJson - name: WriteErrorsToJson - # These failed the first and the second transform. - input: ComputeRatioForBadRecords.still_bad - config: - path: /path/to/errors.json -``` - -When using the `chain` syntax, the required error consumption can happen -in an `extra_transforms` block. - -``` -pipeline: - type: chain - transforms: - - type: ReadFromCsv - config: - path: /path/to/input*.csv - - - type: MapToFields - name: SomeStep - config: - language: python - fields: - col1: col1 - # This could raise a divide-by-zero error. - ratio: col2 / col3 - error_handling: - output: errors - - - type: MapToFields - name: AnotherStep - config: - language: python - fields: - col1: col1 - # This could raise a divide-by-zero error. - inverse_ratio: 1 / ratio - error_handling: - output: errors - - - type: WriteToJson - config: - path: /path/to/output.json - - extra_transforms: - - type: WriteToJson - name: WriteErrors - input: [SomeStep.errors, AnotherStep.errors] - config: - path: /path/to/errors.json -``` +The contents of this file have been moved to the main Beam site at +https://beam.apache.org/documentation/sdks/yaml-errors/ diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.md b/sdks/python/apache_beam/yaml/yaml_mapping.md index 74b95d3cab2a..d0400c209917 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.md +++ b/sdks/python/apache_beam/yaml/yaml_mapping.md @@ -19,230 +19,5 @@ # Beam YAML mappings -Beam YAML has the ability to do simple transformations which can be used to -get data into the correct shape. The simplest of these is `MaptoFields` -which creates records with new fields defined in terms of the input fields. - -## Field renames - -To rename fields one can write - -``` -- type: MapToFields - config: - fields: - new_col1: col1 - new_col2: col2 -``` - -will result in an output where each record has two fields, -`new_col1` and `new_col2`, whose values are those of `col1` and `col2` -respectively (which are the names of two fields from the input schema). - -One can specify the append parameter which indicates the original fields should -be retained similar to the use of `*` in an SQL select statement. For example - -``` -- type: MapToFields - config: - append: true - fields: - new_col1: col1 - new_col2: col2 -``` - -will output records that have `new_col1` and `new_col2` as *additional* -fields. When the append field is specified, one can drop fields as well, e.g. - -``` -- type: MapToFields - config: - append: true - drop: - - col3 - fields: - new_col1: col1 - new_col2: col2 -``` - -which includes all original fiels *except* col3 in addition to outputting the -two new ones. - - -## Mapping functions - -Of course one may want to do transformations beyond just dropping and renaming -fields. Beam YAML has the ability to inline simple UDFs. -This requires a language specification. For example, we can provide a -Python expression referencing the input fields - -``` -- type: MapToFields - config: - language: python - fields: - new_col: "col1.upper()" - another_col: "col2 + col3" -``` - -In addition, one can provide a full Python callable that takes the row as an -argument to do more complex mappings -(see [PythonCallableSource](https://beam.apache.org/releases/pydoc/current/apache_beam.utils.python_callable.html#apache_beam.utils.python_callable.PythonCallableWithSource) -for acceptable formats). Thus one can write - -``` -- type: MapToFields - config: - language: python - fields: - new_col: - callable: | - import re - def my_mapping(row): - if re.match("[0-9]+", row.col1) and row.col2 > 0: - return "good" - else: - return "bad" -``` - -Once one reaches a certain level of complexity, it may be preferable to package -this up as a dependency and simply refer to it by fully qualified name, e.g. - -``` -- type: MapToFields - config: - language: python - fields: - new_col: - callable: pkg.module.fn -``` - -Currently, in addition to Python, Java, SQL, and JavaScript (experimental) -expressions are supported as well - -``` -- type: MapToFields - config: - language: sql - fields: - new_col: "UPPER(col1)" - another_col: "col2 + col3" -``` - -## FlatMap - -Sometimes it may be desirable to emit more (or less) than one record for each -input record. This can be accomplished by mapping to an iterable type and -following the mapping with an Explode operation, e.g. - -``` -- type: MapToFields - config: - language: python - fields: - new_col: "[col1.upper(), col1.lower(), col1.title()]" - another_col: "col2 + col3" -- type: Explode - config: - fields: new_col -``` - -will result in three output records for every input record. - -If more than one record is to be exploded, one must specify whether the cross -product over all fields should be taken. For example - -``` -- type: MapToFields - config: - language: python - fields: - new_col: "[col1.upper(), col1.lower(), col1.title()]" - another_col: "[col2 - 1, col2, col2 + 1]" -- type: Explode - config: - fields: [new_col, another_col] - cross_product: true -``` - -will emit nine records whereas - -``` -- type: MapToFields - config: - language: python - fields: - new_col: "[col1.upper(), col1.lower(), col1.title()]" - another_col: "[col2 - 1, col2, col2 + 1]" -- type: Explode - config: - fields: [new_col, another_col] - cross_product: false -``` - -will only emit three. - -The `Explode` operation can be used on its own if the field in question is -already an iterable type. - -``` -- type: Explode - config: - fields: [col1] -``` - -## Filtering - -Sometimes it can be desirable to only keep records that satisfy a certain -criteria. This can be accomplished with a `Filter` transform, e.g. - -``` -- type: Filter - config: - language: sql - keep: "col2 > 0" -``` - -## Types - -Beam will try to infer the types involved in the mappings, but sometimes this -is not possible. In these cases one can explicitly denote the expected output -type, e.g. - -``` -- type: MapToFields - config: - language: python - fields: - new_col: - expression: "col1.upper()" - output_type: string -``` - -The expected type is given in json schema notation, with the addition that -a top-level basic types may be given as a literal string rather than requiring -a `{type: 'basic_type_name'}` nesting. - -``` -- type: MapToFields - config: - language: python - fields: - new_col: - expression: "col1.upper()" - output_type: string - another_col: - expression: "beam.Row(a=col1, b=[col2])" - output_type: - type: 'object' - properties: - a: - type: 'string' - b: - type: 'array' - items: - type: 'number' -``` - -This can be especially useful to resolve errors involving the inability to -handle the `beam:logical:pythonsdk_any:v1` type. +The contents of this file have been moved to the main Apache Beam site at +https://beam.apache.org/documentation/sdks/yaml-udf/ diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py b/sdks/python/apache_beam/yaml/yaml_mapping_test.py index 0de2f7022550..9dca107dca51 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py @@ -16,13 +16,11 @@ # import logging -import os import unittest import apache_beam as beam from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to -from apache_beam.yaml.readme_test import createTestSuite from apache_beam.yaml.yaml_transform import YamlTransform DATA = [ @@ -154,10 +152,6 @@ def test_validate_explicit_types(self): self.assertEqual(result.element_type._fields[0][1], str) -YamlMappingDocTest = createTestSuite( - 'YamlMappingDocTest', - os.path.join(os.path.dirname(__file__), 'yaml_mapping.md')) - if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/container/py310/base_image_requirements.txt b/sdks/python/container/py310/base_image_requirements.txt index 3a43dc7ffe7b..94b68ae836ab 100644 --- a/sdks/python/container/py310/base_image_requirements.txt +++ b/sdks/python/container/py310/base_image_requirements.txt @@ -32,27 +32,27 @@ charset-normalizer==3.3.2 click==8.1.7 cloudpickle==2.2.1 crcmod==1.7 -cryptography==42.0.2 +cryptography==42.0.3 Cython==0.29.37 Deprecated==1.2.14 deprecation==2.1.0 dill==0.3.1.1 -dnspython==2.5.0 +dnspython==2.6.0 docker==7.0.0 docopt==0.6.2 docstring-parser==0.15 exceptiongroup==1.2.0 execnet==2.0.2 -fastavro==1.9.3 +fastavro==1.9.4 fasteners==0.19 freezegun==1.4.0 future==0.18.3 -google-api-core==2.17.0 -google-api-python-client==2.117.0 +google-api-core==2.17.1 +google-api-python-client==2.118.0 google-apitools==0.5.31 -google-auth==2.27.0 +google-auth==2.28.0 google-auth-httplib2==0.1.1 -google-cloud-aiplatform==1.41.0 +google-cloud-aiplatform==1.42.1 google-cloud-bigquery==3.17.2 google-cloud-bigquery-storage==2.24.0 google-cloud-bigtable==2.23.0 @@ -80,7 +80,7 @@ grpcio-status==1.60.1 guppy3==3.1.4.post1 hdfs==2.7.3 httplib2==0.22.0 -hypothesis==6.98.3 +hypothesis==6.98.6 idna==3.6 iniconfig==2.0.0 joblib==1.3.2 @@ -95,14 +95,14 @@ nose==1.3.7 numpy==1.26.4 oauth2client==4.1.3 objsize==0.7.0 -orjson==3.9.13 +orjson==3.9.14 overrides==7.7.0 packaging==23.2 pandas==2.0.3 parameterized==0.9.0 pluggy==1.4.0 proto-plus==1.23.0 -protobuf==4.25.2 +protobuf==4.25.3 psycopg2-binary==2.9.9 pyarrow==14.0.2 pyarrow-hotfix==0.6 @@ -127,11 +127,11 @@ referencing==0.33.0 regex==2023.12.25 requests==2.31.0 requests-mock==1.11.0 -rpds-py==0.17.1 +rpds-py==0.18.0 rsa==4.9 -scikit-learn==1.4.0 +scikit-learn==1.4.1.post1 scipy==1.12.0 -shapely==2.0.2 +shapely==2.0.3 six==1.16.0 sortedcontainers==2.4.0 soupsieve==2.5 @@ -139,11 +139,11 @@ SQLAlchemy==1.4.51 sqlparse==0.4.4 tenacity==8.2.3 testcontainers==3.7.1 -threadpoolctl==3.2.0 +threadpoolctl==3.3.0 tomli==2.0.1 tqdm==4.66.2 typing_extensions==4.9.0 -tzdata==2023.4 +tzdata==2024.1 tzlocal==5.2 uritemplate==4.1.1 urllib3==2.2.0 diff --git a/sdks/python/container/py311/base_image_requirements.txt b/sdks/python/container/py311/base_image_requirements.txt index 067f053be7f7..13708c7a82e2 100644 --- a/sdks/python/container/py311/base_image_requirements.txt +++ b/sdks/python/container/py311/base_image_requirements.txt @@ -32,26 +32,26 @@ charset-normalizer==3.3.2 click==8.1.7 cloudpickle==2.2.1 crcmod==1.7 -cryptography==42.0.2 +cryptography==42.0.3 Cython==0.29.37 Deprecated==1.2.14 deprecation==2.1.0 dill==0.3.1.1 -dnspython==2.5.0 +dnspython==2.6.0 docker==7.0.0 docopt==0.6.2 docstring-parser==0.15 execnet==2.0.2 -fastavro==1.9.3 +fastavro==1.9.4 fasteners==0.19 freezegun==1.4.0 future==0.18.3 -google-api-core==2.17.0 -google-api-python-client==2.117.0 +google-api-core==2.17.1 +google-api-python-client==2.118.0 google-apitools==0.5.31 -google-auth==2.27.0 +google-auth==2.28.0 google-auth-httplib2==0.1.1 -google-cloud-aiplatform==1.41.0 +google-cloud-aiplatform==1.42.1 google-cloud-bigquery==3.17.2 google-cloud-bigquery-storage==2.24.0 google-cloud-bigtable==2.23.0 @@ -79,7 +79,7 @@ grpcio-status==1.60.1 guppy3==3.1.4.post1 hdfs==2.7.3 httplib2==0.22.0 -hypothesis==6.98.3 +hypothesis==6.98.6 idna==3.6 iniconfig==2.0.0 joblib==1.3.2 @@ -94,14 +94,14 @@ nose==1.3.7 numpy==1.26.4 oauth2client==4.1.3 objsize==0.7.0 -orjson==3.9.13 +orjson==3.9.14 overrides==7.7.0 packaging==23.2 pandas==2.0.3 parameterized==0.9.0 pluggy==1.4.0 proto-plus==1.23.0 -protobuf==4.25.2 +protobuf==4.25.3 psycopg2-binary==2.9.9 pyarrow==14.0.2 pyarrow-hotfix==0.6 @@ -125,11 +125,11 @@ referencing==0.33.0 regex==2023.12.25 requests==2.31.0 requests-mock==1.11.0 -rpds-py==0.17.1 +rpds-py==0.18.0 rsa==4.9 -scikit-learn==1.4.0 +scikit-learn==1.4.1.post1 scipy==1.12.0 -shapely==2.0.2 +shapely==2.0.3 six==1.16.0 sortedcontainers==2.4.0 soupsieve==2.5 @@ -137,10 +137,10 @@ SQLAlchemy==1.4.51 sqlparse==0.4.4 tenacity==8.2.3 testcontainers==3.7.1 -threadpoolctl==3.2.0 +threadpoolctl==3.3.0 tqdm==4.66.2 typing_extensions==4.9.0 -tzdata==2023.4 +tzdata==2024.1 tzlocal==5.2 uritemplate==4.1.1 urllib3==2.2.0 diff --git a/sdks/python/container/py38/base_image_requirements.txt b/sdks/python/container/py38/base_image_requirements.txt index 71013cbde747..687b61481ec8 100644 --- a/sdks/python/container/py38/base_image_requirements.txt +++ b/sdks/python/container/py38/base_image_requirements.txt @@ -33,27 +33,27 @@ charset-normalizer==3.3.2 click==8.1.7 cloudpickle==2.2.1 crcmod==1.7 -cryptography==42.0.2 +cryptography==42.0.3 Cython==0.29.37 Deprecated==1.2.14 deprecation==2.1.0 dill==0.3.1.1 -dnspython==2.5.0 +dnspython==2.6.0 docker==7.0.0 docopt==0.6.2 docstring-parser==0.15 exceptiongroup==1.2.0 execnet==2.0.2 -fastavro==1.9.3 +fastavro==1.9.4 fasteners==0.19 freezegun==1.4.0 future==0.18.3 -google-api-core==2.17.0 -google-api-python-client==2.117.0 +google-api-core==2.17.1 +google-api-python-client==2.118.0 google-apitools==0.5.31 -google-auth==2.27.0 +google-auth==2.28.0 google-auth-httplib2==0.1.1 -google-cloud-aiplatform==1.41.0 +google-cloud-aiplatform==1.42.1 google-cloud-bigquery==3.17.2 google-cloud-bigquery-storage==2.24.0 google-cloud-bigtable==2.23.0 @@ -81,7 +81,7 @@ grpcio-status==1.60.1 guppy3==3.1.4.post1 hdfs==2.7.3 httplib2==0.22.0 -hypothesis==6.98.3 +hypothesis==6.98.6 idna==3.6 importlib-metadata==7.0.1 importlib-resources==6.1.1 @@ -98,7 +98,7 @@ nose==1.3.7 numpy==1.24.4 oauth2client==4.1.3 objsize==0.7.0 -orjson==3.9.13 +orjson==3.9.14 overrides==7.7.0 packaging==23.2 pandas==2.0.3 @@ -106,7 +106,7 @@ parameterized==0.9.0 pkgutil_resolve_name==1.3.10 pluggy==1.4.0 proto-plus==1.23.0 -protobuf==4.25.2 +protobuf==4.25.3 psycopg2-binary==2.9.9 pyarrow==14.0.2 pyarrow-hotfix==0.6 @@ -131,11 +131,11 @@ referencing==0.33.0 regex==2023.12.25 requests==2.31.0 requests-mock==1.11.0 -rpds-py==0.17.1 +rpds-py==0.18.0 rsa==4.9 scikit-learn==1.3.2 scipy==1.10.1 -shapely==2.0.2 +shapely==2.0.3 six==1.16.0 sortedcontainers==2.4.0 soupsieve==2.5 @@ -143,11 +143,11 @@ SQLAlchemy==1.4.51 sqlparse==0.4.4 tenacity==8.2.3 testcontainers==3.7.1 -threadpoolctl==3.2.0 +threadpoolctl==3.3.0 tomli==2.0.1 tqdm==4.66.2 typing_extensions==4.9.0 -tzdata==2023.4 +tzdata==2024.1 tzlocal==5.2 uritemplate==4.1.1 urllib3==2.2.0 diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt index 6440c1322c59..bfe9a74e1453 100644 --- a/sdks/python/container/py39/base_image_requirements.txt +++ b/sdks/python/container/py39/base_image_requirements.txt @@ -32,27 +32,27 @@ charset-normalizer==3.3.2 click==8.1.7 cloudpickle==2.2.1 crcmod==1.7 -cryptography==42.0.2 +cryptography==42.0.3 Cython==0.29.37 Deprecated==1.2.14 deprecation==2.1.0 dill==0.3.1.1 -dnspython==2.5.0 +dnspython==2.6.0 docker==7.0.0 docopt==0.6.2 docstring-parser==0.15 exceptiongroup==1.2.0 execnet==2.0.2 -fastavro==1.9.3 +fastavro==1.9.4 fasteners==0.19 freezegun==1.4.0 future==0.18.3 -google-api-core==2.17.0 -google-api-python-client==2.117.0 +google-api-core==2.17.1 +google-api-python-client==2.118.0 google-apitools==0.5.31 -google-auth==2.27.0 +google-auth==2.28.0 google-auth-httplib2==0.1.1 -google-cloud-aiplatform==1.41.0 +google-cloud-aiplatform==1.42.1 google-cloud-bigquery==3.17.2 google-cloud-bigquery-storage==2.24.0 google-cloud-bigtable==2.23.0 @@ -80,7 +80,7 @@ grpcio-status==1.60.1 guppy3==3.1.4.post1 hdfs==2.7.3 httplib2==0.22.0 -hypothesis==6.98.3 +hypothesis==6.98.6 idna==3.6 importlib-metadata==7.0.1 iniconfig==2.0.0 @@ -96,14 +96,14 @@ nose==1.3.7 numpy==1.26.4 oauth2client==4.1.3 objsize==0.7.0 -orjson==3.9.13 +orjson==3.9.14 overrides==7.7.0 packaging==23.2 pandas==2.0.3 parameterized==0.9.0 pluggy==1.4.0 proto-plus==1.23.0 -protobuf==4.25.2 +protobuf==4.25.3 psycopg2-binary==2.9.9 pyarrow==14.0.2 pyarrow-hotfix==0.6 @@ -128,11 +128,11 @@ referencing==0.33.0 regex==2023.12.25 requests==2.31.0 requests-mock==1.11.0 -rpds-py==0.17.1 +rpds-py==0.18.0 rsa==4.9 -scikit-learn==1.4.0 +scikit-learn==1.4.1.post1 scipy==1.12.0 -shapely==2.0.2 +shapely==2.0.3 six==1.16.0 sortedcontainers==2.4.0 soupsieve==2.5 @@ -140,11 +140,11 @@ SQLAlchemy==1.4.51 sqlparse==0.4.4 tenacity==8.2.3 testcontainers==3.7.1 -threadpoolctl==3.2.0 +threadpoolctl==3.3.0 tomli==2.0.1 tqdm==4.66.2 typing_extensions==4.9.0 -tzdata==2023.4 +tzdata==2024.1 tzlocal==5.2 uritemplate==4.1.1 urllib3==2.2.0 diff --git a/sdks/python/pytest.ini b/sdks/python/pytest.ini index bf4c1aa909a8..e78697169bb0 100644 --- a/sdks/python/pytest.ini +++ b/sdks/python/pytest.ini @@ -65,6 +65,8 @@ markers = uses_tf: tests that utilize tensorflow. uses_transformers: tests that utilize transformers in some way. vertex_ai_postcommit: vertex ai postcommits that need additional deps. + uses_redis: enrichment transform tests that need redis. + uses_mock_api: tests that uses the mock API cluster. # Default timeout intended for unit tests. # If certain tests need a different value, please see the docs on how to diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index 82740ae67c9f..1e25f54dc462 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -134,7 +134,7 @@ autodoc_member_order = 'bysource' autodoc_mock_imports = ["tensorrt", "cuda", "torch", "onnxruntime", "onnx", "tensorflow", "tensorflow_hub", "tensorflow_transform", "tensorflow_metadata", "transformers", "xgboost", "datatable", "transformers", - "sentence_transformers", + "sentence_transformers", "redis", "tensorflow_text", ] # Allow a special section for documenting DataFrame API @@ -204,6 +204,10 @@ ignore_identifiers = [ 'apache_beam.typehints.typehints.validate_composite_type_param()', 'apache_beam.utils.windowed_value._IntervalWindowBase', 'apache_beam.coders.coder_impl.StreamCoderImpl', + 'apache_beam.io.requestresponse.Caller', + 'apache_beam.io.requestresponse.Repeater', + 'apache_beam.io.requestresponse.PreCallThrottler', + 'apache_beam.io.requestresponse.Cache', # Private classes which are used within the same module 'apache_beam.transforms.external_test.PayloadBase', diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 7b635365cc64..409951cbc415 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -16,8 +16,10 @@ # """Apache Beam SDK for Python setup file.""" +import glob import logging import os +import shutil import subprocess import sys import warnings @@ -148,11 +150,11 @@ def cythonize(*args, **kwargs): # not called even though S3 was initialized. This could lead to a # segmentation fault at exit. Keep pyarrow<13 until this is resolved. pyarrow_dependency = [ - 'pyarrow>=3.0.0,<12.0.0', - # NOTE: We can remove this once Beam increases the pyarrow lower bound - # to a version that fixes CVE. - 'pyarrow-hotfix<1' - ] + 'pyarrow>=3.0.0,<12.0.0', + # NOTE: We can remove this once Beam increases the pyarrow lower bound + # to a version that fixes CVE. + 'pyarrow-hotfix<1' + ] else: pyarrow_dependency = [ 'pyarrow>=3.0.0,<15.0.0', @@ -161,7 +163,6 @@ def cythonize(*args, **kwargs): 'pyarrow-hotfix<1' ] - # Exclude pandas<=1.4.2 since it doesn't work with numpy 1.24.x. # Exclude 1.5.0 and 1.5.1 because of # https://github.com/pandas-dev/pandas/issues/45725 @@ -169,12 +170,14 @@ def cythonize(*args, **kwargs): 'pandas>=1.4.3,!=1.5.0,!=1.5.1,<2.1;python_version>="3.8"', ] + def find_by_ext(root_dir, ext): for root, _, files in os.walk(root_dir): for file in files: if file.endswith(ext): yield os.path.realpath(os.path.join(root, file)) + # We must generate protos after setup_requires are installed. def generate_protos_first(): try: @@ -186,23 +189,42 @@ def generate_protos_first(): # skip proto generation in that case. if not os.path.exists(os.path.join(cwd, 'gen_protos.py')): # make sure we already generated protos - pb2_files = list(find_by_ext(os.path.join( - cwd, 'apache_beam', 'portability', 'api'), '_pb2.py')) + pb2_files = list( + find_by_ext( + os.path.join(cwd, 'apache_beam', 'portability', 'api'), + '_pb2.py')) if not pb2_files: - raise RuntimeError('protobuf files are not generated. ' - 'Please generate pb2 files') + raise RuntimeError( + 'protobuf files are not generated. ' + 'Please generate pb2 files') warnings.warn('Skipping proto generation as they are already generated.') return - out = subprocess.run([ - sys.executable, - os.path.join(cwd, 'gen_protos.py'), - '--no-force' - ], capture_output=True, check=True) + out = subprocess.run( + [sys.executable, os.path.join(cwd, 'gen_protos.py'), '--no-force'], + capture_output=True, + check=True) print(out.stdout) except subprocess.CalledProcessError as err: - raise RuntimeError('Could not generate protos due to error: %s', - err.stderr) + raise RuntimeError('Could not generate protos due to error: %s', err.stderr) + + +def copy_tests_from_docs(): + python_root = os.path.abspath(os.path.dirname(__file__)) + docs_src = os.path.normpath( + os.path.join( + python_root, '../../website/www/site/content/en/documentation/sdks')) + docs_dest = os.path.normpath( + os.path.join(python_root, 'apache_beam/yaml/docs')) + if os.path.exists(docs_src): + shutil.rmtree(docs_dest, ignore_errors=True) + os.mkdir(docs_dest) + for path in glob.glob(os.path.join(docs_src, 'yaml*.md')): + shutil.copy(path, docs_dest) + else: + if not os.path.exists(docs_dest): + raise RuntimeError( + f'Could not locate yaml docs in {docs_src} or {docs_dest}.') def generate_external_transform_wrappers(): @@ -277,24 +299,27 @@ def get_portability_package_data(): generate_external_transform_wrappers() + # These data files live elsewhere in the full Beam repository. + copy_tests_from_docs() + # generate cythonize extensions only if we are building a wheel or # building an extension or running in editable mode. cythonize_cmds = ('bdist_wheel', 'build_ext', 'editable_wheel') if any(cmd in sys.argv for cmd in cythonize_cmds): extensions = cythonize([ - 'apache_beam/**/*.pyx', - 'apache_beam/coders/coder_impl.py', - 'apache_beam/metrics/cells.py', - 'apache_beam/metrics/execution.py', - 'apache_beam/runners/common.py', - 'apache_beam/runners/worker/logger.py', - 'apache_beam/runners/worker/opcounters.py', - 'apache_beam/runners/worker/operations.py', - 'apache_beam/transforms/cy_combiners.py', - 'apache_beam/transforms/stats.py', - 'apache_beam/utils/counters.py', - 'apache_beam/utils/windowed_value.py', - ]) + 'apache_beam/**/*.pyx', + 'apache_beam/coders/coder_impl.py', + 'apache_beam/metrics/cells.py', + 'apache_beam/metrics/execution.py', + 'apache_beam/runners/common.py', + 'apache_beam/runners/worker/logger.py', + 'apache_beam/runners/worker/opcounters.py', + 'apache_beam/runners/worker/operations.py', + 'apache_beam/transforms/cy_combiners.py', + 'apache_beam/transforms/stats.py', + 'apache_beam/utils/counters.py', + 'apache_beam/utils/windowed_value.py', + ]) else: extensions = [] # Keep all dependencies inlined in the setup call, otherwise Dependabot won't @@ -319,6 +344,7 @@ def get_portability_package_data(): '*/*/*.h', 'testing/data/*.yaml', 'yaml/*.yaml', + 'yaml/docs/*.md', *get_portability_package_data() ] }, diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 9ab34d904cbc..cadf3a6ae2c6 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -487,6 +487,35 @@ task tftTests { } } +// Tests that depend on Mock API:https://github.com/apache/beam/tree/master/.test-infra/mock-apis. . +task mockAPITests { + dependsOn 'initializeForDataflowJob' + dependsOn ':sdks:python:sdist' + def requirementsFile = "${rootDir}/sdks/python/apache_beam/io/requestresponse_tests_requirements.txt" + doFirst { + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && pip install -r $requirementsFile" + } + } + doLast { + def testOpts = basicTestOpts + def argMap = [ + "test_opts": testOpts, + "collect": "uses_mock_api", + "runner": "TestDataflowRunner", + "project": "apache-beam-testing", + "region": "us-west1", + "requirements_file": "$requirementsFile" + ] + def cmdArgs = mapToArgString(argMap) + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && ${runScriptsDir}/run_integration_test.sh $cmdArgs" + } + } +} + // add all RunInference E2E tests that run on DataflowRunner // As of now, this test suite is enable in py38 suite as the base NVIDIA image used for Tensor RT // contains Python 3.8. @@ -495,6 +524,7 @@ project.tasks.register("inferencePostCommitIT") { dependsOn = [ 'tensorRTtests', 'vertexAIInferenceTest', + 'mockAPITests', ] } diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle index ea1eaf726b52..b5680c2e1e9a 100644 --- a/sdks/python/test-suites/direct/common.gradle +++ b/sdks/python/test-suites/direct/common.gradle @@ -364,6 +364,33 @@ task transformersInferenceTest { } } +// Enrichment transform tests that uses Redis +task enrichmentRedisTest { + dependsOn 'installGcpTest' + dependsOn ':sdks:python:sdist' + def requirementsFile = "${rootDir}/sdks/python/apache_beam/io/requestresponse_tests_requirements.txt" + doFirst { + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && pip install -r $requirementsFile" + } + } + doLast { + def testOpts = basicTestOpts + def argMap = [ + "test_opts": testOpts, + "suite": "postCommitIT-direct-py${pythonVersionSuffix}", + "collect": "uses_redis", + "runner": "TestDirectRunner" + ] + def cmdArgs = mapToArgString(argMap) + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && ${runScriptsDir}/run_integration_test.sh $cmdArgs" + } + } +} + // Add all the RunInference framework IT tests to this gradle task that runs on Direct Runner Post commit suite. project.tasks.register("inferencePostCommitIT") { dependsOn = [ @@ -372,6 +399,7 @@ project.tasks.register("inferencePostCommitIT") { 'tensorflowInferenceTest', 'xgboostInferenceTest', 'transformersInferenceTest', + 'enrichmentRedisTest', // (TODO) https://github.com/apache/beam/issues/25799 // uncomment tfx bsl tests once tfx supports protobuf 4.x // 'tfxInferenceTest', diff --git a/sdks/python/test-suites/tox/py38/build.gradle b/sdks/python/test-suites/tox/py38/build.gradle index ea34eec1fae4..60245e84c73f 100644 --- a/sdks/python/test-suites/tox/py38/build.gradle +++ b/sdks/python/test-suites/tox/py38/build.gradle @@ -149,6 +149,16 @@ toxTask "testPy38embeddingsMLTransform", "py38-embeddings", "${posargs}" test.dependsOn "testPy38embeddingsMLTransform" preCommitPyCoverage.dependsOn "testPy38embeddingsMLTransform" +// Part of MLTransform embeddings test suite but requires tensorflow hub, which we need to test on +// mutliple versions so keeping this suite separate. +toxTask "testPy38TensorflowHubEmbeddings-014", "py38-TFHubEmbeddings-014", "${posargs}" +test.dependsOn "testPy38TensorflowHubEmbeddings-014" +preCommitPyCoverage.dependsOn "testPy38TensorflowHubEmbeddings-014" + +toxTask "testPy38TensorflowHubEmbeddings-015", "py38-TFHubEmbeddings-015", "${posargs}" +test.dependsOn "testPy38TensorflowHubEmbeddings-015" +preCommitPyCoverage.dependsOn "testPy38TensorflowHubEmbeddings-015" + toxTask "whitespacelint", "whitespacelint", "${posargs}" task archiveFilesToLint(type: Zip) { diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index 5fadd05391ce..ca35c383eea1 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -440,3 +440,17 @@ commands = /bin/sh -c "pip freeze | grep -E google-cloud-aiplatform" # Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories. /bin/sh -c 'pytest apache_beam/ml/transforms/embeddings -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 {posargs}; ret=$?; [ $ret = 5 ] && exit 0 || exit $ret' + + +[testenv:py{38,39,310,311}-TFHubEmbeddings-{014,015}] +deps = + 014: tensorflow-hub>=0.14.0,<0.15.0 + 015: tensorflow-hub>=0.15.0,<0.16.0 + tensorflow-text # required to register ops for text embedding models. + +extras = test,gcp +commands = + # Log aiplatform and its dependencies version for debugging + /bin/sh -c "pip freeze | grep -E tensorflow" + # Allow exit code 5 (no tests run) so that we can run this command safely on arbitrary subdirectories. + bash {toxinidir}/scripts/run_pytest.sh {envname} 'apache_beam/ml/transforms/embeddings' \ No newline at end of file diff --git a/website/www/site/content/en/documentation/sdks/yaml-inline-python.md b/website/www/site/content/en/documentation/sdks/yaml-inline-python.md index a3df36022943..2afdadb9360b 100644 --- a/website/www/site/content/en/documentation/sdks/yaml-inline-python.md +++ b/website/www/site/content/en/documentation/sdks/yaml-inline-python.md @@ -21,186 +21,7 @@ title: "Apache Beam YAML Inline Python" under the License. --> -# Using PyTransform form YAML +# Using PyTransform from YAML -Beam YAML provides the ability to easily invoke Python transforms via the -`PyTransform` type, simply referencing them by fully qualified name. -For example, - -``` -- type: PyTransform - config: - constructor: apache_beam.pkg.module.SomeTransform - args: [1, 'foo'] - kwargs: - baz: 3 -``` - -will invoke the transform `apache_beam.pkg.mod.SomeTransform(1, 'foo', baz=3)`. -This fully qualified name can be any PTransform class or other callable that -returns a PTransform. Note, however, that PTransforms that do not accept or -return schema'd data may not be as useable to use from YAML. -Restoring the schema-ness after a non-schema returning transform can be done -by using the `callable` option on `MapToFields` which takes the entire element -as an input, e.g. - -``` -- type: PyTransform - config: - constructor: apache_beam.pkg.module.SomeTransform - args: [1, 'foo'] - kwargs: - baz: 3 -- type: MapToFields - config: - language: python - fields: - col1: - callable: 'lambda element: element.col1' - output_type: string - col2: - callable: 'lambda element: element.col2' - output_type: integer -``` - -This can be used to call arbitrary transforms in the Beam SDK, e.g. - -``` -pipeline: - transforms: - - type: PyTransform - name: ReadFromTsv - input: {} - config: - constructor: apache_beam.io.ReadFromCsv - kwargs: - path: '/path/to/*.tsv' - sep: '\t' - skip_blank_lines: True - true_values: ['yes'] - false_values: ['no'] - comment: '#' - on_bad_lines: 'skip' - binary: False - splittable: False -``` - - -## Defining a transform inline using `__constructor__` - -If the desired transform does not exist, one can define it inline as well. -This is done with the special `__constructor__` keywords, -similar to how cross-language transforms are done. - -With the `__constuctor__` keyword, one defines a Python callable that, on -invocation, *returns* the desired transform. The first argument (or `source` -keyword argument, if there are no positional arguments) -is interpreted as the Python code. For example - -``` -- type: PyTransform - config: - constructor: __constructor__ - kwargs: - source: | - import apache_beam as beam - - def create_my_transform(inc): - return beam.Map(lambda x: beam.Row(a=x.col2 + inc)) - - inc: 10 -``` - -will apply `beam.Map(lambda x: beam.Row(a=x.col2 + 10))` to the incoming -PCollection. - -As a class object can be invoked as its own constructor, this allows one to -define a `beam.PTransform` inline, e.g. - -``` -- type: PyTransform - config: - constructor: __constructor__ - kwargs: - source: | - class MyPTransform(beam.PTransform): - def __init__(self, inc): - self._inc = inc - def expand(self, pcoll): - return pcoll | beam.Map(lambda x: beam.Row(a=x.col2 + self._inc)) - - inc: 10 -``` - -which works exactly as one would expect. - - -## Defining a transform inline using `__callable__` - -The `__callable__` keyword works similarly, but instead of defining a -callable that returns an applicable `PTransform` one simply defines the -expansion to be performed as a callable. This is analogous to BeamPython's -`ptransform.ptransform_fn` decorator. - -In this case one can simply write - -``` -- type: PyTransform - config: - constructor: __callable__ - kwargs: - source: | - def my_ptransform(pcoll, inc): - return pcoll | beam.Map(lambda x: beam.Row(a=x.col2 + inc)) - - inc: 10 -``` - - -# External transforms - -One can also invoke PTransforms define elsewhere via a `python` provider, -for example - -``` -pipeline: - transforms: - - ... - - type: MyTransform - config: - kwarg: whatever - -providers: - - ... - - type: python - input: ... - config: - packages: - - 'some_pypi_package>=version' - transforms: - MyTransform: 'pkg.module.MyTransform' -``` - -These can be defined inline as well, with or without dependencies, e.g. - -``` -pipeline: - transforms: - - ... - - type: ToCase - input: ... - config: - upper: True - -providers: - - type: python - config: {} - transforms: - 'ToCase': | - @beam.ptransform_fn - def ToCase(pcoll, upper): - if upper: - return pcoll | beam.Map(lambda x: str(x).upper()) - else: - return pcoll | beam.Map(lambda x: str(x).lower()) -``` +The contents of this file have been moved to the main Apache Beam site at +https://beam.apache.org/documentation/sdks/yaml-inline-python/