From 6746fbec7f7900237e3e3c07c9bbf308e5239a24 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 10 Mar 2023 14:46:34 -0800 Subject: [PATCH 01/13] Use the cord ctype for proto generation, where available. This is natural as this field is often large and composed of the concatenation of many small elements, which can lead to significant performance improvements. --- .../org/apache/beam/model/fn_execution/v1/beam_fn_api.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto index 975d18cbdfb7..2e728fd657f7 100644 --- a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto +++ b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto @@ -663,7 +663,7 @@ message Elements { // (Optional) Represents a part of a logical byte stream. Elements within // the logical byte stream are encoded in the nested context and // concatenated together. - bytes data = 3; + bytes data = 3 [ctype = CORD]; // (Optional) Set this bit to indicate the this is the last data block // for the given instruction and transform, ending the stream. From 1a4b2cd1a6ef6a5a80b5f3e2d7dc41c2f7d832e8 Mon Sep 17 00:00:00 2001 From: Bjorn Pedersen Date: Mon, 20 Mar 2023 09:59:15 -0400 Subject: [PATCH 02/13] added test for KMS key --- sdks/python/apache_beam/io/gcp/gcsio_test.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index 260090461c8c..db6a992632b5 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -62,12 +62,13 @@ def __init__(self): class FakeFile(object): def __init__( - self, bucket, obj, contents, generation, crc32c=None, last_updated=None): + self, bucket, obj, contents, generation, crc32c=None, kms_key=None, last_updated=None): self.bucket = bucket self.object = obj self.contents = contents self.generation = generation self.crc32c = crc32c + self.kms_key = kms_key self.last_updated = last_updated def get_metadata(self): @@ -82,6 +83,7 @@ def get_metadata(self): generation=self.generation, size=len(self.contents), crc32c=self.crc32c, + kmsKeyName=self.kms_key, updated=last_updated_datetime) @@ -320,6 +322,7 @@ def _insert_random_file( size, generation=1, crc32c=None, + kms_key=None, last_updated=None, fail_when_getting_metadata=False, fail_when_reading=False): @@ -330,6 +333,7 @@ def _insert_random_file( os.urandom(size), generation, crc32c=crc32c, + kms_key=None, last_updated=last_updated) client.objects.add_file(f, fail_when_getting_metadata, fail_when_reading) return f @@ -395,6 +399,16 @@ def test_size(self): self.assertTrue(self.gcs.exists(file_name)) self.assertEqual(1234, self.gcs.size(file_name)) + def test_kms_key(self): + file_name = 'gs://gcsio-test/dummy_file' + file_size = 1234 + kms_key = "dummy" + + self._insert_random_file( + self.client, file_name, file_size, kms_key=kms_key) + self.assertTrue(self.gcs.exists(file_name)) + self.assertEqual(kms_key, self.gcs.kms_key(file_name)) + def test_last_updated(self): file_name = 'gs://gcsio-test/dummy_file' file_size = 1234 From b5ce1106a24912d650e89e140fd98cc00c609b50 Mon Sep 17 00:00:00 2001 From: liferoad Date: Mon, 20 Mar 2023 11:27:35 -0400 Subject: [PATCH 03/13] Add one example to learn beam by doing (#25719) * Add one example to learn beam by doing * add license * clear the output * Polished the notebook based on the comments * cleared out the output * Changed the answer views * Update some cells * Polished the words * Update examples/notebooks/get-started/learn_beam_basics_by_doing.ipynb Co-authored-by: tvalentyn --------- Co-authored-by: xqhu Co-authored-by: tvalentyn --- .../learn_beam_basics_by_doing.ipynb | 1095 +++++++++++++++++ 1 file changed, 1095 insertions(+) create mode 100644 examples/notebooks/get-started/learn_beam_basics_by_doing.ipynb diff --git a/examples/notebooks/get-started/learn_beam_basics_by_doing.ipynb b/examples/notebooks/get-started/learn_beam_basics_by_doing.ipynb new file mode 100644 index 000000000000..e44b2d8d2164 --- /dev/null +++ b/examples/notebooks/get-started/learn_beam_basics_by_doing.ipynb @@ -0,0 +1,1095 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V913rQcmLS72" + }, + "source": [ + "# Welcome to Apache Beam!\n", + "\n", + "This notebook will be your introductory guide to Beam's main concepts and its uses. This tutorial **does not** assume any prior Apache Beam knowledge.\n", + "\n", + "We'll cover what Beam is, what it does, and a few basic transforms!\n", + "\n", + "We aim to give you familiarity with:\n", + "- Creating a `Pipeline`\n", + "- Creating a `PCollection`\n", + "- Performing basic `PTransforms`\n", + " - Map\n", + " - Filter\n", + " - FlatMap\n", + " - Combine\n", + "- Applications\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Lduh-9oXt3P_" + }, + "source": [ + "## How To Approach This Tutorial\n", + "\n", + "This tutorial was designed for someone who likes to **learn by doing**. As such, there will be opportunities for you to practice writing your own code in these notebooks with the answer hidden in a cell below.\n", + "\n", + "Codes that require editing will be with an `...` and each cell title will say `Edit This Code`. However, you are free to play around with the other cells if you would like to add something beyond our tutorial.\n", + "\n", + "It may be tempting to just copy and paste solutions, but even if you do look at the Answer cells, try typing out the solutions manually. The muscle memory will be very helpful.\n", + "\n", + "> Tip: For those who would like to learn concepts more from the ground up, check out these [notebooks](https://beam.apache.org/get-started/tour-of-beam/)!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "42B-64Lvef3K" + }, + "source": [ + "## Prerequisites\n", + "\n", + "We'll assume you have familiarity with Python or Pandas, but you should be able to follow along even if you’re coming from a different programming language. We'll also assume you understand programming concepts like functions, objects, arrays, and dictionaries.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SeDD0nardXyL" + }, + "source": [ + "## Running CoLab\n", + "\n", + "To navigate through different sections, use the table of contents. From View drop-down list, select Table of contents.\n", + "\n", + "To run a code cell, you can click the Run cell button at the top left of the cell, or by select it and press `Shift+Enter`. Try modifying a code cell and re-running it to see what happens." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UdMPnMDDkGc8" + }, + "source": [ + "To begin, we have to set up our environment. Let's install and import Apache Beam by running the cell below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5cPHukaOgDDM" + }, + "outputs": [], + "source": [ + "# Remember: You can press shift+enter to run this cell\n", + "!pip install --quiet apache-beam\n", + "import apache_beam as beam" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "30l8_MD-undP" + }, + "outputs": [], + "source": [ + "# Set the logging level to reduce verbose information\n", + "import logging\n", + "\n", + "logging.root.setLevel(logging.ERROR)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gyoo1gLKtZmU" + }, + "source": [ + "\n", + "\n", + "---\n", + "\n", + "\n", + "\n", + "---\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N29KJTfdMCtB" + }, + "source": [ + "# What is Apache Beam?" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B2evhzEuMu8a" + }, + "source": [ + "Apache Beam is a library for data processing. It is often used for [Extract-Transform-Load (ETL)](https://en.wikipedia.org/wiki/Extract,_transform,_load) jobs, where we:\n", + "1. *Extract* from a data source\n", + "2. *Transform* that data\n", + "3. *Load* that data into a data sink (like a database)\n", + "\n", + "Apache Beam makes these jobs easy with the ability to process everything at the same time and its unified model and open-source SDKs. There are many more parts of Beam, but throughout these tutorials, we will break down each part to show you how they will all fit together.\n", + "\n", + "For this tutorial, you will use these Beam SDKs to build your own `Pipeline` to process your data.\n", + "\n", + "Below, we will run through creating the heart of the `Pipeline`. There are three main abstractions in Beam:\n", + "1. `Pipeline`\n", + "2. `PCollection`\n", + "3. `PTransform`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WecDJbfqpWb1" + }, + "source": [ + "\n", + "\n", + "---\n", + "\n", + "\n", + "---\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FditMJAIAp9Q" + }, + "source": [ + "# 1. Design Your Pipeline" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_xYj6e3Hvt1W" + }, + "source": [ + "## 1.1 What is a Pipeline" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "bVJCh9FLBtU9" + }, + "source": [ + "A `Pipeline` describes the whole cycle of your data processing task, starting from the data sources to the processing transforms you will apply to them until your desired output.\n", + "\n", + "`Pipeline` is responsible for reading, processing, and saving the data. \n", + "Each `PTransform` is done on or outputs a `PCollection`, and this process is done in your `Pipeline`.\n", + "More glossary details can be found at [here](https://beam.apache.org/documentation/glossary/).\n", + "\n", + "A diagram of this process is shown below:\n", + "\n", + "" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "cyn8_mB6zN4m" + }, + "source": [ + "In code, this process will look like this:\n", + "\n", + "\n", + "```\n", + "# Each `step` represents a specific transform. After `step3`, it will save the data reference to `outputs`.\n", + "outputs = pipeline | step1 | step2 | step3\n", + "```\n", + "\n", + ">The pipe operator `|` applies the `PTransform` on the right side of the pipe to the input `PCollection`.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gUyk2UWypI7g" + }, + "source": [ + "Pipelines can quickly grow long, so it's sometimes easier to read if we surround them with parentheses and break them into multiple lines.\n", + "\n", + "```\n", + "# This is equivalent to the example above.\n", + "outputs = (\n", + " pipeline\n", + " | step1\n", + " | step2\n", + " | step3\n", + ")\n", + "```\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7hmlbjPlvZVY" + }, + "source": [ + "Sometimes, the transform names aren't very descriptive. Beam allows each transform, or step, to have a unique label, or description. This makes it a lot easier to debug, and it's in general a good practice to start.\n", + "\n", + "> You can use the right shift operator `>>` to add a label to your transforms, like `'My description' >> MyTransform`.\n", + "\n", + "```\n", + "# Try to give short but descriptive labels.\n", + "# These serve both as comments and help debug later on.\n", + "outputs = (\n", + " pipeline\n", + " | 'First step' >> step1\n", + " | 'Second step' >> step2\n", + " | 'Third step' >> step3\n", + ")\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "GXQ__Kwxvr3J" + }, + "source": [ + "## 1.2 Loading Our Data\n", + "\n", + "Now, you can try to write your own pipeline!\n", + "\n", + "First, let's load the example data we will be using throughout this tutorial into our file directory. This [dataset](https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection) consists of a **collection of SMS messages in English tagged as either \"spam\" or \"ham\" (a legitimate SMS\n", + ")**.\n", + "\n", + "For this tutorial, we will create a pipeline to **explore the dataset using Beam to count words in SMS messages that contain spam or ham**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BkZ0wxsBPuvO" + }, + "outputs": [], + "source": [ + "# Creates a data directory with our dataset SMSSpamCollection\n", + "!mkdir -p data\n", + "!gsutil cp gs://apachebeamdt/SMSSpamCollection data/" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TiHFDIMnLTbm" + }, + "source": [ + "**What does the data look like?**\n", + "\n", + "This dataset is a `txt` file with 5,574 rows and 4 columns recording the following attributes:\n", + "1. `Column 1`: The label (either `ham` or `spam`)\n", + "2. `Column 2`: The SMS as raw text (type `string`)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! head data/SMSSpamCollection" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r-0WM38KNaTI" + }, + "source": [ + "## 1.3 Writing Our Own Pipeline\n", + "\n", + "Now that we understand our dataset, let's go into creating our pipeline.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1Ajz7-WQCVfj" + }, + "source": [ + "To initialize a `Pipeline`, you first assign your pipeline `beam.Pipeline()` to a name. Assign your pipeline to the name, `pipeline`, in the code cell below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cXm0n80E9ZwT" + }, + "outputs": [], + "source": [ + "#@title Edit This Code Cell\n", + "..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "q0gzEcos9hjp" + }, + "outputs": [], + "source": [ + "#@title Answer\n", + "pipeline = beam.Pipeline()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "-K6tpdVK9KzC" + }, + "source": [ + "This pipeline will be where we create our transformed `PCollection`. In Beam, your data lives in a `PCollection`, which stands for `Parallel Collection`.\n", + "\n", + "A **PCollection** is like a list of elements, but without any order guarantees. This allows Beam to easily parallelize and distribute the `PCollection`'s elements.\n", + "\n", + "Now, let's use one of Beam's `Read` transforms to turn our text file (our dataset) into a `PCollection`. " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MS0XWqTiAIlE" + }, + "source": [ + "## 1.4 Reading from Text File\n", + "\n", + "We can use the\n", + "[`ReadFromText`](https://beam.apache.org/releases/pydoc/current/apache_beam.io.textio.html#apache_beam.io.textio.ReadFromText)\n", + "transform to read text files into `str` elements.\n", + "\n", + "It takes a\n", + "[_glob pattern_](https://en.wikipedia.org/wiki/Glob_%28programming%29)\n", + "as an input, and reads all the files that match that pattern. For example, in the pattern `data/*.txt`, the `*` is a wildcard that matches anything. This pattern matches all the files in the `data/` directory with a `.txt` extension. It then **returns one element for each line** in the file.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PXUSzLhhBwdr" + }, + "source": [ + "Because we only want this pipeline to read the `SMSSpamCollection` file that's in the `data/` directory, we will specify the input pattern to be `'data/SMSSpamCollection'`.\n", + "\n", + "We will then use that input pattern with our transform `beam.io.ReadFromText()` and apply it onto our pipeline. The `beam.io.ReadFromText()` transform can take in an input pattern as an input." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HYs5697QEArB" + }, + "outputs": [], + "source": [ + "#@title Hint\n", + "# If you get stuck on the syntax, use the Table of Contents to navigate to 1.1\n", + "# What is a Pipeline and reread that section." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "p5z6C65tEmtM" + }, + "outputs": [], + "source": [ + "#@title Edit This Code\n", + "\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "\n", + "pipeline = beam.Pipeline()\n", + "\n", + "outputs = (\n", + " pipeline\n", + " | beam.io.ReadFromText(inputs_pattern)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "zC9blgIXBv5A" + }, + "outputs": [], + "source": [ + "#@title Answer\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "\n", + "pipeline = beam.Pipeline()\n", + "\n", + "# Remember: | is the apply function in Beam in Python\n", + "outputs = (\n", + " pipeline\n", + " # Remember to add short descriptions to your transforms for good practice and easier understanding\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ISm7PF9wE7Ts" + }, + "source": [ + "## 1.5 Writing to Text File\n", + "\n", + "Now, how do we know if we did it correctly? Let's take a look at the text file you just read.\n", + "\n", + "You may have noticed that we can't simply `print` the output `PCollection` to see the elements. In Beam, you can __NOT__ access the elements from a `PCollection` directly like a Python list.\n", + "\n", + "This is because, depending on the runner,\n", + "the `PCollection` elements might live in multiple worker machines.\n", + "\n", + "However, we can see our output `PCollection` by using a [`WriteToText`](https://beam.apache.org/releases/pydoc/2.27.0/apache_beam.io.textio.html#apache_beam.io.textio.WriteToText) transform to turn our `str` elements into a `txt` file (or another file type of your choosing) and then running a command to show the head of our output file." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iz73k9d1IUGw" + }, + "source": [ + "> `beam.io.WriteToText` takes a _file path prefix_ as an input, and it writes the all `str` elements into one or more files with filenames starting with that prefix.\n", + "> You can optionally pass a `file_name_suffix` as well, usually used for the file extension.\n", + "> Each element goes into its own line in the output files." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D80HrbeKFxCv" + }, + "source": [ + "Now, you can try it. Save the results to a file path prefix `'output'` and make the file_name_suffix `'.txt'`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "v6RvVVINKOt9" + }, + "outputs": [], + "source": [ + "#@title Edit This Code\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "\n", + "pipeline = beam.Pipeline()\n", + "\n", + "# Remember: | is the apply function in Beam in Python\n", + "outputs = (\n", + " pipeline\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + " | 'Write results' >> beam.io.WriteToText(..., file_name_suffix = ...)\n", + " # To see the results from the previous transform\n", + " | 'Print the text file name' >> beam.Map(print) # or beam.LogElements()\n", + ")\n", + "\n", + "# To run the pipeline\n", + "pipeline.run()\n", + "\n", + "# The command used to view your output txt file.\n", + "# If you choose to save the file path prefix to a different location or change the file type,\n", + "# you have to update this command as well.\n", + "! head output*.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "nRYxclPaJArp" + }, + "outputs": [], + "source": [ + "#@title Answer\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "\n", + "pipeline = beam.Pipeline()\n", + "\n", + "outputs = (\n", + " pipeline\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + " # ADDED\n", + " | 'Write results' >> beam.io.WriteToText(\"ansoutput1\", file_name_suffix = \".txt\")\n", + " | 'Print the text file name' >> beam.Map(print)\n", + ")\n", + "\n", + "pipeline.run()\n", + "\n", + "# The file this data is saved to is called \"ansoutput1\" as seen in the WriteToText transform.\n", + "# The command below and the transform input should match.\n", + "! head ansoutput1*.txt" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "iie23Id-O33B" + }, + "source": [ + "# 2. PTransforms\n", + "\n", + "Now that we have read in our code, we can now process our data to explore the text messages that could be classified as spam or non-spam (ham).\n", + "\n", + "In order to achieve this, we need to use `PTransforms`.\n", + "\n", + "A **`PTransform`** is any data processing operation that performs a processing function on one or more `PCollection`, outputting zero or more `PCollection`.\n", + "\n", + "Some PTransforms accept user-defined functions that apply custom logic, which you will learn in the *Advanced Transforms* notebook. The “P” stands for “parallel.”" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vPhZG2GgPCW7" + }, + "source": [ + "## 2.1 Map\n", + "\n", + "One feature to use for the classifier to distinguish spam SMS from ham SMS is to compare the distribution of common words between the two categories. To find the common words for the two categories, we want to perform a frequency count of each word in the data set." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cR1YWyvQj1tm" + }, + "source": [ + "First, because the data set is read line by line, let's clean up the `PCollection` so that the label and the SMS is separated.\n", + "\n", + "To do so, we will use the transform `Map`, which takes a **function** and **maps it** to **each element** of the collection and transforms a single input `a` to a single output `b`.\n", + "\n", + "In this case, we will use `beam.Map` which takes in a lambda function and uses regex to split the line into a two item list: [label, SMS]. The lambda function we will use is `lambda line: line.split(\"\\t\")`, which splits each element of the `PCollection` by tab and putting them into a list.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Xey-I5BFvCiH" + }, + "source": [ + "Add a line of code between the `ReadFromText` and `WriteToText` transform that applies a `beam.Map` transform that takes in the function described above. Remember to add a short description for your transform!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hdrOsitSuj7z" + }, + "outputs": [], + "source": [ + "#@title Edit This Code\n", + "import re\n", + "\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "\n", + "pipeline = beam.Pipeline()\n", + "\n", + "outputs = (\n", + " pipeline\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + " ...\n", + " | 'Write results' >> beam.io.WriteToText(\"output2\", file_name_suffix = \".txt\")\n", + " | 'Print the text file name' >> beam.Map(print)\n", + ")\n", + "\n", + "pipeline.run()\n", + "\n", + "! head output2*.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "2k4e3j24wTzM" + }, + "outputs": [], + "source": [ + "#@title Answer\n", + "import re\n", + "\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "\n", + "pipeline = beam.Pipeline()\n", + "\n", + "outputs = (\n", + " pipeline\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + " # ADDED\n", + " | 'Separate to list' >> beam.Map(lambda line: line.split(\"\\t\"))\n", + " | 'Write results' >> beam.io.WriteToText(\"ansoutput2\", file_name_suffix = \".txt\")\n", + " | 'Print the text file name' >> beam.Map(print)\n", + ")\n", + "\n", + "pipeline.run()\n", + "\n", + "! head ansoutput2*.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ju8TXxBYwvzX" + }, + "source": [ + "## 2.2 Filter\n", + "\n", + "Now that we have a list separating the label and the SMS, let's first focus on only counting words with the **spam** label. In order to process certain elements while igorning others, we want to filter out specific elements in a collection using the transform `Filter`.\n", + "\n", + "`beam.Filter` takes in a function that checks a single element a, and returns True to keep the element, or False to discard it.\n", + "\n", + "In this case, we want `Filter` to return true if the list contains the label **spam**.\n", + "\n", + "We will use a lambda function again for this example, but this time, you will write the lambda function yourself. Add a line of code after your `beam.Map` transform to only return a `PCollection` that only contains lists with the label **spam**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a72v9SQ0zb5u" + }, + "outputs": [], + "source": [ + "#@title Edit This Code\n", + "import re\n", + "\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "\n", + "pipeline = beam.Pipeline()\n", + "\n", + "outputs = (\n", + " pipeline\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + " | 'Separate to list' >> beam.Map(lambda line: line.split(\"\\t\"))\n", + " ...\n", + " | 'Write results' >> beam.io.WriteToText(\"ansoutput3\", file_name_suffix = \".txt\")\n", + " | 'Print the text file name' >> beam.Map(print)\n", + ")\n", + "\n", + "pipeline.run()\n", + "\n", + "! head ansoutput3*.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "BmvUeOztzCv0" + }, + "outputs": [], + "source": [ + "#@title Answer\n", + "import re\n", + "\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "\n", + "pipeline = beam.Pipeline()\n", + "\n", + "outputs = (\n", + " pipeline\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + " | 'Separate to list' >> beam.Map(lambda line: line.split(\"\\t\"))\n", + " # ADDED\n", + " | 'Keep only spam' >> beam.Filter(lambda line: line[0] == \"spam\")\n", + " | 'Write results' >> beam.io.WriteToText(\"ansoutput3\", file_name_suffix = \".txt\")\n", + " | 'Print the text file name' >> beam.Map(print)\n", + ")\n", + "\n", + "pipeline.run()\n", + "\n", + "! head ansoutput3*.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v4EcJw45zy6Y" + }, + "source": [ + "## 2.3 FlatMap\n", + "\n", + "Now, that we know we only have SMS labelled spam, we now need to change the element such that instead of each element being a list containing the label and the SMS, each element is a word in the SMS.\n", + "\n", + "We can't use `Map`, since `Map` allows us to transform each individual element, but we can't change the number of elements with it.\n", + "\n", + "Instead, we want to map a function to each element of a collection. That function returns a list of output elements, so we would get a list of lists of elements. Then we want to flatten the list of lists into a single list.\n", + "\n", + "To do this, we will use `FlatMap`, which takes a **function** that transforms a single input `a` into an **iterable of outputs** `b`. But we get a **single collection** containing the outputs of all the elements. In this case, all these elements will be the words found in the SMS." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3V3w2cz11S_8" + }, + "source": [ + "Add a `FlatMap` transform that takes in the function `lambda line: re.findall(r\"[a-zA-Z']+\", line[1])` to your code below. The lambda function finds words by finding all elements in the SMS that match the specifications of the regex." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "USHHIPO91xDn" + }, + "outputs": [], + "source": [ + "#@title Edit This Code\n", + "import re\n", + "\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "\n", + "pipeline = beam.Pipeline()\n", + "\n", + "outputs = (\n", + " pipeline\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + " | 'Separate to list' >> beam.Map(lambda line: line.split(\"\\t\"))\n", + " | 'Keep only spam' >> beam.Filter(lambda line: line[0] == \"spam\")\n", + " ...\n", + " | 'Write results' >> beam.io.WriteToText(\"ansoutput3\", file_name_suffix = \".txt\")\n", + " | 'Print the text file name' >> beam.Map(print)\n", + ")\n", + "\n", + "pipeline.run()\n", + "\n", + "! head ansoutput3*.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "VFuRMIsN0Edn" + }, + "outputs": [], + "source": [ + "#@title Answer\n", + "import re\n", + "\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "\n", + "pipeline = beam.Pipeline()\n", + "\n", + "outputs = (\n", + " pipeline\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + " | 'Separate to list' >> beam.Map(lambda line: line.split(\"\\t\"))\n", + " | 'Keep only spam' >> beam.Filter(lambda line: line[0] == \"spam\")\n", + " # ADDED\n", + " | 'Find words' >> beam.FlatMap(lambda line: re.findall(r\"[a-zA-Z']+\", line[1]))\n", + " | 'Write results' >> beam.io.WriteToText(\"ansoutput3\", file_name_suffix = \".txt\")\n", + " | 'Print the text file name' >> beam.Map(print)\n", + ")\n", + "\n", + "pipeline.run()\n", + "\n", + "! head ansoutput3*.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BFJMyIyJ17lE" + }, + "source": [ + "## 2.4 Combine\n", + "\n", + "Now that each word is one element, we have to count up the elements. To do that we can use [aggregation](https://beam.apache.org/documentation/transforms/python/overview/) transforms, specifically `CombinePerKey` in this instance which transforms an iterable of inputs a, and returns a single output a based on their key.\n", + "\n", + "Before using `CombinePerKey` however, we have to associate each word with a numerical value to then combine them.\n", + "\n", + "To do this, we add `| 'Pair words with 1' >> beam.Map(lambda word: (word, 1))` to the `Pipeline`, which associates each word with the numerical value 1.\n", + "\n", + "With each word assigned to a numerical value, we can now combine these numerical values to sum up all the counts of each word. Like the past transforms, `CombinePerKey` takes in a function and applies it to each element of the `PCollection`.\n", + "\n", + "However, instead of writing our own lambda function, we can use pass one of Beam's built-in function `sum` into `CombinePerKey`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Yw_3dfnzI2xA" + }, + "outputs": [], + "source": [ + "#@title Edit This Code\n", + "import re\n", + "\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "\n", + "pipeline = beam.Pipeline()\n", + "\n", + "outputs = (\n", + " pipeline\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + " | 'Separate to list' >> beam.Map(lambda line: line.split(\"\\t\"))\n", + " | 'Keep only spam' >> beam.Filter(lambda line: line[0] == \"spam\")\n", + " | 'Find words' >> beam.FlatMap(lambda line: re.findall(r\"[a-zA-Z']+\", line[1]))\n", + " | 'Pair words with 1' >> beam.Map(lambda word: (word, 1))\n", + " ...\n", + " | 'Write results' >> beam.io.WriteToText(\"ansoutput4\", file_name_suffix = \".txt\")\n", + " | 'Print the text file name' >> beam.Map(print)\n", + ")\n", + "\n", + "pipeline.run()\n", + "\n", + "! head ansoutput4*.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "qDyh9_faIkCo" + }, + "outputs": [], + "source": [ + "#@title Answer\n", + "import re\n", + "\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "\n", + "pipeline = beam.Pipeline()\n", + "\n", + "outputs = (\n", + " pipeline\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + " | 'Separate to list' >> beam.Map(lambda line: line.split(\"\\t\"))\n", + " | 'Keep only spam' >> beam.Filter(lambda line: line[0] == \"spam\")\n", + " | 'Find words' >> beam.FlatMap(lambda line: re.findall(r\"[a-zA-Z']+\", line[1]))\n", + " | 'Pair words with 1' >> beam.Map(lambda word: (word, 1))\n", + " # ADDED\n", + " | 'Group and sum' >> beam.CombinePerKey(sum)\n", + " | 'Write results' >> beam.io.WriteToText(\"ansoutput4\", file_name_suffix = \".txt\")\n", + " | 'Print the text file name' >> beam.Map(print)\n", + ")\n", + "\n", + "pipeline.run()\n", + "\n", + "! head ansoutput4*.txt" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "L9zNI4LMJHur" + }, + "source": [ + "And we finished! Now that we have a count of all the words to gain better understanding about our dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BozCqbzUPItB" + }, + "source": [ + "\n", + "\n", + "---\n", + "\n", + "\n", + "\n", + "---\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Cv9508k4MeS4" + }, + "source": [ + "# Full Spam Ham Apache Beam Example\n", + "\n", + "Below is a summary of all the code we performed for your convenience. Note you do not need to explicitly call run with the with statement." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8iWWy3Y7Msxm" + }, + "outputs": [], + "source": [ + "!pip install --quiet apache-beam\n", + "!mkdir -p data\n", + "!gsutil cp gs://apachebeamdt/SMSSpamCollection data/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "43oijI4BCAW1" + }, + "outputs": [], + "source": [ + "import apache_beam as beam\n", + "import re\n", + "\n", + "inputs_pattern = 'data/SMSSpamCollection'\n", + "outputs_prefix_ham = 'outputs/fullcodeham'\n", + "outputs_prefix_spam = 'outputs/fullcodespam'\n", + "\n", + "# Ham Word Count\n", + "with beam.Pipeline() as pipeline:\n", + " ham = (\n", + " pipeline\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + " | 'Separate to list' >> beam.Map(lambda line: line.split(\"\\t\"))\n", + " | 'Keep only ham' >> beam.Filter(lambda line: line[0] == \"ham\")\n", + " | 'Find words' >> beam.FlatMap(lambda line: re.findall(r\"[a-zA-Z']+\", line[1]))\n", + " | 'Pair words with 1' >> beam.Map(lambda word: (word, 1))\n", + " | 'Group and sum' >> beam.CombinePerKey(sum)\n", + " | 'Format results' >> beam.Map(lambda word_c: str(word_c))\n", + " | 'Write results' >> beam.io.WriteToText(outputs_prefix_ham, file_name_suffix = \".txt\")\n", + " )\n", + "\n", + "# Spam Word Count\n", + "with beam.Pipeline() as pipeline1:\n", + " spam = (\n", + " pipeline1\n", + " | 'Take in Dataset' >> beam.io.ReadFromText(inputs_pattern)\n", + " | 'Separate to list' >> beam.Map(lambda line: line.split(\"\\t\"))\n", + " | 'Filter out only spam' >> beam.Filter(lambda line: line[0] == \"spam\")\n", + " | 'Find words' >> beam.FlatMap(lambda line: re.findall(r\"[a-zA-Z']+\", line[1]))\n", + " | 'Pair words with 1' >> beam.Map(lambda word: (word, 1))\n", + " | 'Group and sum' >> beam.CombinePerKey(sum)\n", + " | 'Format results' >> beam.Map(lambda word_c: str(word_c))\n", + " | 'Write results' >> beam.io.WriteToText(outputs_prefix_spam, file_name_suffix = \".txt\")\n", + " )\n", + "\n", + "print('Ham Word Count Head')\n", + "! head outputs/fullcodeham*.txt\n", + "\n", + "print('Spam Word Count Head')\n", + "! head outputs/fullcodespam*.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FNg53x4fyXOO" + }, + "source": [ + "**One more thing: you can also visualize the Beam pipelines!**\n", + "\n", + "Check [this example](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/runners/interactive/examples/Interactive%20Beam%20Example.ipynb) to learn more about the interactive Beam." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "k5cYnOElwrm7" + }, + "outputs": [], + "source": [ + "import apache_beam.runners.interactive.interactive_beam as ib\n", + "ib.show_graph(pipeline)" + ] + } + ], + "license": [ + "Licensed to the Apache Software Foundation (ASF) under one", + "or more contributor license agreements. See the NOTICE file", + "distributed with this work for additional information", + "regarding copyright ownership. The ASF licenses this file", + "to you under the Apache License, Version 2.0 (the", + "\"License\"); you may not use this file except in compliance", + "with the License. You may obtain a copy of the License at", + "", + " http://www.apache.org/licenses/LICENSE-2.0", + "", + "Unless required by applicable law or agreed to in writing,", + "software distributed under the License is distributed on an", + "\"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY", + "KIND, either express or implied. See the License for the", + "specific language governing permissions and limitations", + "under the License." + ], + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "vscode": { + "interpreter": { + "hash": "aab5fceeb08468f7e142944162550e82df74df803ff2eb1987d9526d4285522f" + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 62b9b3814096fa796a291f596337c6d854a88442 Mon Sep 17 00:00:00 2001 From: Bjorn Pedersen Date: Mon, 20 Mar 2023 12:54:09 -0400 Subject: [PATCH 04/13] fix for test for KMS key --- sdks/python/apache_beam/io/gcp/gcsio_test.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index db6a992632b5..7d504368ea0c 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -62,7 +62,14 @@ def __init__(self): class FakeFile(object): def __init__( - self, bucket, obj, contents, generation, crc32c=None, kms_key=None, last_updated=None): + self, + bucket, + obj, + contents, + generation, + crc32c=None, + kms_key=None, + last_updated=None): self.bucket = bucket self.object = obj self.contents = contents @@ -333,7 +340,7 @@ def _insert_random_file( os.urandom(size), generation, crc32c=crc32c, - kms_key=None, + kms_key=kms_key, last_updated=last_updated) client.objects.add_file(f, fail_when_getting_metadata, fail_when_reading) return f From 6cb7b8e5f82b7022a75d5feb13bc851009a0cf19 Mon Sep 17 00:00:00 2001 From: Robert Burke Date: Mon, 20 Mar 2023 16:40:09 -0700 Subject: [PATCH 05/13] Ensure truncate element is wrapped in *FullValue (#25908) Co-authored-by: lostluck <13907733+lostluck@users.noreply.github.com> --- sdks/go/pkg/beam/core/runtime/exec/sdf.go | 27 ++++++++------------ sdks/go/test/integration/primitives/drain.go | 24 +++++++++-------- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf.go b/sdks/go/pkg/beam/core/runtime/exec/sdf.go index 1dd3e35dc4d3..6482bfb3a6ae 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/sdf.go +++ b/sdks/go/pkg/beam/core/runtime/exec/sdf.go @@ -297,10 +297,10 @@ func (n *TruncateSizedRestriction) StartBundle(ctx context.Context, id string, d // Input Diagram: // // *FullValue { -// Elm: *FullValue { -// Elm: *FullValue (original input) +// Elm: *FullValue { -- mainElm +// Elm: *FullValue (original input) -- inp // Elm2: *FullValue { -// Elm: Restriction +// Elm: Restriction -- rest // Elm2: Watermark estimator state // } // } @@ -325,24 +325,19 @@ func (n *TruncateSizedRestriction) StartBundle(ctx context.Context, id string, d // } func (n *TruncateSizedRestriction) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error { mainElm := elm.Elm.(*FullValue) - inp := mainElm.Elm - // For the main element, the way we fill it out depends on whether the input element - // is a KV or single-element. Single-elements might have been lifted out of - // their FullValue if they were decoded, so we need to have a case for that. - // TODO(https://github.com/apache/beam/issues/20196): Optimize this so it's decided in exec/translate.go - // instead of checking per-element. - if e, ok := mainElm.Elm.(*FullValue); ok { - mainElm = e - inp = e - } - rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm + + // If receiving directly from a datasource, + // the element may not be wrapped in a *FullValue + inp := convertIfNeeded(mainElm.Elm, &FullValue{}) + + rest := mainElm.Elm2.(*FullValue).Elm rt, err := n.ctInv.Invoke(ctx, rest) if err != nil { return err } - newRest, err := n.truncateInv.Invoke(ctx, rt, mainElm) + newRest, err := n.truncateInv.Invoke(ctx, rt, inp) if err != nil { return err } @@ -351,7 +346,7 @@ func (n *TruncateSizedRestriction) ProcessElement(ctx context.Context, elm *Full return nil } - size, err := n.sizeInv.Invoke(ctx, mainElm, newRest) + size, err := n.sizeInv.Invoke(ctx, inp, newRest) if err != nil { return err } diff --git a/sdks/go/test/integration/primitives/drain.go b/sdks/go/test/integration/primitives/drain.go index 2e861f54615c..d116dfa8bd3a 100644 --- a/sdks/go/test/integration/primitives/drain.go +++ b/sdks/go/test/integration/primitives/drain.go @@ -28,7 +28,7 @@ import ( ) func init() { - register.DoFn3x1[*sdf.LockRTracker, []byte, func(int64), sdf.ProcessContinuation](&TruncateFn{}) + register.DoFn4x1[context.Context, *sdf.LockRTracker, []byte, func(int64), sdf.ProcessContinuation](&TruncateFn{}) register.Emitter1[int64]() } @@ -83,9 +83,14 @@ func (fn *TruncateFn) SplitRestriction(_ []byte, rest offsetrange.Restriction) [ } // TruncateRestriction truncates the restriction during drain. -func (fn *TruncateFn) TruncateRestriction(rt *sdf.LockRTracker, _ []byte) offsetrange.Restriction { - start := rt.GetRestriction().(offsetrange.Restriction).Start +func (fn *TruncateFn) TruncateRestriction(ctx context.Context, rt *sdf.LockRTracker, _ []byte) offsetrange.Restriction { + rest := rt.GetRestriction().(offsetrange.Restriction) + start := rest.Start newEnd := start + 20 + + done, remaining := rt.GetProgress() + log.Infof(ctx, "Draining at: done %v, remaining %v, start %v, end %v, newEnd %v", done, remaining, start, rest.End, newEnd) + return offsetrange.Restriction{ Start: start, End: newEnd, @@ -93,29 +98,26 @@ func (fn *TruncateFn) TruncateRestriction(rt *sdf.LockRTracker, _ []byte) offset } // ProcessElement continually gets the start position of the restriction and emits the element as it is. -func (fn *TruncateFn) ProcessElement(rt *sdf.LockRTracker, _ []byte, emit func(int64)) sdf.ProcessContinuation { +func (fn *TruncateFn) ProcessElement(ctx context.Context, rt *sdf.LockRTracker, _ []byte, emit func(int64)) sdf.ProcessContinuation { position := rt.GetRestriction().(offsetrange.Restriction).Start - counter := 0 for { if rt.TryClaim(position) { + log.Infof(ctx, "Claimed position: %v", position) // Successful claim, emit the value and move on. emit(position) position++ - counter++ } else if rt.GetError() != nil || rt.IsDone() { // Stop processing on error or completion if err := rt.GetError(); err != nil { - log.Errorf(context.Background(), "error in restriction tracker, got %v", err) + log.Errorf(ctx, "error in restriction tracker, got %v", err) } + log.Infof(ctx, "Restriction done at position %v.", position) return sdf.StopProcessing() } else { + log.Infof(ctx, "Checkpointed at position %v, resuming later.", position) // Resume later. return sdf.ResumeProcessingIn(5 * time.Second) } - - if counter >= 10 { - return sdf.ResumeProcessingIn(1 * time.Second) - } time.Sleep(1 * time.Second) } } From 5f9bf8b74f3d1e19b326963258c48df9372ca2dc Mon Sep 17 00:00:00 2001 From: Nick Li <56149585+nickuncaged1201@users.noreply.github.com> Date: Mon, 20 Mar 2023 18:29:30 -0700 Subject: [PATCH 06/13] Pubsub test client fixup (#25907) * make incoming message public * add set clock method to test client * spotless --- .../org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java index 43dc244f5c25..575957c60728 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubTestClient.java @@ -100,7 +100,11 @@ private static class State { private static final State STATE = new State(); /** Closing the factory will validate all expected messages were processed. */ - public interface PubsubTestClientFactory extends PubsubClientFactory, Closeable, Serializable {} + public interface PubsubTestClientFactory extends PubsubClientFactory, Closeable, Serializable { + default PubsubIO.Read setClock(PubsubIO.Read readTransform, Clock clock) { + return readTransform.withClock(clock); + } + } /** * Return a factory for testing publishers. Only one factory may be in-flight at a time. The From 9a5e5b8ab76e8a9f08141e269f8dd9f0e6c35894 Mon Sep 17 00:00:00 2001 From: liferoad Date: Tue, 21 Mar 2023 11:11:52 -0400 Subject: [PATCH 07/13] Raise the Runtime error when DoFn.process uses both yield and return (#25743) Co-authored-by: xqhu --- sdks/python/apache_beam/transforms/core.py | 81 +++++++++++-- .../apache_beam/transforms/core_test.py | 113 ++++++++++++++++++ 2 files changed, 187 insertions(+), 7 deletions(-) create mode 100644 sdks/python/apache_beam/transforms/core_test.py diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 47aaeff43a6f..6260975b32c9 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -28,6 +28,7 @@ import traceback import types import typing +from itertools import dropwhile from apache_beam import coders from apache_beam import pvalue @@ -1387,6 +1388,59 @@ def partition_for(self, element, num_partitions, *args, **kwargs): return self._fn(element, num_partitions, *args, **kwargs) +def _get_function_body_without_inners(func): + source_lines = inspect.getsourcelines(func)[0] + source_lines = dropwhile(lambda x: x.startswith("@"), source_lines) + def_line = next(source_lines).strip() + if def_line.startswith("def ") and def_line.endswith(":"): + first_line = next(source_lines) + indentation = len(first_line) - len(first_line.lstrip()) + final_lines = [first_line[indentation:]] + + skip_inner_def = False + if first_line[indentation:].startswith("def "): + skip_inner_def = True + for line in source_lines: + line_indentation = len(line) - len(line.lstrip()) + + if line[indentation:].startswith("def "): + skip_inner_def = True + continue + + if skip_inner_def and line_indentation == indentation: + skip_inner_def = False + + if skip_inner_def and line_indentation > indentation: + continue + final_lines.append(line[indentation:]) + + return "".join(final_lines) + else: + return def_line.rsplit(":")[-1].strip() + + +def _check_fn_use_yield_and_return(fn): + if isinstance(fn, types.BuiltinFunctionType): + return False + try: + source_code = _get_function_body_without_inners(fn) + has_yield = False + has_return = False + for line in source_code.split("\n"): + if line.lstrip().startswith("yield ") or line.lstrip().startswith( + "yield("): + has_yield = True + if line.lstrip().startswith("return ") or line.lstrip().startswith( + "return("): + has_return = True + if has_yield and has_return: + return True + return False + except Exception as e: + _LOGGER.debug(str(e)) + return False + + class ParDo(PTransformWithSideInputs): """A :class:`ParDo` transform. @@ -1427,6 +1481,14 @@ def __init__(self, fn, *args, **kwargs): if not isinstance(self.fn, DoFn): raise TypeError('ParDo must be called with a DoFn instance.') + # DoFn.process cannot allow both return and yield + if _check_fn_use_yield_and_return(self.fn.process): + _LOGGER.warning( + 'Using yield and return in the process method ' + 'of %s can lead to unexpected behavior, see:' + 'https://github.com/apache/beam/issues/22969.', + self.fn.__class__) + # Validate the DoFn by creating a DoFnSignature from apache_beam.runners.common import DoFnSignature self._signature = DoFnSignature(self.fn) @@ -2663,6 +2725,7 @@ def from_runner_api_parameter(unused_ptransform, combine_payload, context): class CombineValuesDoFn(DoFn): """DoFn for performing per-key Combine transforms.""" + def __init__( self, input_pcoll_type, @@ -2725,6 +2788,7 @@ def default_type_hints(self): class _CombinePerKeyWithHotKeyFanout(PTransform): + def __init__( self, combine_fn, # type: CombineFn @@ -2939,11 +3003,12 @@ class GroupBy(PTransform): The GroupBy operation can be made into an aggregating operation by invoking its `aggregate_field` method. """ + def __init__( self, *fields, # type: typing.Union[str, typing.Callable] **kwargs # type: typing.Union[str, typing.Callable] - ): + ): if len(fields) == 1 and not kwargs: self._force_tuple_keys = False name = fields[0] if isinstance(fields[0], str) else 'key' @@ -2966,7 +3031,7 @@ def aggregate_field( field, # type: typing.Union[str, typing.Callable] combine_fn, # type: typing.Union[typing.Callable, CombineFn] dest, # type: str - ): + ): """Returns a grouping operation that also aggregates grouped values. Args: @@ -3054,7 +3119,7 @@ def aggregate_field( field, # type: typing.Union[str, typing.Callable] combine_fn, # type: typing.Union[typing.Callable, CombineFn] dest, # type: str - ): + ): field = _expr_to_callable(field, 0) return _GroupAndAggregate( self._grouping, list(self._aggregations) + [(field, combine_fn, dest)]) @@ -3096,10 +3161,12 @@ class Select(PTransform): pcoll | beam.Map(lambda x: beam.Row(a=x.a, b=foo(x))) """ - def __init__(self, - *args, # type: typing.Union[str, typing.Callable] - **kwargs # type: typing.Union[str, typing.Callable] - ): + + def __init__( + self, + *args, # type: typing.Union[str, typing.Callable] + **kwargs # type: typing.Union[str, typing.Callable] + ): self._fields = [( expr if isinstance(expr, str) else 'arg%02d' % ix, _expr_to_callable(expr, ix)) for (ix, expr) in enumerate(args) diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py new file mode 100644 index 000000000000..0fba28266138 --- /dev/null +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -0,0 +1,113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for the core python file.""" +# pytype: skip-file + +import logging +import unittest + +import pytest + +import apache_beam as beam + + +class TestDoFn1(beam.DoFn): + def process(self, element): + yield element + + +class TestDoFn2(beam.DoFn): + def process(self, element): + def inner_func(x): + yield x + + return inner_func(element) + + +class TestDoFn3(beam.DoFn): + """mixing return and yield is not allowed""" + def process(self, element): + if not element: + return -1 + yield element + + +class TestDoFn4(beam.DoFn): + """test the variable name containing return""" + def process(self, element): + my_return = element + yield my_return + + +class TestDoFn5(beam.DoFn): + """test the variable name containing yield""" + def process(self, element): + my_yield = element + return my_yield + + +class TestDoFn6(beam.DoFn): + """test the variable name containing return""" + def process(self, element): + return_test = element + yield return_test + + +class TestDoFn7(beam.DoFn): + """test the variable name containing yield""" + def process(self, element): + yield_test = element + return yield_test + + +class TestDoFn8(beam.DoFn): + """test the code containing yield and yield from""" + def process(self, element): + if not element: + yield from [1, 2, 3] + else: + yield element + + +class CreateTest(unittest.TestCase): + @pytest.fixture(autouse=True) + def inject_fixtures(self, caplog): + self._caplog = caplog + + def test_dofn_with_yield_and_return(self): + warning_text = 'Using yield and return' + + with self._caplog.at_level(logging.WARNING): + assert beam.ParDo(sum) + assert beam.ParDo(TestDoFn1()) + assert beam.ParDo(TestDoFn2()) + assert beam.ParDo(TestDoFn4()) + assert beam.ParDo(TestDoFn5()) + assert beam.ParDo(TestDoFn6()) + assert beam.ParDo(TestDoFn7()) + assert beam.ParDo(TestDoFn8()) + assert warning_text not in self._caplog.text + + with self._caplog.at_level(logging.WARNING): + beam.ParDo(TestDoFn3()) + assert warning_text in self._caplog.text + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() From e08d93dd5dabc4f2b345eb334ac8323b13615b4c Mon Sep 17 00:00:00 2001 From: Bjorn Pedersen Date: Tue, 21 Mar 2023 16:56:20 -0400 Subject: [PATCH 08/13] replaced GcsIO non-batch methods with GCS equivalents --- sdks/python/apache_beam/io/gcp/gcsio.py | 153 +++++++++++------------- sdks/python/setup.py | 1 + 2 files changed, 71 insertions(+), 83 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index 2010d2292593..6d5ad597af34 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -40,6 +40,7 @@ from itertools import islice from typing import Optional from typing import Union +from google.cloud import storage import apache_beam from apache_beam.internal.http_client import get_new_http @@ -68,7 +69,6 @@ from apitools.base.py.exceptions import HttpError from apitools.base.py import transfer from apache_beam.internal.gcp import auth - from apache_beam.io.gcp.internal.clients import storage except ImportError: raise ImportError( 'Google Cloud Storage I/O not supported for this execution environment ' @@ -162,20 +162,13 @@ class GcsIOError(IOError, retry.PermanentException): class GcsIO(object): """Google Cloud Storage I/O client.""" def __init__(self, storage_client=None, pipeline_options=None): - # type: (Optional[storage.StorageV1], Optional[Union[dict, PipelineOptions]]) -> None if storage_client is None: if not pipeline_options: pipeline_options = PipelineOptions() elif isinstance(pipeline_options, dict): pipeline_options = PipelineOptions.from_dictionary(pipeline_options) - storage_client = storage.StorageV1( - credentials=auth.get_service_credentials(pipeline_options), - get_credentials=False, - http=get_new_http(), - response_encoding='utf8', - additional_http_headers={ - "User-Agent": "apache-beam-%s" % apache_beam.__version__ - }) + storage_client = storage.Client( + credentials=auth.get_service_credentials(pipeline_options)) self.client = storage_client self._rewrite_cb = None self.bucket_to_project_number = {} @@ -200,24 +193,24 @@ def _set_rewrite_response_callback(self, callback): def get_bucket(self, bucket_name): """Returns an object bucket from its name, or None if it does not exist.""" try: - request = storage.StorageBucketsGetRequest(bucket=bucket_name) - return self.client.buckets.Get(request) + return self.client.lookup_bucket(bucket_name) except HttpError: return None def create_bucket(self, bucket_name, project, kms_key=None, location=None): """Create and return a GCS bucket in a specific project.""" encryption = None - if kms_key: - encryption = storage.Bucket.EncryptionValue(kms_key) - - request = storage.StorageBucketsInsertRequest( - bucket=storage.Bucket( - name=bucket_name, location=location, encryption=encryption), - project=project, - ) + try: - return self.client.buckets.Insert(request) + bucket = self.client.create_bucket( + bucket_or_name=bucket_name, + project=project, + location=location, + ) + if kms_key: + bucket.default_kms_key_name(kms_key) + return self.get_bucket(bucket_name) + return bucket except HttpError: return None @@ -270,11 +263,10 @@ def delete(self, path): Args: path: GCS file path pattern in the form gs:///. """ - bucket, object_path = parse_gcs_path(path) - request = storage.StorageObjectsDeleteRequest( - bucket=bucket, object=object_path) + bucket_name, target_name = parse_gcs_path(path) try: - self.client.objects.Delete(request) + bucket = self.client.get_bucket(bucket_name) + bucket.delete_blob(target_name) except HttpError as http_error: if http_error.status_code == 404: # Return success when the file doesn't exist anymore for idempotency. @@ -329,47 +321,53 @@ def delete_batch(self, paths): def copy( self, src, - dest, - dest_kms_key_name=None, - max_bytes_rewritten_per_call=None): + dest): + # dest_kms_key_name=None, + # max_bytes_rewritten_per_call=None): """Copies the given GCS object from src to dest. Args: src: GCS file path pattern in the form gs:///. dest: GCS file path pattern in the form gs:///. + !!! dest_kms_key_name: Experimental. No backwards compatibility guarantees. Encrypt dest with this Cloud KMS key. If None, will use dest bucket encryption defaults. max_bytes_rewritten_per_call: Experimental. No backwards compatibility guarantees. Each rewrite API call will return after these many bytes. - Used for testing. + Used for testing. !!! Raises: TimeoutError: on timeout. """ - src_bucket, src_path = parse_gcs_path(src) - dest_bucket, dest_path = parse_gcs_path(dest) - request = storage.StorageObjectsRewriteRequest( - sourceBucket=src_bucket, - sourceObject=src_path, - destinationBucket=dest_bucket, - destinationObject=dest_path, - destinationKmsKeyName=dest_kms_key_name, - maxBytesRewrittenPerCall=max_bytes_rewritten_per_call) - response = self.client.objects.Rewrite(request) - while not response.done: - _LOGGER.debug( - 'Rewrite progress: %d of %d bytes, %s to %s', - response.totalBytesRewritten, - response.objectSize, - src, - dest) - request.rewriteToken = response.rewriteToken - response = self.client.objects.Rewrite(request) - if self._rewrite_cb is not None: - self._rewrite_cb(response) - - _LOGGER.debug('Rewrite done: %s to %s', src, dest) + src_bucket_name, src_path = parse_gcs_path(src) + dest_bucket_name, dest_path = parse_gcs_path(dest) + # request = storage.StorageObjectsRewriteRequest( + # sourceBucket=src_bucket, + # sourceObject=src_path, + # destinationBucket=dest_bucket, + # destinationObject=dest_path, + # destinationKmsKeyName=dest_kms_key_name, + # maxBytesRewrittenPerCall=max_bytes_rewritten_per_call) + src_bucket = self.get_bucket(src_bucket_name) + src_blob = src_bucket.get_blob(src_path) + dest_bucket = self.get_bucket(dest_bucket_name) + if not dest_path: + dest_path = None + response = src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_path) + # !!! while not response.done: + # _LOGGER.debug( + # 'Rewrite progress: %d of %d bytes, %s to %s', + # response.totalBytesRewritten, + # response.objectSize, + # src, + # dest) + # request.rewriteToken = response.rewriteToken + # response = self.client.objects.Rewrite(request) + # if self._rewrite_cb is not None: + # self._rewrite_cb(response) + + # _LOGGER.debug('Rewrite done: %s to %s', src, dest) !!! # We intentionally do not decorate this method with a retry, as retrying is # handled in BatchApiRequest.Execute(). @@ -565,10 +563,9 @@ def _gcs_object(self, path): Returns: GCS object. """ - bucket, object_path = parse_gcs_path(path) - request = storage.StorageObjectsGetRequest( - bucket=bucket, object=object_path) - return self.client.objects.Get(request) + bucket_name, object_path = parse_gcs_path(path) + bucket = self.client.get_bucket(bucket_name) + return bucket.get_blob(object_path) @deprecated(since='2.45.0', current='list_files') def list_prefix(self, path, with_metadata=False): @@ -604,7 +601,6 @@ def list_files(self, path, with_metadata=False): tuple(file name, tuple(size, timestamp)). """ bucket, prefix = parse_gcs_path(path, object_optional=True) - request = storage.StorageObjectsListRequest(bucket=bucket, prefix=prefix) file_info = set() counter = 0 start_time = time.time() @@ -612,35 +608,26 @@ def list_files(self, path, with_metadata=False): _LOGGER.debug("Starting the file information of the input") else: _LOGGER.debug("Starting the size estimation of the input") - while True: - response = retry.with_exponential_backoff( - retry_filter=retry.retry_on_server_errors_and_timeout_filter)( - self.client.objects.List)( - request) - - for item in response.items: - file_name = 'gs://%s/%s' % (item.bucket, item.name) - if file_name not in file_info: - file_info.add(file_name) - counter += 1 - if counter % 10000 == 0: - if with_metadata: - _LOGGER.info( - "Finished computing file information of: %s files", - len(file_info)) - else: - _LOGGER.info( - "Finished computing size of: %s files", len(file_info)) - + response = self.client.list_blobs(bucket, prefix=prefix) + for item in response: + file_name = 'gs://%s/%s' % (item.bucket(), item.name) + if file_name not in file_info: + file_info.add(file_name) + counter += 1 + if counter % 10000 == 0: if with_metadata: - yield file_name, (item.size, self._updated_to_seconds(item.updated)) + _LOGGER.info( + "Finished computing file information of: %s files", + len(file_info)) else: - yield file_name, item.size + _LOGGER.info( + "Finished computing size of: %s files", len(file_info)) + + if with_metadata: + yield file_name, (item.size(), self._updated_to_seconds(item.updated())) + else: + yield file_name, item.size() - if response.nextPageToken: - request.pageToken = response.nextPageToken - else: - break _LOGGER.log( # do not spam logs when list_prefix is likely used to check empty folder logging.INFO if counter > 0 else logging.DEBUG, diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 8b49937c97ee..dc4aae81ec98 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -300,6 +300,7 @@ def get_portability_package_data(): 'google-cloud-datastore>=2.0.0,<3', 'google-cloud-pubsub>=2.1.0,<3', 'google-cloud-pubsublite>=1.2.0,<2', + 'google-cloud-storage>=2.7.0,<3', # GCP packages required by tests 'google-cloud-bigquery>=2.0.0,<4', 'google-cloud-bigquery-storage>=2.6.3,<3', From 2c075b3603f009f36d809428c8cf7b97090a7138 Mon Sep 17 00:00:00 2001 From: Jack Dingilian Date: Tue, 21 Mar 2023 17:32:07 -0400 Subject: [PATCH 09/13] Advance DetectNewPartition's watermark by aggregating the (#25906) watermark of ReadChangeStreamPartitions --- .../changestreams/ByteStringRangeHelper.java | 158 +++++++++++------- .../action/DetectNewPartitionsAction.java | 92 ++++++++++ .../changestreams/dao/MetadataTableDao.java | 40 ++++- .../ByteStringRangeHelperTest.java | 105 ++++++++++++ .../action/DetectNewPartitionsActionTest.java | 61 +++++++ .../dao/MetadataTableDaoTest.java | 71 ++++++++ 6 files changed, 467 insertions(+), 60 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelper.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelper.java index 8d3b84f735de..c0448ecef13e 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelper.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelper.java @@ -20,6 +20,8 @@ import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange; import com.google.protobuf.ByteString; import com.google.protobuf.TextFormat; +import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.stream.Collectors; @@ -29,31 +31,6 @@ /** Helper functions to evaluate the completeness of collection of ByteStringRanges. */ @Internal public class ByteStringRangeHelper { - /** - * Returns formatted string of a partition for debugging. - * - * @param partition partition to format. - * @return String representation of partition. - */ - public static String formatByteStringRange(ByteStringRange partition) { - return "['" - + TextFormat.escapeBytes(partition.getStart()) - + "','" - + TextFormat.escapeBytes(partition.getEnd()) - + "')"; - } - - /** - * Convert partitions to a string for debugging. - * - * @param partitions to print - * @return string representation of partitions - */ - public static String partitionsToString(List partitions) { - return partitions.stream() - .map(ByteStringRangeHelper::formatByteStringRange) - .collect(Collectors.joining(", ", "{", "}")); - } @VisibleForTesting static class PartitionComparator implements Comparator { @@ -83,6 +60,104 @@ public int compare(ByteStringRange first, ByteStringRange second) { } } + /** + * Returns true if parentPartitions is a superset of childPartition. + * + *

If ordered parentPartitions row ranges form a contiguous range, and start key is before or + * at childPartition's start key, and end key is at or after childPartition's end key, then + * parentPartitions is a superset of childPartition. + * + *

Overlaps from parents are valid because arbitrary partitions can merge and they may overlap. + * They will form a valid new partition. However, if there are any missing parent partitions, then + * merge cannot happen with missing row ranges. + * + * @param parentPartitions list of partitions to determine if it forms a large contiguous range + * @param childPartition the smaller partition + * @return true if parentPartitions is a superset of childPartition, otherwise false. + */ + public static boolean isSuperset( + List parentPartitions, ByteStringRange childPartition) { + // sort parentPartitions by starting key + // iterate through, check open end key and close start key of each iteration to ensure no gaps. + // first start key and last end key must be equal to or wider than child partition start and end + // key. + if (parentPartitions.isEmpty()) { + return false; + } + parentPartitions.sort(new PartitionComparator()); + ByteString parentStartKey = parentPartitions.get(0).getStart(); + ByteString parentEndKey = parentPartitions.get(parentPartitions.size() - 1).getEnd(); + + return !childStartsBeforeParent(parentStartKey, childPartition.getStart()) + && !childEndsAfterParent(parentEndKey, childPartition.getEnd()) + && !gapsInParentPartitions(parentPartitions); + } + + /** + * Convert partitions to a string for debugging. + * + * @param partitions to print + * @return string representation of partitions + */ + public static String partitionsToString(List partitions) { + return partitions.stream() + .map(ByteStringRangeHelper::formatByteStringRange) + .collect(Collectors.joining(", ", "{", "}")); + } + + /** + * Figure out if partitions cover the entire keyspace. If it doesn't, return a list of missing and + * overlapping partitions. + * + *

partitions covers the entire key space if, when ordered, the end key is the same as the + * start key of the next row range in the list, and the first start key is "" and the last end key + * is "". There should be no overlap. + * + * @param partitions to determine if they cover entire keyspace + * @return list of missing and overlapping partitions + */ + public static List getMissingAndOverlappingPartitionsFromKeySpace( + List partitions) { + if (partitions.isEmpty()) { + return Collections.singletonList(ByteStringRange.create("", "")); + } + + List missingPartitions = new ArrayList<>(); + + // sort partitions by start key + // iterate through ensuring end key is lexicographically after next start key. + partitions.sort(new PartitionComparator()); + + ByteString prevEnd = ByteString.EMPTY; + for (ByteStringRange partition : partitions) { + if (!partition.getStart().equals(prevEnd)) { + ByteStringRange missingPartition = ByteStringRange.create(prevEnd, partition.getStart()); + missingPartitions.add(missingPartition); + } + prevEnd = partition.getEnd(); + } + // Check that the last partition ends with "", otherwise it's missing. + if (!prevEnd.equals(ByteString.EMPTY)) { + ByteStringRange missingPartition = ByteStringRange.create(prevEnd, ByteString.EMPTY); + missingPartitions.add(missingPartition); + } + return missingPartitions; + } + + /** + * Returns formatted string of a partition for debugging. + * + * @param partition partition to format. + * @return String representation of partition. + */ + public static String formatByteStringRange(ByteStringRange partition) { + return "['" + + TextFormat.escapeBytes(partition.getStart()) + + "','" + + TextFormat.escapeBytes(partition.getEnd()) + + "')"; + } + private static boolean childStartsBeforeParent( ByteString parentStartKey, ByteString childStartKey) { // Check if the start key of the child partition comes before the start key of the entire @@ -121,37 +196,4 @@ private static boolean gapsInParentPartitions(List sortedParent } return false; } - - /** - * Returns true if parentPartitions is a superset of childPartition. - * - *

If ordered parentPartitions row ranges form a contiguous range, and start key is before or - * at childPartition's start key, and end key is at or after childPartition's end key, then - * parentPartitions is a superset of childPartition. - * - *

Overlaps from parents are valid because arbitrary partitions can merge and they may overlap. - * They will form a valid new partition. However, if there are any missing parent partitions, then - * merge cannot happen with missing row ranges. - * - * @param parentPartitions list of partitions to determine if it forms a large contiguous range - * @param childPartition the smaller partition - * @return true if parentPartitions is a superset of childPartition, otherwise false. - */ - public static boolean isSuperset( - List parentPartitions, ByteStringRange childPartition) { - // sort parentPartitions by starting key - // iterate through, check open end key and close start key of each iteration to ensure no gaps. - // first start key and last end key must be equal to or wider than child partition start and end - // key. - if (parentPartitions.isEmpty()) { - return false; - } - parentPartitions.sort(new PartitionComparator()); - ByteString parentStartKey = parentPartitions.get(0).getStart(); - ByteString parentEndKey = parentPartitions.get(parentPartitions.size() - 1).getEnd(); - - return !childStartsBeforeParent(parentStartKey, childPartition.getStart()) - && !childEndsAfterParent(parentEndKey, childPartition.getEnd()) - && !gapsInParentPartitions(parentPartitions); - } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/DetectNewPartitionsAction.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/DetectNewPartitionsAction.java index cc05efcb9832..08a0b8837e49 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/DetectNewPartitionsAction.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/DetectNewPartitionsAction.java @@ -17,10 +17,23 @@ */ package org.apache.beam.sdk.io.gcp.bigtable.changestreams.action; +import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.getMissingAndOverlappingPartitionsFromKeySpace; +import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.partitionsToString; + +import com.google.api.gax.rpc.ServerStream; +import com.google.cloud.bigtable.data.v2.models.Range; +import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange; +import com.google.cloud.bigtable.data.v2.models.Row; import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.stream.Collectors; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper; import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics; import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableDao; +import org.apache.beam.sdk.io.gcp.bigtable.changestreams.encoder.MetadataTableEncoder; import org.apache.beam.sdk.io.gcp.bigtable.changestreams.model.PartitionRecord; import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; @@ -65,6 +78,83 @@ public DetectNewPartitionsAction( this.generateInitialPartitionsAction = generateInitialPartitionsAction; } + /** + * Periodically advances DetectNewPartition's (DNP) watermark based on the watermark of all the + * partitions recorded in the metadata table. We don't advance DNP's watermark on every run to + * "now" because of a possible inconsistent state. DNP's watermark is used to hold the entire + * pipeline's watermark back. + * + *

The low watermark is important to determine when to terminate the pipeline. During splits + * and merges, the watermark step may appear to be higher than it actually is. If a partition, at + * watermark 100, splits, it considered completed. If all other partitions including DNP have + * watermark beyond 100, the low watermark of the pipeline is higher than 100. However, the split + * partitions will have a watermark of 100 because they will resume from where the parent + * partition has stopped. But this would mean the low watermark of the pipeline needs to move + * backwards in time, which is not possible. + * + *

We "fix" this by using DNP's watermark to hold the pipeline's watermark down. DNP will + * periodically scan the metadata table for all the partitions watermarks. It only advances its + * watermark forward to the low watermark of the partitions. So in the case of a partition + * split/merge the low watermark of the pipeline is held back by DNP. We guarantee correctness by + * ensuring all the partitions exists in the metadata table in order to calculate the low + * watermark. It is possible that some partitions might be missing in between split and merges. + * + * @param tracker restriction tracker to guide how frequently watermark should be advanced + * @param watermarkEstimator watermark estimator to advance the watermark + */ + private void advanceWatermark( + RestrictionTracker tracker, + ManualWatermarkEstimator watermarkEstimator) + throws InvalidProtocolBufferException { + // We currently choose to update the watermark every 10 runs. We want to choose a number that is + // frequent, so the watermark isn't lagged behind too far. Also not too frequent so we do not + // overload the table with full table scans. + if (tracker.currentRestriction().getFrom() % 10 == 0) { + // Get partitions with a watermark set but skip rows w a lock and no watermark yet + ServerStream rows = metadataTableDao.readFromMdTableStreamPartitionsWithWatermark(); + List partitions = new ArrayList<>(); + HashMap slowPartitions = new HashMap<>(); + Instant lowWatermark = Instant.ofEpochMilli(Long.MAX_VALUE); + for (Row row : rows) { + Instant watermark = MetadataTableEncoder.parseWatermarkFromRow(row); + if (watermark == null) { + continue; + } + // Update low watermark if watermark < low watermark. + if (watermark.compareTo(lowWatermark) < 0) { + lowWatermark = watermark; + } + Range.ByteStringRange partition = + metadataTableDao.convertStreamPartitionRowKeyToPartition(row.getKey()); + partitions.add(partition); + if (watermark.plus(DEBUG_WATERMARK_DELAY).isBeforeNow()) { + slowPartitions.put(partition, watermark); + } + } + List missingAndOverlappingPartitions = + getMissingAndOverlappingPartitionsFromKeySpace(partitions); + if (missingAndOverlappingPartitions.isEmpty()) { + watermarkEstimator.setWatermark(lowWatermark); + LOG.info("DNP: Updating watermark: " + watermarkEstimator.currentWatermark()); + } else { + LOG.warn( + "DNP: Could not update watermark because missing {}", + partitionsToString(missingAndOverlappingPartitions)); + } + if (!slowPartitions.isEmpty()) { + LOG.warn( + "DNP: Watermark is being held back by the following partitions: {}", + slowPartitions.entrySet().stream() + .map( + e -> + ByteStringRangeHelper.formatByteStringRange(e.getKey()) + + " => " + + e.getValue()) + .collect(Collectors.joining(", ", "{", "}"))); + } + } + } + /** * Perform the necessary steps to manage initial set of partitions and new partitions. Currently, * we set to process new partitions every second. @@ -102,6 +192,8 @@ public ProcessContinuation run( return ProcessContinuation.stop(); } + advanceWatermark(tracker, watermarkEstimator); + return ProcessContinuation.resume().withResumeDelay(Duration.standardSeconds(1)); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java index 56bc7c3ab19d..eeeb3682ccf0 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDao.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao; +import static com.google.cloud.bigtable.data.v2.models.Filters.FILTERS; import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableAdminDao.DETECT_NEW_PARTITION_SUFFIX; import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableAdminDao.NEW_PARTITION_PREFIX; import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableAdminDao.STREAM_PARTITION_PREFIX; @@ -24,12 +25,13 @@ import com.google.api.gax.rpc.ServerStream; import com.google.cloud.bigtable.data.v2.BigtableDataClient; import com.google.cloud.bigtable.data.v2.models.ChangeStreamContinuationToken; -import com.google.cloud.bigtable.data.v2.models.Filters; import com.google.cloud.bigtable.data.v2.models.Query; import com.google.cloud.bigtable.data.v2.models.Range; +import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange; import com.google.cloud.bigtable.data.v2.models.Row; import com.google.cloud.bigtable.data.v2.models.RowMutation; import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Internal; import org.joda.time.Instant; @@ -90,6 +92,21 @@ private ByteString getFullDetectNewPartition() { return changeStreamNamePrefix.concat(DETECT_NEW_PARTITION_SUFFIX); } + /** + * Convert stream partition row key to partition to process metadata read from Bigtable. + * + *

RowKey should be directly from Cloud Bigtable and not altered in any way. + * + * @param rowKey row key from Cloud Bigtable + * @return partition extracted from rowKey + * @throws InvalidProtocolBufferException if conversion from rowKey to partition fails + */ + public ByteStringRange convertStreamPartitionRowKeyToPartition(ByteString rowKey) + throws InvalidProtocolBufferException { + int prefixLength = changeStreamNamePrefix.size() + STREAM_PARTITION_PREFIX.size(); + return ByteStringRange.toByteStringRange(rowKey.substring(prefixLength)); + } + /** * Convert partition to a Stream Partition row key to query for metadata of partitions that are * currently being streamed. @@ -125,7 +142,7 @@ public ServerStream readNewPartitions() { Query query = Query.create(tableId) .prefix(getFullNewPartitionPrefix()) - .filter(Filters.FILTERS.limit().cellsPerColumn(1)); + .filter(FILTERS.limit().cellsPerColumn(1)); return dataClient.readRows(query); } @@ -175,6 +192,25 @@ private void writeNewPartition( dataClient.mutateRow(rowMutation); } + /** + * @return stream of partitions currently being streamed by the beam job that have set a + * watermark. + */ + public ServerStream readFromMdTableStreamPartitionsWithWatermark() { + // We limit to the latest value per column. + Query query = + Query.create(tableId) + .prefix(getFullStreamPartitionPrefix()) + .filter( + FILTERS + .chain() + .filter(FILTERS.limit().cellsPerColumn(1)) + .filter(FILTERS.family().exactMatch(MetadataTableAdminDao.CF_WATERMARK)) + .filter( + FILTERS.qualifier().exactMatch(MetadataTableAdminDao.QUALIFIER_DEFAULT))); + return dataClient.readRows(query); + } + /** * Update the metadata for the rowKey. This helper adds necessary prefixes to the row key. * diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelperTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelperTest.java index 78e63a1a494c..f97442f00029 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelperTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/ByteStringRangeHelperTest.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.bigtable.changestreams; import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.formatByteStringRange; +import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.getMissingAndOverlappingPartitionsFromKeySpace; import static org.apache.beam.sdk.io.gcp.bigtable.changestreams.ByteStringRangeHelper.partitionsToString; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -172,6 +173,110 @@ public void testPartitionsToStringEmptyPartition() { assertEquals("{}", partitionsString); } + @Test + public void testGetMissingAndOverlappingPartitionFromKeySpaceEmptyPartition() { + List partitions = new ArrayList<>(); + List missingAndOverlappingPartitions = + getMissingAndOverlappingPartitionsFromKeySpace(partitions); + assertEquals( + Collections.singletonList(ByteStringRange.create("", "")), missingAndOverlappingPartitions); + } + + @Test + public void testGetMissingAndOverlappingPartitionFromKeySpaceSinglePartition() { + ByteStringRange partition1 = ByteStringRange.create("", ""); + List partitions = Collections.singletonList(partition1); + List missingAndOverlappingPartitions = + getMissingAndOverlappingPartitionsFromKeySpace(partitions); + assertEquals(Collections.emptyList(), missingAndOverlappingPartitions); + } + + @Test + public void testGetMissingAndOverlappingPartitionFromKeySpaceNoMissingPartition() { + ByteStringRange partition1 = ByteStringRange.create("", "A"); + ByteStringRange partition2 = ByteStringRange.create("A", "B"); + ByteStringRange partition3 = ByteStringRange.create("B", ""); + List partitions = Arrays.asList(partition1, partition2, partition3); + List missingAndOverlappingPartitions = + getMissingAndOverlappingPartitionsFromKeySpace(partitions); + assertEquals(Collections.emptyList(), missingAndOverlappingPartitions); + } + + @Test + public void testGetMissingAndOverlappingPartitionFromKeySpaceMissingStartPartition() { + ByteStringRange partition1 = ByteStringRange.create("", "A"); + ByteStringRange partition2 = ByteStringRange.create("A", "B"); + ByteStringRange partition3 = ByteStringRange.create("B", ""); + List partitions = Arrays.asList(partition2, partition3); + List missingAndOverlappingPartitions = + getMissingAndOverlappingPartitionsFromKeySpace(partitions); + assertEquals(Collections.singletonList(partition1), missingAndOverlappingPartitions); + } + + @Test + public void testGetMissingAndOverlappingPartitionFromKeySpaceMissingEndPartition() { + ByteStringRange partition1 = ByteStringRange.create("", "A"); + ByteStringRange partition2 = ByteStringRange.create("A", "B"); + ByteStringRange partition3 = ByteStringRange.create("B", ""); + List partitions = Arrays.asList(partition1, partition2); + List missingAndOverlappingPartitions = + getMissingAndOverlappingPartitionsFromKeySpace(partitions); + assertEquals(Collections.singletonList(partition3), missingAndOverlappingPartitions); + } + + @Test + public void testGetMissingAndOverlappingPartitionFromKeySpaceMissingMiddlePartition() { + ByteStringRange partition1 = ByteStringRange.create("", "A"); + ByteStringRange partition2 = ByteStringRange.create("A", "B"); + ByteStringRange partition3 = ByteStringRange.create("B", ""); + List partitions = Arrays.asList(partition1, partition3); + List missingAndOverlappingPartitions = + getMissingAndOverlappingPartitionsFromKeySpace(partitions); + assertEquals(Collections.singletonList(partition2), missingAndOverlappingPartitions); + } + + @Test + public void testGetMissingAndOverlappingPartitionFromKeySpaceOverlapPartition() { + ByteStringRange partition1 = ByteStringRange.create("", "B"); + ByteStringRange partition2 = ByteStringRange.create("A", ""); + List partitions = Arrays.asList(partition1, partition2); + List missingAndOverlappingPartitions = + getMissingAndOverlappingPartitionsFromKeySpace(partitions); + assertEquals( + Collections.singletonList(ByteStringRange.create("B", "A")), + missingAndOverlappingPartitions); + } + + @Test + public void testGetMissingAndOverlappingPartitionFromKeySpaceOverlapAndMissingPartition() { + ByteStringRange partition1 = ByteStringRange.create("", "B"); + ByteStringRange partition2 = ByteStringRange.create("C", "D"); + ByteStringRange partition3 = ByteStringRange.create("A", "C"); + ByteStringRange partition4 = ByteStringRange.create("E", ""); + ByteStringRange partition5 = ByteStringRange.create("C", "E"); + List partitions = + Arrays.asList(partition1, partition2, partition3, partition4, partition5); + List missingAndOverlappingPartitions = + getMissingAndOverlappingPartitionsFromKeySpace(partitions); + assertEquals( + Arrays.asList(ByteStringRange.create("B", "A"), ByteStringRange.create("D", "C")), + missingAndOverlappingPartitions); + } + + @Test + public void testGetMissingAndOverlappingPartitionsFromKeySpaceOverlapWithOpenEndKey() { + ByteStringRange fullKeySpace = ByteStringRange.create("", ""); + ByteStringRange partialKeySpace = ByteStringRange.create("n", ""); + List partitions = Arrays.asList(fullKeySpace, partialKeySpace); + // TODO come up with a better way to differentiate missing with start key "" and overlapping + // with end key "" + ByteStringRange overlappingPartition = ByteStringRange.create("", "n"); + List expectedOverlapping = Collections.singletonList(overlappingPartition); + List missingAndOverlappingPartitions = + getMissingAndOverlappingPartitionsFromKeySpace(partitions); + assertEquals(expectedOverlapping, missingAndOverlappingPartitions); + } + @Test public void testPartitionComparator() { ByteStringRange partition1 = ByteStringRange.create("", "a"); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/DetectNewPartitionsActionTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/DetectNewPartitionsActionTest.java index c2fbe8f8bd9c..449f188c2a95 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/DetectNewPartitionsActionTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/action/DetectNewPartitionsActionTest.java @@ -24,6 +24,7 @@ import com.google.cloud.bigtable.admin.v2.BigtableTableAdminSettings; import com.google.cloud.bigtable.data.v2.BigtableDataClient; import com.google.cloud.bigtable.data.v2.BigtableDataSettings; +import com.google.cloud.bigtable.data.v2.models.Range; import com.google.cloud.bigtable.emulator.v2.BigtableEmulatorRule; import java.io.IOException; import org.apache.beam.sdk.io.gcp.bigtable.changestreams.ChangeStreamMetrics; @@ -32,12 +33,14 @@ import org.apache.beam.sdk.io.gcp.bigtable.changestreams.dao.MetadataTableDao; import org.apache.beam.sdk.io.gcp.bigtable.changestreams.model.PartitionRecord; import org.apache.beam.sdk.io.range.OffsetRange; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators; +import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Before; import org.junit.BeforeClass; @@ -66,6 +69,7 @@ public class DetectNewPartitionsActionTest { private MetadataTableDao metadataTableDao; private ManualWatermarkEstimator watermarkEstimator; private Instant startTime; + private Instant partitionTime; private static BigtableDataClient dataClient; private static BigtableTableAdminClient adminClient; @@ -101,6 +105,7 @@ public void setUp() throws Exception { metadataTableAdminDao.getChangeStreamNamePrefix()); startTime = Instant.now(); + partitionTime = startTime.plus(Duration.standardSeconds(10)); action = new DetectNewPartitionsAction(metrics, metadataTableDao, generateInitialPartitionsAction); watermarkEstimator = new WatermarkEstimators.Manual(startTime); @@ -119,4 +124,60 @@ public void testInitialPartitions() throws Exception { ProcessContinuation.resume(), action.run(tracker, receiver, watermarkEstimator, bundleFinalizer, startTime)); } + + // Every 10 tryClaim, DNP updates the watermark based on the watermark of all the RCSP. + @Test + public void testAdvanceWatermarkWithAllPartitions() throws Exception { + // We advance watermark on every 10 restriction tracker advancement + OffsetRange offsetRange = new OffsetRange(10, Long.MAX_VALUE); + when(tracker.currentRestriction()).thenReturn(offsetRange); + when(tracker.tryClaim(offsetRange.getFrom())).thenReturn(true); + + assertEquals(startTime, watermarkEstimator.currentWatermark()); + + // Write 2 partitions to the table that covers entire keyspace. + Range.ByteStringRange partition1 = Range.ByteStringRange.create("", "b"); + Instant watermark1 = partitionTime.plus(Duration.millis(100)); + metadataTableDao.updateWatermark(partition1, watermark1, null); + Range.ByteStringRange partition2 = Range.ByteStringRange.create("b", ""); + Instant watermark2 = partitionTime.plus(Duration.millis(1)); + metadataTableDao.updateWatermark(partition2, watermark2, null); + + assertEquals( + DoFn.ProcessContinuation.resume().withResumeDelay(Duration.standardSeconds(1)), + action.run(tracker, receiver, watermarkEstimator, bundleFinalizer, startTime)); + + // Because the 2 partitions cover the entire keyspace, the watermark should have advanced. + // Also note the watermark is watermark2 which is the lowest of the 2 watermarks. + assertEquals(watermark2, watermarkEstimator.currentWatermark()); + } + + // Every 10 tryClaim, DNP only updates its watermark if all the RCSP currently streamed covers the + // entire key space. If there's any missing, they are in the process of split or merge. If the + // watermark is updated with missing partitions, the watermark might be further ahead than it + // actually is. + @Test + public void testAdvanceWatermarkWithMissingPartitions() throws Exception { + // We advance watermark on every 10 restriction tracker advancement + OffsetRange offsetRange = new OffsetRange(10, Long.MAX_VALUE); + when(tracker.currentRestriction()).thenReturn(offsetRange); + when(tracker.tryClaim(offsetRange.getFrom())).thenReturn(true); + + assertEquals(startTime, watermarkEstimator.currentWatermark()); + + // Write 2 partitions to the table that DO NOT cover the entire keyspace. + Range.ByteStringRange partition1 = Range.ByteStringRange.create("", "b"); + Instant watermark1 = partitionTime.plus(Duration.millis(100)); + metadataTableDao.updateWatermark(partition1, watermark1, null); + Range.ByteStringRange partition2 = Range.ByteStringRange.create("b", "c"); + Instant watermark2 = partitionTime.plus(Duration.millis(1)); + metadataTableDao.updateWatermark(partition2, watermark2, null); + + assertEquals( + DoFn.ProcessContinuation.resume().withResumeDelay(Duration.standardSeconds(1)), + action.run(tracker, receiver, watermarkEstimator, bundleFinalizer, startTime)); + + // Because the 2 partitions DO NOT cover the entire keyspace, watermark stays at startTime. + assertEquals(startTime, watermarkEstimator.currentWatermark()); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java index 1a14a923a94e..0c0d2e671f37 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/changestreams/dao/MetadataTableDaoTest.java @@ -32,6 +32,7 @@ import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import java.io.IOException; +import java.util.ArrayList; import org.apache.beam.sdk.io.gcp.bigtable.changestreams.UniqueIdGenerator; import org.apache.beam.sdk.io.gcp.bigtable.changestreams.encoder.MetadataTableEncoder; import org.joda.time.Instant; @@ -84,6 +85,46 @@ public void before() { metadataTableAdminDao.getChangeStreamNamePrefix()); } + @Test + public void testStreamPartitionRowKeyConversion() throws InvalidProtocolBufferException { + ByteStringRange rowRange = ByteStringRange.create("a", "b"); + ByteString rowKey = metadataTableDao.convertPartitionToStreamPartitionRowKey(rowRange); + assertEquals(rowRange, metadataTableDao.convertStreamPartitionRowKeyToPartition(rowKey)); + } + + @Test + public void testStreamPartitionRowKeyConversionWithIllegalUtf8() + throws InvalidProtocolBufferException { + // Test that the conversion is able to handle non-utf8 values. + byte[] nonUtf8Bytes = {(byte) 0b10001100}; + ByteString nonUtf8RowKey = ByteString.copyFrom(nonUtf8Bytes); + ByteStringRange rowRange = ByteStringRange.create(nonUtf8RowKey, nonUtf8RowKey); + ByteString rowKey = metadataTableDao.convertPartitionToStreamPartitionRowKey(rowRange); + assertEquals(rowRange, metadataTableDao.convertStreamPartitionRowKeyToPartition(rowKey)); + } + + @Test + public void testReadStreamPartitionsWithWatermark() throws InvalidProtocolBufferException { + ByteStringRange partitionWithWatermark = ByteStringRange.create("a", ""); + Instant watermark = Instant.now(); + metadataTableDao.updateWatermark(partitionWithWatermark, watermark, null); + + // This should only return rows where the watermark has been set, not rows that have been locked + // but have not yet set the first watermark + ServerStream rowsWithWatermark = + metadataTableDao.readFromMdTableStreamPartitionsWithWatermark(); + ArrayList metadataRows = new ArrayList<>(); + for (Row row : rowsWithWatermark) { + metadataRows.add(row); + } + assertEquals(1, metadataRows.size()); + Instant metadataWatermark = MetadataTableEncoder.parseWatermarkFromRow(metadataRows.get(0)); + assertEquals(watermark, metadataWatermark); + ByteStringRange rowKeyResponse = + metadataTableDao.convertStreamPartitionRowKeyToPartition(metadataRows.get(0).getKey()); + assertEquals(partitionWithWatermark, rowKeyResponse); + } + @Test public void testNewPartitionsWriteRead() throws InvalidProtocolBufferException { // This test a split of ["", "") to ["", "a") and ["a", "") @@ -124,6 +165,36 @@ public void testNewPartitionsWriteRead() throws InvalidProtocolBufferException { assertEquals(2, rowsCount); } + @Test + public void testUpdateAndReadWatermark() throws InvalidProtocolBufferException { + ByteStringRange partition1 = ByteStringRange.create("a", "b"); + Instant watermark1 = Instant.now(); + metadataTableDao.updateWatermark(partition1, watermark1, null); + ByteStringRange partition2 = ByteStringRange.create("b", "c"); + Instant watermark2 = Instant.now(); + metadataTableDao.updateWatermark(partition2, watermark2, null); + + ServerStream rows = metadataTableDao.readFromMdTableStreamPartitionsWithWatermark(); + int rowsCount = 0; + boolean matchedPartition1 = false; + boolean matchedPartition2 = false; + for (Row row : rows) { + rowsCount++; + ByteStringRange partition = + metadataTableDao.convertStreamPartitionRowKeyToPartition(row.getKey()); + if (partition.equals(partition1)) { + assertEquals(watermark1, MetadataTableEncoder.parseWatermarkFromRow(row)); + matchedPartition1 = true; + } else if (partition.equals(partition2)) { + assertEquals(watermark2, MetadataTableEncoder.parseWatermarkFromRow(row)); + matchedPartition2 = true; + } + } + assertEquals(2, rowsCount); + assertTrue(matchedPartition1); + assertTrue(matchedPartition2); + } + @Test public void testUpdateWatermark() { ByteStringRange partition = ByteStringRange.create("a", "b"); From ca0787642a6b3804a742326147281c99ae8d08d2 Mon Sep 17 00:00:00 2001 From: reuvenlax Date: Tue, 21 Mar 2023 15:57:38 -0700 Subject: [PATCH 10/13] Merge pull request #25723: #25722 Add option to propagate successful storage-api writes --- .../beam/sdk/transforms/GroupIntoBatches.java | 64 +++++++-- .../beam/sdk/io/gcp/bigquery/BatchLoads.java | 2 + .../beam/sdk/io/gcp/bigquery/BigQueryIO.java | 26 +++- .../io/gcp/bigquery/SplittingIterable.java | 46 +++++-- .../bigquery/StorageApiConvertMessages.java | 5 +- .../StorageApiDynamicDestinationsBeamRow.java | 4 +- ...geApiDynamicDestinationsGenericRecord.java | 4 +- .../sdk/io/gcp/bigquery/StorageApiLoads.java | 78 +++++++++-- .../gcp/bigquery/StorageApiWritePayload.java | 26 +++- .../StorageApiWriteRecordsInconsistent.java | 18 ++- .../StorageApiWriteUnshardedRecords.java | 123 +++++++++++++++--- .../StorageApiWritesShardedRecords.java | 73 ++++++++--- .../io/gcp/bigquery/StreamingWriteTables.java | 2 + .../beam/sdk/io/gcp/bigquery/WriteResult.java | 40 +++++- .../io/gcp/bigquery/BigQueryIOWriteTest.java | 6 + 15 files changed, 434 insertions(+), 83 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java index 78ce549c3b08..311a3dac6ca1 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java @@ -329,7 +329,7 @@ public long getElementByteSize() { @Override public PCollection>> expand(PCollection> input) { - + Duration allowedLateness = input.getWindowingStrategy().getAllowedLateness(); checkArgument( input.getCoder() instanceof KvCoder, "coder specified in the input PCollection is not a KvCoder"); @@ -344,6 +344,7 @@ public PCollection>> expand(PCollection> in params.getBatchSizeBytes(), weigher, params.getMaxBufferingDuration(), + allowedLateness, valueCoder))); } @@ -357,12 +358,20 @@ private static class GroupIntoBatchesDoFn @Nullable private final SerializableFunction weigher; private final Duration maxBufferingDuration; + private final Duration allowedLateness; + // The following timer is no longer set. We maintain the spec for update compatibility. private static final String END_OF_WINDOW_ID = "endOFWindow"; @TimerId(END_OF_WINDOW_ID) private final TimerSpec windowTimer = TimerSpecs.timer(TimeDomain.EVENT_TIME); + // This timer manages the watermark hold if there is no buffering timer. + private static final String TIMER_HOLD_ID = "watermarkHold"; + + @TimerId(TIMER_HOLD_ID) + private final TimerSpec holdTimerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME); + // This timer expires when it's time to batch and output the buffered data. private static final String END_OF_BUFFERING_ID = "endOfBuffering"; @@ -410,11 +419,13 @@ private static class GroupIntoBatchesDoFn long batchSizeBytes, @Nullable SerializableFunction weigher, Duration maxBufferingDuration, + Duration allowedLateness, Coder inputValueCoder) { this.batchSize = batchSize; this.batchSizeBytes = batchSizeBytes; this.weigher = weigher; this.maxBufferingDuration = maxBufferingDuration; + this.allowedLateness = allowedLateness; this.batchSpec = StateSpecs.bag(inputValueCoder); Combine.BinaryCombineLongFn sumCombineFn = @@ -452,9 +463,18 @@ public long apply(long left, long right) { this.prefetchFrequency = ((batchSize / 5) <= 1) ? Long.MAX_VALUE : (batchSize / 5); } + @Override + public Duration getAllowedTimestampSkew() { + // This is required since flush is sometimes called from processElement. This is safe because + // a watermark hold + // will always be set using timer.withOutputTimestamp. + return Duration.millis(Long.MAX_VALUE); + } + @ProcessElement public void processElement( @TimerId(END_OF_BUFFERING_ID) Timer bufferingTimer, + @TimerId(TIMER_HOLD_ID) Timer holdTimer, @StateId(BATCH_ID) BagState batch, @StateId(NUM_ELEMENTS_IN_BATCH_ID) CombiningState storedBatchSize, @StateId(NUM_BYTES_IN_BATCH_ID) CombiningState storedBatchSizeBytes, @@ -473,9 +493,10 @@ public void processElement( storedBatchSizeBytes.readLater(); } storedBatchSize.readLater(); - if (shouldCareAboutMaxBufferingDuration) { - minBufferedTs.readLater(); - } + minBufferedTs.readLater(); + + // Make sure we always include the current timestamp in the minBufferedTs. + minBufferedTs.add(elementTs.getMillis()); LOG.debug("*** BATCH *** Add element for window {} ", window); if (shouldCareAboutWeight) { @@ -505,23 +526,26 @@ public void processElement( timerTs, minBufferedTs); bufferingTimer.clear(); + holdTimer.clear(); } storedBatchSizeBytes.add(elementWeight); } batch.add(element.getValue()); // Blind add is supported with combiningState storedBatchSize.add(1L); + // Add the timestamp back into minBufferedTs as it might be cleared by flushBatch above. + minBufferedTs.add(elementTs.getMillis()); final long num = storedBatchSize.read(); - if (shouldCareAboutMaxBufferingDuration) { - long oldOutputTs = - MoreObjects.firstNonNull( - minBufferedTs.read(), BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()); - minBufferedTs.add(elementTs.getMillis()); - // If this is the first element in the batch or if the timer's output timestamp needs - // modifying, then set a - // timer. - if (num == 1 || minBufferedTs.read() != oldOutputTs) { + + // If this is the first element in the batch or if the timer's output timestamp needs + // modifying, then set a timer. + long oldOutputTs = + MoreObjects.firstNonNull( + minBufferedTs.read(), BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()); + boolean needsNewTimer = num == 1 || minBufferedTs.read() != oldOutputTs; + if (needsNewTimer) { + if (shouldCareAboutMaxBufferingDuration) { long targetTs = MoreObjects.firstNonNull( timerTs.read(), @@ -530,6 +554,12 @@ public void processElement( bufferingTimer .withOutputTimestamp(Instant.ofEpochMilli(minBufferedTs.read())) .set(Instant.ofEpochMilli(targetTs)); + } else { + // The only way to hold the watermark is to set a timer. Since there is no buffering + // timer, we set a dummy + // timer at the end of the window to manage the hold. + Instant windowEnd = window.maxTimestamp().plus(allowedLateness); + holdTimer.withOutputTimestamp(Instant.ofEpochMilli(minBufferedTs.read())).set(windowEnd); } } @@ -585,6 +615,11 @@ public void onWindowExpiration( receiver, key, batch, storedBatchSize, storedBatchSizeBytes, timerTs, minBufferedTs); } + @OnTimer(TIMER_HOLD_ID) + public void onHoldTimer() { + // Do nothing. The associated watermark hold will be automatically removed. + } + // We no longer set this timer, since OnWindowExpiration takes care of his. However we leave the // callback in place // for existing jobs that have already set these timers. @@ -618,7 +653,8 @@ private void flushBatch( Iterable values = batch.read(); // When the timer fires, batch state might be empty if (!Iterables.isEmpty(values)) { - receiver.output(KV.of(key, values)); + receiver.outputWithTimestamp( + KV.of(key, values), Instant.ofEpochMilli(minBufferedTs.read())); } clearState(batch, storedBatchSize, storedBatchSizeBytes, timerTs, minBufferedTs); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java index 9ba2a83d7b6d..bad44ee5b5c1 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java @@ -865,6 +865,8 @@ private WriteResult writeResult(Pipeline p, PCollection succes new TupleTag<>("successfulInserts"), successfulWrites, null, + null, + null, null); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index 6745f7aceea0..1804af1ab987 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -2094,6 +2094,7 @@ public static Write write() { .setAutoSchemaUpdate(false) .setDeterministicRecordIdFn(null) .setMaxRetryJobs(1000) + .setPropagateSuccessfulStorageApiWrites(false) .build(); } @@ -2211,6 +2212,8 @@ public enum Method { abstract int getNumStorageWriteApiStreams(); + abstract boolean getPropagateSuccessfulStorageApiWrites(); + abstract int getMaxFilesPerPartition(); abstract long getMaxBytesPerPartition(); @@ -2306,6 +2309,9 @@ abstract Builder setAvroSchemaFactory( abstract Builder setNumStorageWriteApiStreams(int numStorageApiStreams); + abstract Builder setPropagateSuccessfulStorageApiWrites( + boolean propagateSuccessfulStorageApiWrites); + abstract Builder setMaxFilesPerPartition(int maxFilesPerPartition); abstract Builder setMaxBytesPerPartition(long maxBytesPerPartition); @@ -2763,6 +2769,17 @@ public Write withNumStorageWriteApiStreams(int numStorageWriteApiStreams) { return toBuilder().setNumStorageWriteApiStreams(numStorageWriteApiStreams).build(); } + /** + * If set to true, then all successful writes will be propagated to {@link WriteResult} and + * accessible via the {@link WriteResult#getSuccessfulStorageApiInserts} method. + */ + public Write withPropagateSuccessfulStorageApiWrites( + boolean propagateSuccessfulStorageApiWrites) { + return toBuilder() + .setPropagateSuccessfulStorageApiWrites(propagateSuccessfulStorageApiWrites) + .build(); + } + /** * Provides a custom location on GCS for storing temporary files to be loaded via BigQuery batch * load jobs. See "Usage with templates" in {@link BigQueryIO} documentation for discussion. @@ -3270,6 +3287,9 @@ private WriteResult continueExpandTyped( checkArgument( getSchemaUpdateOptions() == null || getSchemaUpdateOptions().isEmpty(), "SchemaUpdateOptions are not supported when method == STREAMING_INSERTS"); + checkArgument( + !getPropagateSuccessfulStorageApiWrites(), + "withPropagateSuccessfulStorageApiWrites only supported when using storage api writes."); RowWriterFactory.TableRowWriterFactory tableRowWriterFactory = (RowWriterFactory.TableRowWriterFactory) rowWriterFactory; @@ -3301,6 +3321,9 @@ private WriteResult continueExpandTyped( rowWriterFactory.getOutputType() == OutputType.AvroGenericRecord, "useAvroLogicalTypes can only be set with Avro output."); } + checkArgument( + !getPropagateSuccessfulStorageApiWrites(), + "withPropagateSuccessfulStorageApiWrites only supported when using storage api writes."); // Batch load jobs currently support JSON data insertion only with CSV files if (getJsonSchema() != null && getJsonSchema().isAccessible()) { @@ -3406,7 +3429,8 @@ private WriteResult continueExpandTyped( method == Method.STORAGE_API_AT_LEAST_ONCE, getAutoSharding(), getAutoSchemaUpdate(), - getIgnoreUnknownValues()); + getIgnoreUnknownValues(), + getPropagateSuccessfulStorageApiWrites()); return input.apply("StorageApiLoads", storageApiLoads); } else { throw new RuntimeException("Unexpected write method " + method); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SplittingIterable.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SplittingIterable.java index 4b4978bb30fe..a7de876b98fa 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SplittingIterable.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SplittingIterable.java @@ -18,20 +18,32 @@ package org.apache.beam.sdk.io.gcp.bigquery; import com.google.api.services.bigquery.model.TableRow; +import com.google.auto.value.AutoValue; import com.google.cloud.bigquery.storage.v1.ProtoRows; import com.google.protobuf.ByteString; import java.util.Iterator; +import java.util.List; import java.util.NoSuchElementException; import java.util.function.BiConsumer; import java.util.function.Function; -import javax.annotation.Nullable; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; /** * Takes in an iterable and batches the results into multiple ProtoRows objects. The splitSize * parameter controls how many rows are batched into a single ProtoRows object before we move on to * the next one. */ -class SplittingIterable implements Iterable { +class SplittingIterable implements Iterable { + @AutoValue + abstract static class Value { + abstract ProtoRows getProtoRows(); + + abstract List getTimestamps(); + } + interface ConvertUnknownFields { ByteString convert(TableRow tableRow, boolean ignoreUnknownValues) throws TableRowToStorageApiProto.SchemaConversionException; @@ -42,18 +54,21 @@ ByteString convert(TableRow tableRow, boolean ignoreUnknownValues) private final ConvertUnknownFields unknownFieldsToMessage; private final Function protoToTableRow; - private final BiConsumer failedRowsConsumer; + private final BiConsumer, String> failedRowsConsumer; private final boolean autoUpdateSchema; private final boolean ignoreUnknownValues; + private final Instant elementsTimestamp; + public SplittingIterable( Iterable underlying, long splitSize, ConvertUnknownFields unknownFieldsToMessage, Function protoToTableRow, - BiConsumer failedRowsConsumer, + BiConsumer, String> failedRowsConsumer, boolean autoUpdateSchema, - boolean ignoreUnknownValues) { + boolean ignoreUnknownValues, + Instant elementsTimestamp) { this.underlying = underlying; this.splitSize = splitSize; this.unknownFieldsToMessage = unknownFieldsToMessage; @@ -61,11 +76,12 @@ public SplittingIterable( this.failedRowsConsumer = failedRowsConsumer; this.autoUpdateSchema = autoUpdateSchema; this.ignoreUnknownValues = ignoreUnknownValues; + this.elementsTimestamp = elementsTimestamp; } @Override - public Iterator iterator() { - return new Iterator() { + public Iterator iterator() { + return new Iterator() { final Iterator underlyingIterator = underlying.iterator(); @Override @@ -74,11 +90,12 @@ public boolean hasNext() { } @Override - public ProtoRows next() { + public Value next() { if (!hasNext()) { throw new NoSuchElementException(); } + List timestamps = Lists.newArrayList(); ProtoRows.Builder inserts = ProtoRows.newBuilder(); long bytesSize = 0; while (underlyingIterator.hasNext()) { @@ -107,7 +124,11 @@ public ProtoRows next() { // 24926 is fixed, we need to merge the unknownFields back into the main row // before outputting to the // failed-rows consumer. - failedRowsConsumer.accept(tableRow, e.toString()); + Instant timestamp = payload.getTimestamp(); + if (timestamp == null) { + timestamp = elementsTimestamp; + } + failedRowsConsumer.accept(TimestampedValue.of(tableRow, timestamp), e.toString()); continue; } } @@ -116,12 +137,17 @@ public ProtoRows next() { } } inserts.addSerializedRows(byteString); + Instant timestamp = payload.getTimestamp(); + if (timestamp == null) { + timestamp = elementsTimestamp; + } + timestamps.add(timestamp); bytesSize += byteString.size(); if (bytesSize > splitSize) { break; } } - return inserts.build(); + return new AutoValue_SplittingIterable_Value(inserts.build(), timestamps); } }; } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages.java index a14409788b2d..fa16df4c4626 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiConvertMessages.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.values.TupleTagList; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; /** * A transform that converts messages to protocol buffers in preparation for writing to BigQuery. @@ -129,6 +130,7 @@ public void processElement( ProcessContext c, PipelineOptions pipelineOptions, @Element KV element, + @Timestamp Instant timestamp, MultiOutputReceiver o) throws Exception { dynamicDestinations.setSideInputAccessorFromProcessContext(c); @@ -136,7 +138,8 @@ public void processElement( messageConverters.get( element.getKey(), dynamicDestinations, getDatasetService(pipelineOptions)); try { - StorageApiWritePayload payload = messageConverter.toMessage(element.getValue()); + StorageApiWritePayload payload = + messageConverter.toMessage(element.getValue()).withTimestamp(timestamp); o.get(successfulWritesTag).output(KV.of(element.getKey(), payload)); } catch (TableRowToStorageApiProto.SchemaConversionException e) { TableRow tableRow = messageConverter.toTableRow(element.getValue()); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDynamicDestinationsBeamRow.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDynamicDestinationsBeamRow.java index 6e9f75d15dfe..e56e156c20f8 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDynamicDestinationsBeamRow.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDynamicDestinationsBeamRow.java @@ -62,9 +62,9 @@ public TableSchema getTableSchema() { @Override @SuppressWarnings("nullness") - public StorageApiWritePayload toMessage(T element) { + public StorageApiWritePayload toMessage(T element) throws Exception { Message msg = BeamRowToStorageApiProto.messageFromBeamRow(descriptor, toRow.apply(element)); - return new AutoValue_StorageApiWritePayload(msg.toByteArray(), null); + return StorageApiWritePayload.of(msg.toByteArray(), null); } @Override diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDynamicDestinationsGenericRecord.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDynamicDestinationsGenericRecord.java index 98684db558bb..bb0a1236e1f4 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDynamicDestinationsGenericRecord.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDynamicDestinationsGenericRecord.java @@ -68,11 +68,11 @@ class GenericRecordConverter implements MessageConverter { @Override @SuppressWarnings("nullness") - public StorageApiWritePayload toMessage(T element) { + public StorageApiWritePayload toMessage(T element) throws Exception { Message msg = AvroGenericRecordToStorageApiProto.messageFromGenericRecord( descriptor, toGenericRecord.apply(new AvroWriteRequest<>(element, avroSchema))); - return new AutoValue_StorageApiWritePayload(msg.toByteArray(), null); + return StorageApiWritePayload.of(msg.toByteArray(), null); } @Override diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java index a8d133c0c3a9..30f624618127 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java @@ -17,8 +17,10 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import com.google.api.services.bigquery.model.TableRow; import java.nio.ByteBuffer; import java.util.concurrent.ThreadLocalRandom; +import javax.annotation.Nullable; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; @@ -41,10 +43,11 @@ /** This {@link PTransform} manages loads into BigQuery using the Storage API. */ public class StorageApiLoads extends PTransform>, WriteResult> { - final TupleTag> successfulRowsTag = + final TupleTag> successfulConvertedRowsTag = new TupleTag<>("successfulRows"); final TupleTag failedRowsTag = new TupleTag<>("failedRows"); + @Nullable TupleTag successfulWrittenRowsTag; private final Coder destinationCoder; private final StorageApiDynamicDestinations dynamicDestinations; private final CreateDisposition createDisposition; @@ -68,7 +71,8 @@ public StorageApiLoads( boolean allowInconsistentWrites, boolean allowAutosharding, boolean autoUpdateSchema, - boolean ignoreUnknownValues) { + boolean ignoreUnknownValues, + boolean propagateSuccessfulStorageApiWrites) { this.destinationCoder = destinationCoder; this.dynamicDestinations = dynamicDestinations; this.createDisposition = createDisposition; @@ -80,6 +84,9 @@ public StorageApiLoads( this.allowAutosharding = allowAutosharding; this.autoUpdateSchema = autoUpdateSchema; this.ignoreUnknownValues = ignoreUnknownValues; + if (propagateSuccessfulStorageApiWrites) { + this.successfulWrittenRowsTag = new TupleTag<>("successfulPublishedRowsTag"); + } } @Override @@ -120,19 +127,21 @@ public WriteResult expandInconsistent( dynamicDestinations, bqServices, failedRowsTag, - successfulRowsTag, + successfulConvertedRowsTag, BigQueryStorageApiInsertErrorCoder.of(), successCoder)); PCollectionTuple writeRecordsResult = convertMessagesResult - .get(successfulRowsTag) + .get(successfulConvertedRowsTag) .apply( "StorageApiWriteInconsistent", new StorageApiWriteRecordsInconsistent<>( dynamicDestinations, bqServices, failedRowsTag, + successfulWrittenRowsTag, BigQueryStorageApiInsertErrorCoder.of(), + TableRowJsonCoder.of(), autoUpdateSchema, ignoreUnknownValues)); @@ -140,8 +149,21 @@ public WriteResult expandInconsistent( PCollectionList.of(convertMessagesResult.get(failedRowsTag)) .and(writeRecordsResult.get(failedRowsTag)) .apply("flattenErrors", Flatten.pCollections()); + @Nullable PCollection successfulWrittenRows = null; + if (successfulWrittenRowsTag != null) { + successfulWrittenRows = writeRecordsResult.get(successfulWrittenRowsTag); + } return WriteResult.in( - input.getPipeline(), null, null, null, null, null, failedRowsTag, insertErrors); + input.getPipeline(), + null, + null, + null, + null, + null, + failedRowsTag, + insertErrors, + successfulWrittenRowsTag, + successfulWrittenRows); } public WriteResult expandTriggered( @@ -163,7 +185,7 @@ public WriteResult expandTriggered( dynamicDestinations, bqServices, failedRowsTag, - successfulRowsTag, + successfulConvertedRowsTag, BigQueryStorageApiInsertErrorCoder.of(), successCoder)); @@ -178,7 +200,7 @@ public WriteResult expandTriggered( if (this.allowAutosharding) { groupedRecords = convertMessagesResult - .get(successfulRowsTag) + .get(successfulConvertedRowsTag) .apply( "GroupIntoBatches", GroupIntoBatches.ofByteSize( @@ -208,7 +230,9 @@ public WriteResult expandTriggered( bqServices, destinationCoder, BigQueryStorageApiInsertErrorCoder.of(), + TableRowJsonCoder.of(), failedRowsTag, + successfulWrittenRowsTag, autoUpdateSchema, ignoreUnknownValues)); @@ -217,14 +241,28 @@ public WriteResult expandTriggered( .and(writeRecordsResult.get(failedRowsTag)) .apply("flattenErrors", Flatten.pCollections()); + @Nullable PCollection successfulWrittenRows = null; + if (successfulWrittenRowsTag != null) { + successfulWrittenRows = writeRecordsResult.get(successfulWrittenRowsTag); + } + return WriteResult.in( - input.getPipeline(), null, null, null, null, null, failedRowsTag, insertErrors); + input.getPipeline(), + null, + null, + null, + null, + null, + failedRowsTag, + insertErrors, + successfulWrittenRowsTag, + successfulWrittenRows); } private PCollection, StorageApiWritePayload>> createShardedKeyValuePairs(PCollectionTuple pCollection) { return pCollection - .get(successfulRowsTag) + .get(successfulConvertedRowsTag) .apply( "AddShard", ParDo.of( @@ -268,20 +306,22 @@ public WriteResult expandUntriggered( dynamicDestinations, bqServices, failedRowsTag, - successfulRowsTag, + successfulConvertedRowsTag, BigQueryStorageApiInsertErrorCoder.of(), successCoder)); PCollectionTuple writeRecordsResult = convertMessagesResult - .get(successfulRowsTag) + .get(successfulConvertedRowsTag) .apply( "StorageApiWriteUnsharded", new StorageApiWriteUnshardedRecords<>( dynamicDestinations, bqServices, failedRowsTag, + successfulWrittenRowsTag, BigQueryStorageApiInsertErrorCoder.of(), + TableRowJsonCoder.of(), autoUpdateSchema, ignoreUnknownValues)); @@ -290,7 +330,21 @@ public WriteResult expandUntriggered( .and(writeRecordsResult.get(failedRowsTag)) .apply("flattenErrors", Flatten.pCollections()); + @Nullable PCollection successfulWrittenRows = null; + if (successfulWrittenRowsTag != null) { + successfulWrittenRows = writeRecordsResult.get(successfulWrittenRowsTag); + } + return WriteResult.in( - input.getPipeline(), null, null, null, null, null, failedRowsTag, insertErrors); + input.getPipeline(), + null, + null, + null, + null, + null, + failedRowsTag, + insertErrors, + successfulWrittenRowsTag, + successfulWrittenRows); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritePayload.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritePayload.java index 85a0c3b4fe61..5b6f27949870 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritePayload.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritePayload.java @@ -25,6 +25,7 @@ import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.util.CoderUtils; +import org.joda.time.Instant; /** Class used to wrap elements being sent to the Storage API sinks. */ @AutoValue @@ -36,6 +37,21 @@ public abstract class StorageApiWritePayload { @SuppressWarnings("mutable") public abstract @Nullable byte[] getUnknownFieldsPayload(); + public abstract @Nullable Instant getTimestamp(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setPayload(byte[] value); + + public abstract Builder setUnknownFieldsPayload(@Nullable byte[] value); + + public abstract Builder setTimestamp(@Nullable Instant value); + + public abstract StorageApiWritePayload build(); + } + + public abstract Builder toBuilder(); + @SuppressWarnings("nullness") static StorageApiWritePayload of(byte[] payload, @Nullable TableRow unknownFields) throws IOException { @@ -43,7 +59,15 @@ static StorageApiWritePayload of(byte[] payload, @Nullable TableRow unknownField if (unknownFields != null) { unknownFieldsPayload = CoderUtils.encodeToByteArray(TableRowJsonCoder.of(), unknownFields); } - return new AutoValue_StorageApiWritePayload(payload, unknownFieldsPayload); + return new AutoValue_StorageApiWritePayload.Builder() + .setPayload(payload) + .setUnknownFieldsPayload(unknownFieldsPayload) + .setTimestamp(null) + .build(); + } + + public StorageApiWritePayload withTimestamp(Instant instant) { + return toBuilder().setTimestamp(instant).build(); } public @Memoized @Nullable TableRow getUnknownFields() throws IOException { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java index 343b7e1c81a7..7c6445fcf11c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWriteRecordsInconsistent.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import com.google.api.services.bigquery.model.TableRow; +import javax.annotation.Nullable; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -38,8 +40,10 @@ public class StorageApiWriteRecordsInconsistent private final StorageApiDynamicDestinations dynamicDestinations; private final BigQueryServices bqServices; private final TupleTag failedRowsTag; + private final @Nullable TupleTag successfulRowsTag; private final TupleTag> finalizeTag = new TupleTag<>("finalizeTag"); private final Coder failedRowsCoder; + private final Coder successfulRowsCoder; private final boolean autoUpdateSchema; private final boolean ignoreUnknownValues; @@ -47,13 +51,17 @@ public StorageApiWriteRecordsInconsistent( StorageApiDynamicDestinations dynamicDestinations, BigQueryServices bqServices, TupleTag failedRowsTag, + @Nullable TupleTag successfulRowsTag, Coder failedRowsCoder, + Coder successfulRowsCoder, boolean autoUpdateSchema, boolean ignoreUnknownValues) { this.dynamicDestinations = dynamicDestinations; this.bqServices = bqServices; this.failedRowsTag = failedRowsTag; this.failedRowsCoder = failedRowsCoder; + this.successfulRowsCoder = successfulRowsCoder; + this.successfulRowsTag = successfulRowsTag; this.autoUpdateSchema = autoUpdateSchema; this.ignoreUnknownValues = ignoreUnknownValues; } @@ -63,6 +71,10 @@ public PCollectionTuple expand(PCollection private final StorageApiDynamicDestinations dynamicDestinations; private final BigQueryServices bqServices; private final TupleTag failedRowsTag; + private final @Nullable TupleTag successfulRowsTag; private final TupleTag> finalizeTag = new TupleTag<>("finalizeTag"); private final Coder failedRowsCoder; + private final Coder successfulRowsCoder; private final boolean autoUpdateSchema; private final boolean ignoreUnknownValues; private static final ExecutorService closeWriterExecutor = Executors.newCachedThreadPool(); @@ -147,13 +149,17 @@ public StorageApiWriteUnshardedRecords( StorageApiDynamicDestinations dynamicDestinations, BigQueryServices bqServices, TupleTag failedRowsTag, + @Nullable TupleTag successfulRowsTag, Coder failedRowsCoder, + Coder successfulRowsCoder, boolean autoUpdateSchema, boolean ignoreUnknownValues) { this.dynamicDestinations = dynamicDestinations; this.bqServices = bqServices; this.failedRowsTag = failedRowsTag; + this.successfulRowsTag = successfulRowsTag; this.failedRowsCoder = failedRowsCoder; + this.successfulRowsCoder = successfulRowsCoder; this.autoUpdateSchema = autoUpdateSchema; this.ignoreUnknownValues = ignoreUnknownValues; } @@ -165,6 +171,10 @@ public PCollectionTuple expand(PCollection private final Counter forcedFlushes = Metrics.counter(WriteRecordsDoFn.class, "forcedFlushes"); private final TupleTag> finalizeTag; private final TupleTag failedRowsTag; + private final @Nullable TupleTag successfulRowsTag; private final boolean autoUpdateSchema; private final boolean ignoreUnknownValues; static class AppendRowsContext extends RetryManager.Operation.Context { long offset; ProtoRows protoRows; + List timestamps; - public AppendRowsContext(long offset, ProtoRows protoRows) { + public AppendRowsContext( + long offset, ProtoRows protoRows, List timestamps) { this.offset = offset; this.protoRows = protoRows; + this.timestamps = timestamps; } } @@ -220,6 +238,7 @@ class DestinationState { private @Nullable AppendClientInfo appendClientInfo = null; private long currentOffset = 0; private List pendingMessages; + private List pendingTimestamps; private transient @Nullable DatasetService maybeDatasetService; private final Counter recordsAppended = Metrics.counter(WriteRecordsDoFn.class, "recordsAppended"); @@ -250,6 +269,7 @@ public DestinationState( throws Exception { this.tableUrn = tableUrn; this.pendingMessages = Lists.newArrayList(); + this.pendingTimestamps = Lists.newArrayList(); this.maybeDatasetService = datasetService; this.useDefaultStream = useDefaultStream; this.initialTableSchema = messageConverter.getTableSchema(); @@ -404,6 +424,7 @@ void invalidateWriteStream() { void addMessage( StorageApiWritePayload payload, + org.joda.time.Instant elementTs, OutputReceiver failedRowsReceiver) throws Exception { maybeTickleCache(); @@ -428,17 +449,23 @@ void addMessage( // 24926 is fixed, we need to merge the unknownFields back into the main row before // outputting to the // failed-rows consumer. - failedRowsReceiver.output(new BigQueryStorageApiInsertError(tableRow, e.toString())); + org.joda.time.Instant timestamp = payload.getTimestamp(); + failedRowsReceiver.outputWithTimestamp( + new BigQueryStorageApiInsertError(tableRow, e.toString()), + timestamp != null ? timestamp : elementTs); return; } } } pendingMessages.add(payloadBytes); + org.joda.time.Instant timestamp = payload.getTimestamp(); + pendingTimestamps.add(timestamp != null ? timestamp : elementTs); } long flush( RetryManager retryManager, - OutputReceiver failedRowsReceiver) + OutputReceiver failedRowsReceiver, + @Nullable OutputReceiver successfulRowsReceiver) throws Exception { if (pendingMessages.isEmpty()) { return 0; @@ -447,7 +474,8 @@ long flush( final ProtoRows.Builder insertsBuilder = ProtoRows.newBuilder(); insertsBuilder.addAllSerializedRows(pendingMessages); final ProtoRows inserts = insertsBuilder.build(); - pendingMessages.clear(); + List insertTimestamps = pendingTimestamps; + pendingTimestamps = Lists.newArrayList(); // Handle the case where the request is too large. if (inserts.getSerializedSize() >= maxRequestSize) { @@ -461,14 +489,17 @@ long flush( + maxRequestSize + ". This is unexpected. All rows in the request will be sent to the failed-rows PCollection."); } - for (ByteString rowBytes : inserts.getSerializedRowsList()) { + for (int i = 0; i < inserts.getSerializedRowsCount(); ++i) { + ByteString rowBytes = inserts.getSerializedRows(i); + org.joda.time.Instant timestamp = insertTimestamps.get(i); TableRow failedRow = TableRowToStorageApiProto.tableRowFromMessage( DynamicMessage.parseFrom( getAppendClientInfo(true, null).getDescriptor(), rowBytes)); - failedRowsReceiver.output( + failedRowsReceiver.outputWithTimestamp( new BigQueryStorageApiInsertError( - failedRow, "Row payload too large. Maximum size " + maxRequestSize)); + failedRow, "Row payload too large. Maximum size " + maxRequestSize), + timestamp); } return 0; } @@ -478,7 +509,8 @@ long flush( offset = this.currentOffset; this.currentOffset += inserts.getSerializedRowsCount(); } - AppendRowsContext appendRowsContext = new AppendRowsContext(offset, inserts); + AppendRowsContext appendRowsContext = + new AppendRowsContext(offset, inserts, insertTimestamps); retryManager.addOperation( c -> { @@ -518,15 +550,17 @@ long flush( for (int failedIndex : failedRowIndices) { // Convert the message to a TableRow and send it to the failedRows collection. ByteString protoBytes = failedContext.protoRows.getSerializedRows(failedIndex); + org.joda.time.Instant timestamp = failedContext.timestamps.get(failedIndex); try { TableRow failedRow = TableRowToStorageApiProto.tableRowFromMessage( DynamicMessage.parseFrom( Preconditions.checkStateNotNull(appendClientInfo).getDescriptor(), protoBytes)); - failedRowsReceiver.output( + failedRowsReceiver.outputWithTimestamp( new BigQueryStorageApiInsertError( - failedRow, error.getRowIndexToErrorMessage().get(failedIndex))); + failedRow, error.getRowIndexToErrorMessage().get(failedIndex)), + timestamp); } catch (InvalidProtocolBufferException e) { LOG.error("Failed to insert row and could not parse the result!"); } @@ -536,13 +570,16 @@ long flush( // Remove the failed row from the payload, so we retry the batch without the failed // rows. ProtoRows.Builder retryRows = ProtoRows.newBuilder(); + List retryTimestamps = Lists.newArrayList(); for (int i = 0; i < failedContext.protoRows.getSerializedRowsCount(); ++i) { if (!failedRowIndices.contains(i)) { ByteString rowBytes = failedContext.protoRows.getSerializedRows(i); retryRows.addSerializedRows(rowBytes); + retryTimestamps.add(failedContext.timestamps.get(i)); } } failedContext.protoRows = retryRows.build(); + failedContext.timestamps = retryTimestamps; // Since we removed rows, we need to update the insert offsets for all remaining // rows. @@ -564,7 +601,25 @@ long flush( appendFailures.inc(); return RetryType.RETRY_ALL_OPERATIONS; }, - c -> recordsAppended.inc(c.protoRows.getSerializedRowsCount()), + c -> { + recordsAppended.inc(c.protoRows.getSerializedRowsCount()); + if (successfulRowsReceiver != null) { + for (int i = 0; i < c.protoRows.getSerializedRowsCount(); ++i) { + ByteString rowBytes = c.protoRows.getSerializedRowsList().get(i); + try { + TableRow row = + TableRowToStorageApiProto.tableRowFromMessage( + DynamicMessage.parseFrom( + Preconditions.checkStateNotNull(appendClientInfo).getDescriptor(), + rowBytes)); + org.joda.time.Instant timestamp = c.timestamps.get(i); + successfulRowsReceiver.outputWithTimestamp(row, timestamp); + } catch (InvalidProtocolBufferException e) { + LOG.warn("Failure parsing TableRow: " + e); + } + } + } + }, appendRowsContext); maybeTickleCache(); return inserts.getSerializedRowsCount(); @@ -623,6 +678,7 @@ void postFlush() { int streamAppendClientCount, TupleTag> finalizeTag, TupleTag failedRowsTag, + @Nullable TupleTag successfulRowsTag, boolean autoUpdateSchema, boolean ignoreUnknownValues) { this.messageConverters = new TwoLevelMessageConverterCache<>(operationName); @@ -634,6 +690,7 @@ void postFlush() { this.streamAppendClientCount = streamAppendClientCount; this.finalizeTag = finalizeTag; this.failedRowsTag = failedRowsTag; + this.successfulRowsTag = successfulRowsTag; this.autoUpdateSchema = autoUpdateSchema; this.ignoreUnknownValues = ignoreUnknownValues; } @@ -642,18 +699,22 @@ boolean shouldFlush() { return numPendingRecords > flushThresholdCount || numPendingRecordBytes > flushThresholdBytes; } - void flushIfNecessary(OutputReceiver failedRowsReceiver) + void flushIfNecessary( + OutputReceiver failedRowsReceiver, + @Nullable OutputReceiver successfulRowsReceiver) throws Exception { if (shouldFlush()) { forcedFlushes.inc(); // Too much memory being used. Flush the state and wait for it to drain out. // TODO(reuvenlax): Consider waiting for memory usage to drop instead of waiting for all the // appends to finish. - flushAll(failedRowsReceiver); + flushAll(failedRowsReceiver, successfulRowsReceiver); } } - void flushAll(OutputReceiver failedRowsReceiver) + void flushAll( + OutputReceiver failedRowsReceiver, + @Nullable OutputReceiver successfulRowsReceiver) throws Exception { List> retryManagers = Lists.newArrayListWithCapacity(Preconditions.checkStateNotNull(destinations).size()); @@ -663,7 +724,8 @@ void flushAll(OutputReceiver failedRowsReceiver) RetryManager retryManager = new RetryManager<>(Duration.standardSeconds(1), Duration.standardSeconds(10), 1000); retryManagers.add(retryManager); - numRowsWritten += destinationState.flush(retryManager, failedRowsReceiver); + numRowsWritten += + destinationState.flush(retryManager, failedRowsReceiver, successfulRowsReceiver); retryManager.run(false); } if (numRowsWritten > 0) { @@ -731,6 +793,7 @@ public void process( ProcessContext c, PipelineOptions pipelineOptions, @Element KV element, + @Timestamp org.joda.time.Instant elementTs, MultiOutputReceiver o) throws Exception { DatasetService initializedDatasetService = initializeDatasetService(pipelineOptions); @@ -744,15 +807,18 @@ public void process( c, k, initializedDatasetService, pipelineOptions.as(BigQueryOptions.class))); OutputReceiver failedRowsReceiver = o.get(failedRowsTag); - flushIfNecessary(failedRowsReceiver); - state.addMessage(element.getValue(), failedRowsReceiver); + @Nullable + OutputReceiver successfulRowsReceiver = + (successfulRowsTag != null) ? o.get(successfulRowsTag) : null; + flushIfNecessary(failedRowsReceiver, successfulRowsReceiver); + state.addMessage(element.getValue(), elementTs, failedRowsReceiver); ++numPendingRecords; numPendingRecordBytes += element.getValue().getPayload().length; } @FinishBundle public void finishBundle(FinishBundleContext context) throws Exception { - flushAll( + OutputReceiver failedRowsReceiver = new OutputReceiver() { @Override public void output(BigQueryStorageApiInsertError output) { @@ -764,7 +830,24 @@ public void outputWithTimestamp( BigQueryStorageApiInsertError output, org.joda.time.Instant timestamp) { context.output(failedRowsTag, output, timestamp, GlobalWindow.INSTANCE); } - }); + }; + @Nullable OutputReceiver successfulRowsReceiver = null; + if (successfulRowsTag != null) { + successfulRowsReceiver = + new OutputReceiver() { + @Override + public void output(TableRow output) { + outputWithTimestamp(output, GlobalWindow.INSTANCE.maxTimestamp()); + } + + @Override + public void outputWithTimestamp(TableRow output, org.joda.time.Instant timestamp) { + context.output(successfulRowsTag, output, timestamp, GlobalWindow.INSTANCE); + } + }; + } + + flushAll(failedRowsReceiver, successfulRowsReceiver); final Map destinations = Preconditions.checkStateNotNull(this.destinations); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java index a9d814ffe890..cd23b7be2c52 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiWritesShardedRecords.java @@ -123,8 +123,12 @@ public class StorageApiWritesShardedRecords failedRowsCoder; private final boolean autoUpdateSchema; private final boolean ignoreUnknownValues; + private final Duration streamIdleTime = DEFAULT_STREAM_IDLE_TIME; private final TupleTag failedRowsTag; + private final @Nullable TupleTag successfulRowsTag; + private final Coder succussfulRowsCoder; + private final TupleTag> flushTag = new TupleTag<>("flushTag"); private static final ExecutorService closeWriterExecutor = Executors.newCachedThreadPool(); @@ -138,9 +142,13 @@ class AppendRowsContext extends RetryManager.Operation.Context key, ProtoRows protoRows) { + List timestamps; + + AppendRowsContext( + ShardedKey key, ProtoRows protoRows, List timestamps) { this.key = key; this.protoRows = protoRows; + this.timestamps = timestamps; } @Override @@ -203,7 +211,9 @@ public StorageApiWritesShardedRecords( BigQueryServices bqServices, Coder destinationCoder, Coder failedRowsCoder, + Coder successfulRowsCoder, TupleTag failedRowsTag, + @Nullable TupleTag successfulRowsTag, boolean autoUpdateSchema, boolean ignoreUnknownValues) { this.dynamicDestinations = dynamicDestinations; @@ -213,6 +223,8 @@ public StorageApiWritesShardedRecords( this.destinationCoder = destinationCoder; this.failedRowsCoder = failedRowsCoder; this.failedRowsTag = failedRowsTag; + this.successfulRowsTag = successfulRowsTag; + this.succussfulRowsCoder = successfulRowsCoder; this.autoUpdateSchema = autoUpdateSchema; this.ignoreUnknownValues = ignoreUnknownValues; } @@ -225,13 +237,17 @@ public PCollectionTuple expand( final long maxRequestSize = bigQueryOptions.getStorageWriteApiMaxRequestSize(); String operationName = input.getName() + "/" + getName(); + TupleTagList tupleTagList = TupleTagList.of(failedRowsTag); + if (successfulRowsTag != null) { + tupleTagList = tupleTagList.and(successfulRowsTag); + } // Append records to the Storage API streams. PCollectionTuple writeRecordsResult = input.apply( "Write Records", ParDo.of(new WriteRecordsDoFn(operationName, streamIdleTime, splitSize, maxRequestSize)) .withSideInputs(dynamicDestinations.getSideInputs()) - .withOutputTags(flushTag, TupleTagList.of(failedRowsTag))); + .withOutputTags(flushTag, tupleTagList)); SchemaCoder operationCoder; try { @@ -261,6 +277,9 @@ public PCollectionTuple expand( .apply( "Flush and finalize writes", ParDo.of(new StorageApiFlushAndFinalizeDoFn(bqServices))); writeRecordsResult.get(failedRowsTag).setCoder(failedRowsCoder); + if (successfulRowsTag != null) { + writeRecordsResult.get(successfulRowsTag).setCoder(succussfulRowsCoder); + } return writeRecordsResult; } @@ -377,6 +396,7 @@ public void process( ProcessContext c, final PipelineOptions pipelineOptions, @Element KV, Iterable> element, + @Timestamp org.joda.time.Instant elementTs, final @AlwaysFetched @StateId("streamName") ValueState streamName, final @AlwaysFetched @StateId("streamOffset") ValueState streamOffset, final @StateId("updatedSchema") ValueState updatedSchema, @@ -468,7 +488,7 @@ public void process( // Each ProtoRows object contains at most 1MB of rows. // TODO: Push messageFromTableRow up to top level. That we we cans skip TableRow entirely if // already proto or already schema. - Iterable messages = + Iterable messages = new SplittingIterable( element.getValue(), splitSize, @@ -476,9 +496,12 @@ public void process( bytes -> appendClientInfo.get().toTableRow(bytes), (failedRow, errorMessage) -> o.get(failedRowsTag) - .output(new BigQueryStorageApiInsertError(failedRow, errorMessage)), + .outputWithTimestamp( + new BigQueryStorageApiInsertError(failedRow.getValue(), errorMessage), + failedRow.getTimestamp()), autoUpdateSchema, - ignoreUnknownValues); + ignoreUnknownValues, + elementTs); // Initialize stream names and offsets for all contexts. This will be called initially, but // will also be called if we roll over to a new stream on a retry. @@ -566,23 +589,28 @@ public void process( // Convert the message to a TableRow and send it to the failedRows collection. ByteString protoBytes = failedContext.protoRows.getSerializedRows(failedIndex); TableRow failedRow = appendClientInfo.get().toTableRow(protoBytes); + org.joda.time.Instant timestamp = failedContext.timestamps.get(failedIndex); o.get(failedRowsTag) - .output( + .outputWithTimestamp( new BigQueryStorageApiInsertError( - failedRow, error.getRowIndexToErrorMessage().get(failedIndex))); + failedRow, error.getRowIndexToErrorMessage().get(failedIndex)), + timestamp); } rowsSentToFailedRowsCollection.inc(failedRowIndices.size()); // Remove the failed row from the payload, so we retry the batch without the failed // rows. ProtoRows.Builder retryRows = ProtoRows.newBuilder(); + @Nullable List timestamps = Lists.newArrayList(); for (int i = 0; i < failedContext.protoRows.getSerializedRowsCount(); ++i) { if (!failedRowIndices.contains(i)) { ByteString rowBytes = failedContext.protoRows.getSerializedRows(i); retryRows.addSerializedRows(rowBytes); + timestamps.add(failedContext.timestamps.get(i)); } } failedContext.protoRows = retryRows.build(); + failedContext.timestamps = timestamps; // Since we removed rows, we need to update the insert offsets for all remaining rows. long offset = failedContext.offset; @@ -655,16 +683,25 @@ public void process( context.offset + context.protoRows.getSerializedRowsCount() - 1, false))); flushesScheduled.inc(context.protoRows.getSerializedRowsCount()); + + if (successfulRowsTag != null) { + for (int i = 0; i < context.protoRows.getSerializedRowsCount(); ++i) { + ByteString protoBytes = context.protoRows.getSerializedRows(i); + org.joda.time.Instant timestamp = context.timestamps.get(i); + o.get(successfulRowsTag) + .outputWithTimestamp(appendClientInfo.get().toTableRow(protoBytes), timestamp); + } + } }; Instant now = Instant.now(); List contexts = Lists.newArrayList(); RetryManager retryManager = new RetryManager<>(Duration.standardSeconds(1), Duration.standardSeconds(10), 1000); int numAppends = 0; - for (ProtoRows protoRows : messages) { + for (SplittingIterable.Value splitValue : messages) { // Handle the case of a row that is too large. - if (protoRows.getSerializedSize() >= maxRequestSize) { - if (protoRows.getSerializedRowsCount() > 1) { + if (splitValue.getProtoRows().getSerializedSize() >= maxRequestSize) { + if (splitValue.getProtoRows().getSerializedRowsCount() > 1) { // TODO(reuvenlax): Is it worth trying to handle this case by splitting the protoRows? // Given that we split // the ProtoRows iterable at 2MB and the max request size is 10MB, this scenario seems @@ -674,20 +711,25 @@ public void process( + maxRequestSize + ". This is unexpected. All rows in the request will be sent to the failed-rows PCollection."); } - for (ByteString rowBytes : protoRows.getSerializedRowsList()) { + for (int i = 0; i < splitValue.getProtoRows().getSerializedRowsCount(); ++i) { + ByteString rowBytes = splitValue.getProtoRows().getSerializedRows(i); + org.joda.time.Instant timestamp = splitValue.getTimestamps().get(i); TableRow failedRow = appendClientInfo.get().toTableRow(rowBytes); o.get(failedRowsTag) - .output( + .outputWithTimestamp( new BigQueryStorageApiInsertError( - failedRow, "Row payload too large. Maximum size " + maxRequestSize)); + failedRow, "Row payload too large. Maximum size " + maxRequestSize), + timestamp); } } else { ++numAppends; // RetryManager - AppendRowsContext context = new AppendRowsContext(element.getKey(), protoRows); + AppendRowsContext context = + new AppendRowsContext( + element.getKey(), splitValue.getProtoRows(), splitValue.getTimestamps()); contexts.add(context); retryManager.addOperation(runOperation, onError, onSuccess, context); - recordsAppended.inc(protoRows.getSerializedRowsCount()); + recordsAppended.inc(splitValue.getProtoRows().getSerializedRowsCount()); appendSizeDistribution.update(context.protoRows.getSerializedRowsCount()); } } @@ -709,7 +751,6 @@ public void process( if (autoUpdateSchema) { @Nullable StreamAppendClient streamAppendClient = appendClientInfo.get().getStreamAppendClient(); - ; @Nullable TableSchema newSchema = (streamAppendClient != null) ? streamAppendClient.getUpdatedSchema() : null; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java index 23cda2f57891..00779193262d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StreamingWriteTables.java @@ -337,6 +337,8 @@ public WriteResult expand(PCollection> input) { null, null, null, + null, + null, null); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteResult.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteResult.java index d18dabfc8ea1..7ea2f959ce67 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteResult.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteResult.java @@ -42,6 +42,8 @@ public final class WriteResult implements POutput { private final @Nullable PCollection successfulBatchInserts; private final @Nullable TupleTag failedStorageApiInsertsTag; private final @Nullable PCollection failedStorageApiInserts; + private final @Nullable TupleTag successfulStorageApiInsertsTag; + private final @Nullable PCollection successfulStorageApiInserts; /** Creates a {@link WriteResult} in the given {@link Pipeline}. */ static WriteResult in( @@ -52,7 +54,9 @@ static WriteResult in( @Nullable TupleTag successfulBatchInsertsTag, @Nullable PCollection successfulBatchInserts, @Nullable TupleTag failedStorageApiInsertsTag, - @Nullable PCollection failedStorageApiInserts) { + @Nullable PCollection failedStorageApiInserts, + @Nullable TupleTag successfulStorageApiInsertsTag, + @Nullable PCollection successfulStorageApiInserts) { return new WriteResult( pipeline, failedInsertsTag, @@ -63,7 +67,9 @@ static WriteResult in( successfulBatchInsertsTag, successfulBatchInserts, failedStorageApiInsertsTag, - failedStorageApiInserts); + failedStorageApiInserts, + successfulStorageApiInsertsTag, + successfulStorageApiInserts); } static WriteResult withExtendedErrors( @@ -81,6 +87,8 @@ static WriteResult withExtendedErrors( null, null, null, + null, + null, null); } @@ -116,7 +124,9 @@ private WriteResult( @Nullable TupleTag successfulInsertsTag, @Nullable PCollection successfulBatchInserts, @Nullable TupleTag failedStorageApiInsertsTag, - @Nullable PCollection failedStorageApiInserts) { + @Nullable PCollection failedStorageApiInserts, + @Nullable TupleTag successfulStorageApiInsertsTag, + @Nullable PCollection successfulStorageApiInserts) { this.pipeline = pipeline; this.failedInsertsTag = failedInsertsTag; this.failedInserts = failedInserts; @@ -127,6 +137,8 @@ private WriteResult( this.successfulBatchInserts = successfulBatchInserts; this.failedStorageApiInsertsTag = failedStorageApiInsertsTag; this.failedStorageApiInserts = failedStorageApiInserts; + this.successfulStorageApiInsertsTag = successfulStorageApiInsertsTag; + this.successfulStorageApiInserts = successfulStorageApiInserts; } /** @@ -194,6 +206,11 @@ public PCollection getFailedInsertsWithErr() { + " extended errors. Use getFailedInserts or getFailedStorageApiInserts instead"); } + /** + * Return any rows that persistently fail to insert when using a storage-api method. For example: + * rows with values that do not match the BigQuery schema or rows that are too large to insert. + * This collection is in the global window. + */ public PCollection getFailedStorageApiInserts() { Preconditions.checkStateNotNull( failedStorageApiInsertsTag, @@ -203,6 +220,23 @@ public PCollection getFailedStorageApiInserts() { "Cannot use getFailedStorageApiInserts as this insert didn't use the storage API."); } + /** + * Return all rows successfully inserted using one of the storage-api insert methods. Rows undergo + * a conversion process, so while these TableRow objects are logically the same as the rows in the + * initial PCollection, they may not be physically identical. This PCollection is in the global + * window. + */ + public PCollection getSuccessfulStorageApiInserts() { + Preconditions.checkStateNotNull( + successfulStorageApiInsertsTag, + "Can only getSuccessfulStorageApiInserts if using the storage API and " + + "withPropagateSuccessfulStorageApiWrites() is set."); + return Preconditions.checkStateNotNull( + successfulStorageApiInserts, + "Can only getSuccessfulStorageApiInserts if using the storage API and " + + "withPropagateSuccessfulStorageApiWrites() is set."); + } + @Override public Pipeline getPipeline() { return pipeline; diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java index fc86c256a6d6..13a06b9a25a0 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java @@ -2814,6 +2814,7 @@ public void testStorageApiErrors() throws Exception { .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) .withSchema(tableSchema) .withFailedInsertRetryPolicy(InsertRetryPolicy.retryTransientErrors()) + .withPropagateSuccessfulStorageApiWrites(true) .withTestServices(fakeBqServices) .withoutValidation()); @@ -2823,10 +2824,15 @@ public void testStorageApiErrors() throws Exception { .apply( MapElements.into(TypeDescriptor.of(TableRow.class)) .via(BigQueryStorageApiInsertError::getRow)); + PCollection successfulRows = result.getSuccessfulStorageApiInserts(); PAssert.that(deadRows) .containsInAnyOrder( Iterables.concat(badRows, Iterables.filter(goodRows, shouldFailRow::apply))); + PAssert.that(successfulRows) + .containsInAnyOrder( + Iterables.toArray( + Iterables.filter(goodRows, r -> !shouldFailRow.apply(r)), TableRow.class)); p.run(); assertThat( From 8a4e029ca255be78be55e9c8f2f2198a1645a9e7 Mon Sep 17 00:00:00 2001 From: Bjorn Pedersen Date: Wed, 22 Mar 2023 17:13:51 -0400 Subject: [PATCH 11/13] using fileio.BlobReader in GcsDownloader + linting --- sdks/python/apache_beam/io/gcp/gcsio.py | 92 +++++--------------- sdks/python/apache_beam/io/gcp/gcsio_test.py | 5 +- 2 files changed, 23 insertions(+), 74 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index 6d5ad597af34..18d4fac76856 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -38,12 +38,8 @@ import time import traceback from itertools import islice -from typing import Optional -from typing import Union from google.cloud import storage -import apache_beam -from apache_beam.internal.http_client import get_new_http from apache_beam.internal.metrics.metric import ServiceCallMetric from apache_beam.io.filesystemio import Downloader from apache_beam.io.filesystemio import DownloaderStream @@ -199,14 +195,13 @@ def get_bucket(self, bucket_name): def create_bucket(self, bucket_name, project, kms_key=None, location=None): """Create and return a GCS bucket in a specific project.""" - encryption = None - + try: bucket = self.client.create_bucket( - bucket_or_name=bucket_name, - project=project, - location=location, - ) + bucket_or_name=bucket_name, + project=project, + location=location, + ) if kms_key: bucket.default_kms_key_name(kms_key) return self.get_bucket(bucket_name) @@ -265,7 +260,7 @@ def delete(self, path): """ bucket_name, target_name = parse_gcs_path(path) try: - bucket = self.client.get_bucket(bucket_name) + bucket = self.get_bucket(bucket_name) bucket.delete_blob(target_name) except HttpError as http_error: if http_error.status_code == 404: @@ -318,56 +313,25 @@ def delete_batch(self, paths): @retry.with_exponential_backoff( retry_filter=retry.retry_on_server_errors_and_timeout_filter) - def copy( - self, - src, - dest): - # dest_kms_key_name=None, - # max_bytes_rewritten_per_call=None): + def copy(self, src, dest): """Copies the given GCS object from src to dest. Args: src: GCS file path pattern in the form gs:///. dest: GCS file path pattern in the form gs:///. !!! - dest_kms_key_name: Experimental. No backwards compatibility guarantees. - Encrypt dest with this Cloud KMS key. If None, will use dest bucket - encryption defaults. - max_bytes_rewritten_per_call: Experimental. No backwards compatibility - guarantees. Each rewrite API call will return after these many bytes. - Used for testing. !!! Raises: TimeoutError: on timeout. """ src_bucket_name, src_path = parse_gcs_path(src) dest_bucket_name, dest_path = parse_gcs_path(dest) - # request = storage.StorageObjectsRewriteRequest( - # sourceBucket=src_bucket, - # sourceObject=src_path, - # destinationBucket=dest_bucket, - # destinationObject=dest_path, - # destinationKmsKeyName=dest_kms_key_name, - # maxBytesRewrittenPerCall=max_bytes_rewritten_per_call) src_bucket = self.get_bucket(src_bucket_name) src_blob = src_bucket.get_blob(src_path) dest_bucket = self.get_bucket(dest_bucket_name) if not dest_path: dest_path = None - response = src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_path) - # !!! while not response.done: - # _LOGGER.debug( - # 'Rewrite progress: %d of %d bytes, %s to %s', - # response.totalBytesRewritten, - # response.objectSize, - # src, - # dest) - # request.rewriteToken = response.rewriteToken - # response = self.client.objects.Rewrite(request) - # if self._rewrite_cb is not None: - # self._rewrite_cb(response) - - # _LOGGER.debug('Rewrite done: %s to %s', src, dest) !!! + src_bucket.copy_blob(src_blob, dest_bucket, new_name=dest_path) # We intentionally do not decorate this method with a retry, as retrying is # handled in BatchApiRequest.Execute(). @@ -620,11 +584,11 @@ def list_files(self, path, with_metadata=False): "Finished computing file information of: %s files", len(file_info)) else: - _LOGGER.info( - "Finished computing size of: %s files", len(file_info)) + _LOGGER.info("Finished computing size of: %s files", len(file_info)) if with_metadata: - yield file_name, (item.size(), self._updated_to_seconds(item.updated())) + yield file_name, ( + item.size(), self._updated_to_seconds(item.updated())) else: yield file_name, item.size() @@ -654,8 +618,8 @@ def __init__(self, client, path, buffer_size, get_project_number): # Create a request count metric resource = resource_identifiers.GoogleCloudStorageBucket(self._bucket) labels = { - monitoring_infos.SERVICE_LABEL: 'Storage', - monitoring_infos.METHOD_LABEL: 'Objects.get', + monitoring_infos.SERVICE_LABEL: 'GCS Client', + monitoring_infos.METHOD_LABEL: 'BlobReader.read', monitoring_infos.RESOURCE_LABEL: resource, monitoring_infos.GCS_BUCKET_LABEL: self._bucket } @@ -664,7 +628,7 @@ def __init__(self, client, path, buffer_size, get_project_number): labels[monitoring_infos.GCS_PROJECT_ID_LABEL] = str(project_number) else: _LOGGER.debug( - 'Possibly missing storage.buckets.get permission to ' + 'Possibly missing storage.get_bucket permission to ' 'bucket %s. Label %s is not added to the counter because it ' 'cannot be identified.', self._bucket, @@ -674,12 +638,9 @@ def __init__(self, client, path, buffer_size, get_project_number): request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN, base_labels=labels) - # Get object state. - self._get_request = ( - storage.StorageObjectsGetRequest( - bucket=self._bucket, object=self._name)) try: - metadata = self._get_object_metadata(self._get_request) + bucket = self._client.get_bucket(self._bucket) + metadata = bucket.get_blob(self._name) except HttpError as http_error: service_call_metric.call(http_error) if http_error.status_code == 404: @@ -691,31 +652,20 @@ def __init__(self, client, path, buffer_size, get_project_number): else: service_call_metric.call('ok') - self._size = metadata.size - - # Ensure read is from file of the correct generation. - self._get_request.generation = metadata.generation - - # Initialize read buffer state. - self._download_stream = io.BytesIO() - self._downloader = transfer.Download( - self._download_stream, - auto_transfer=False, - chunksize=self._buffer_size, - num_retries=20) + self._size = metadata.size() try: - self._client.objects.Get(self._get_request, download=self._downloader) + reader = storage.fileio.BlobReader(metadata, chunk_size=self._buffer_size) + reader.read() service_call_metric.call('ok') except HttpError as e: service_call_metric.call(e) raise + finally: + reader.close() @retry.with_exponential_backoff( retry_filter=retry.retry_on_server_errors_and_timeout_filter) - def _get_object_metadata(self, get_request): - return self._client.objects.Get(get_request) - @property def size(self): return self._size diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index 7d504368ea0c..c0ecd472efb3 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -411,8 +411,7 @@ def test_kms_key(self): file_size = 1234 kms_key = "dummy" - self._insert_random_file( - self.client, file_name, file_size, kms_key=kms_key) + self._insert_random_file(self.client, file_name, file_size, kms_key=kms_key) self.assertTrue(self.gcs.exists(file_name)) self.assertEqual(kms_key, self.gcs.kms_key(file_name)) @@ -537,7 +536,7 @@ def test_copy(self): self.assertFalse( gcsio.parse_gcs_path(dest_file_name) in self.client.objects.files) - self.gcs.copy(src_file_name, dest_file_name, dest_kms_key_name='kms_key') + self.gcs.copy(src_file_name, dest_file_name) self.assertTrue( gcsio.parse_gcs_path(src_file_name) in self.client.objects.files) From 08bfd4fb7ddd521f185e8eb5c3a61fbc5914937e Mon Sep 17 00:00:00 2001 From: Bjorn Pedersen Date: Thu, 23 Mar 2023 15:27:41 -0400 Subject: [PATCH 12/13] replaced GcsUploader and GcsDownloader --- .../assets/symbols/python.g.yaml | 8 - sdks/python/apache_beam/io/gcp/gcsio.py | 197 +----------------- 2 files changed, 8 insertions(+), 197 deletions(-) diff --git a/playground/frontend/playground_components/assets/symbols/python.g.yaml b/playground/frontend/playground_components/assets/symbols/python.g.yaml index a47447225a68..0b9e5e142ded 100644 --- a/playground/frontend/playground_components/assets/symbols/python.g.yaml +++ b/playground/frontend/playground_components/assets/symbols/python.g.yaml @@ -4790,10 +4790,6 @@ GBKTransform: - from_runner_api_parameter - to_runner_api_parameter GcpTestIOError: {} -GcsDownloader: - methods: - - get_range - - size GCSFileSystem: methods: - checksum @@ -4837,10 +4833,6 @@ GcsIOError: {} GcsIOOverrides: methods: - retry_func -GcsUploader: - methods: - - finish - - put GeneralPurposeConsumerSet: methods: - flush diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index 18d4fac76856..1787f0cb6e20 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -30,24 +30,12 @@ # pytype: skip-file import errno -import io import logging -import multiprocessing import re -import threading import time -import traceback from itertools import islice from google.cloud import storage -from apache_beam.internal.metrics.metric import ServiceCallMetric -from apache_beam.io.filesystemio import Downloader -from apache_beam.io.filesystemio import DownloaderStream -from apache_beam.io.filesystemio import PipeStream -from apache_beam.io.filesystemio import Uploader -from apache_beam.io.filesystemio import UploaderStream -from apache_beam.io.gcp import resource_identifiers -from apache_beam.metrics import monitoring_infos from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.utils import retry from apache_beam.utils.annotations import deprecated @@ -63,7 +51,6 @@ # pylint: disable=ungrouped-imports from apitools.base.py.batch import BatchApiRequest from apitools.base.py.exceptions import HttpError - from apitools.base.py import transfer from apache_beam.internal.gcp import auth except ImportError: raise ImportError( @@ -229,24 +216,15 @@ def open( Raises: ValueError: Invalid open file mode. """ + bucket_name, blob_name = parse_gcs_path(filename) + bucket = self.client.get_bucket(bucket_name) + blob = bucket.get_blob(blob_name) + if mode == 'r' or mode == 'rb': - downloader = GcsDownloader( - self.client, - filename, - buffer_size=read_buffer_size, - get_project_number=self.get_project_number) - return io.BufferedReader( - DownloaderStream( - downloader, read_buffer_size=read_buffer_size, mode=mode), - buffer_size=read_buffer_size) + return storage.fileio.BlobReader(blob, chunk_size=read_buffer_size) elif mode == 'w' or mode == 'wb': - uploader = GcsUploader( - self.client, - filename, - mime_type, - get_project_number=self.get_project_number) - return io.BufferedWriter( - UploaderStream(uploader, mode=mode), buffer_size=128 * 1024) + return storage.fileio.BlobReader( + blob, chunk_size=read_buffer_size, content_type=mime_type) else: raise ValueError('Invalid file open mode: %s.' % mode) @@ -260,7 +238,7 @@ def delete(self, path): """ bucket_name, target_name = parse_gcs_path(path) try: - bucket = self.get_bucket(bucket_name) + bucket = self.client.get_bucket(bucket_name) bucket.delete_blob(target_name) except HttpError as http_error: if http_error.status_code == 404: @@ -605,162 +583,3 @@ def _updated_to_seconds(updated): return ( time.mktime(updated.timetuple()) - time.timezone + updated.microsecond / 1000000.0) - - -class GcsDownloader(Downloader): - def __init__(self, client, path, buffer_size, get_project_number): - self._client = client - self._path = path - self._bucket, self._name = parse_gcs_path(path) - self._buffer_size = buffer_size - self._get_project_number = get_project_number - - # Create a request count metric - resource = resource_identifiers.GoogleCloudStorageBucket(self._bucket) - labels = { - monitoring_infos.SERVICE_LABEL: 'GCS Client', - monitoring_infos.METHOD_LABEL: 'BlobReader.read', - monitoring_infos.RESOURCE_LABEL: resource, - monitoring_infos.GCS_BUCKET_LABEL: self._bucket - } - project_number = self._get_project_number(self._bucket) - if project_number: - labels[monitoring_infos.GCS_PROJECT_ID_LABEL] = str(project_number) - else: - _LOGGER.debug( - 'Possibly missing storage.get_bucket permission to ' - 'bucket %s. Label %s is not added to the counter because it ' - 'cannot be identified.', - self._bucket, - monitoring_infos.GCS_PROJECT_ID_LABEL) - - service_call_metric = ServiceCallMetric( - request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN, - base_labels=labels) - - try: - bucket = self._client.get_bucket(self._bucket) - metadata = bucket.get_blob(self._name) - except HttpError as http_error: - service_call_metric.call(http_error) - if http_error.status_code == 404: - raise IOError(errno.ENOENT, 'Not found: %s' % self._path) - else: - _LOGGER.error( - 'HTTP error while requesting file %s: %s', self._path, http_error) - raise - else: - service_call_metric.call('ok') - - self._size = metadata.size() - - try: - reader = storage.fileio.BlobReader(metadata, chunk_size=self._buffer_size) - reader.read() - service_call_metric.call('ok') - except HttpError as e: - service_call_metric.call(e) - raise - finally: - reader.close() - - @retry.with_exponential_backoff( - retry_filter=retry.retry_on_server_errors_and_timeout_filter) - @property - def size(self): - return self._size - - def get_range(self, start, end): - self._download_stream.seek(0) - self._download_stream.truncate(0) - self._downloader.GetRange(start, end - 1) - return self._download_stream.getvalue() - - -class GcsUploader(Uploader): - def __init__(self, client, path, mime_type, get_project_number): - self._client = client - self._path = path - self._bucket, self._name = parse_gcs_path(path) - self._mime_type = mime_type - self._get_project_number = get_project_number - - # Set up communication with child thread. - parent_conn, child_conn = multiprocessing.Pipe() - self._child_conn = child_conn - self._conn = parent_conn - - # Set up uploader. - self._insert_request = ( - storage.StorageObjectsInsertRequest( - bucket=self._bucket, name=self._name)) - self._upload = transfer.Upload( - PipeStream(self._child_conn), - self._mime_type, - chunksize=WRITE_CHUNK_SIZE) - self._upload.strategy = transfer.RESUMABLE_UPLOAD - - # Start uploading thread. - self._upload_thread = threading.Thread(target=self._start_upload) - self._upload_thread.daemon = True - self._upload_thread.last_error = None - self._upload_thread.start() - - # TODO(silviuc): Refactor so that retry logic can be applied. - # There is retry logic in the underlying transfer library but we should make - # it more explicit so we can control the retry parameters. - @retry.no_retries # Using no_retries marks this as an integration point. - def _start_upload(self): - # This starts the uploader thread. We are forced to run the uploader in - # another thread because the apitools uploader insists on taking a stream - # as input. Happily, this also means we get asynchronous I/O to GCS. - # - # The uploader by default transfers data in chunks of 1024 * 1024 bytes at - # a time, buffering writes until that size is reached. - - project_number = self._get_project_number(self._bucket) - - # Create a request count metric - resource = resource_identifiers.GoogleCloudStorageBucket(self._bucket) - labels = { - monitoring_infos.SERVICE_LABEL: 'Storage', - monitoring_infos.METHOD_LABEL: 'Objects.insert', - monitoring_infos.RESOURCE_LABEL: resource, - monitoring_infos.GCS_BUCKET_LABEL: self._bucket, - monitoring_infos.GCS_PROJECT_ID_LABEL: str(project_number) - } - service_call_metric = ServiceCallMetric( - request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN, - base_labels=labels) - try: - self._client.objects.Insert(self._insert_request, upload=self._upload) - service_call_metric.call('ok') - except Exception as e: # pylint: disable=broad-except - service_call_metric.call(e) - _LOGGER.error( - 'Error in _start_upload while inserting file %s: %s', - self._path, - traceback.format_exc()) - self._upload_thread.last_error = e - finally: - self._child_conn.close() - - def put(self, data): - try: - self._conn.send_bytes(data.tobytes()) - except EOFError: - if self._upload_thread.last_error is not None: - raise self._upload_thread.last_error # pylint: disable=raising-bad-type - raise - - def finish(self): - self._conn.close() - # TODO(udim): Add timeout=DEFAULT_HTTP_TIMEOUT_SECONDS * 2 and raise if - # isAlive is True. - self._upload_thread.join() - # Check for exception since the last put() call. - if self._upload_thread.last_error is not None: - raise type(self._upload_thread.last_error)( - "Error while uploading file %s: %s", - self._path, - self._upload_thread.last_error.message) # pylint: disable=raising-bad-type From d7c693a71445fde6eb561a2f97cb42cf6f3df18b Mon Sep 17 00:00:00 2001 From: Bjorn Pedersen Date: Thu, 23 Mar 2023 15:39:55 -0400 Subject: [PATCH 13/13] updated interactive/utils.py --- .../apache_beam/runners/interactive/utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/runners/interactive/utils.py b/sdks/python/apache_beam/runners/interactive/utils.py index bdd9ab4d1a89..8c854d06e679 100644 --- a/sdks/python/apache_beam/runners/interactive/utils.py +++ b/sdks/python/apache_beam/runners/interactive/utils.py @@ -27,14 +27,14 @@ from typing import Dict from typing import Tuple +from google.cloud import storage + import pandas as pd import apache_beam as beam from apache_beam.dataframe.convert import to_pcollection from apache_beam.dataframe.frame_base import DeferredBase from apache_beam.internal.gcp import auth -from apache_beam.internal.http_client import get_new_http -from apache_beam.io.gcp.internal.clients import storage from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.pipeline import Pipeline from apache_beam.portability.api import beam_runner_api_pb2 @@ -452,13 +452,9 @@ def assert_bucket_exists(bucket_name): """ try: from apitools.base.py.exceptions import HttpError - storage_client = storage.StorageV1( - credentials=auth.get_service_credentials(PipelineOptions()), - get_credentials=False, - http=get_new_http(), - response_encoding='utf8') - request = storage.StorageBucketsGetRequest(bucket=bucket_name) - storage_client.buckets.Get(request) + storage_client = storage.Client( + credentials=auth.get_service_credentials(PipelineOptions())) + storage_client.get_bucket(bucket_name) except HttpError as e: if e.status_code == 404: _LOGGER.error('%s bucket does not exist!', bucket_name)