From 81a50f2331b96277ffa9bb6359250c379d927d8b Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 11 Sep 2024 11:25:41 -0700 Subject: [PATCH 001/181] Add various utility meta-transforms to Beam. --- CHANGES.md | 2 + .../apache/beam/sdk/transforms/Flatten.java | 77 ++++++++++++++++ .../org/apache/beam/sdk/transforms/Tee.java | 91 +++++++++++++++++++ .../beam/sdk/transforms/FlattenTest.java | 27 ++++++ .../apache/beam/sdk/transforms/TeeTest.java | 84 +++++++++++++++++ sdks/python/apache_beam/transforms/core.py | 28 ++++++ .../apache_beam/transforms/ptransform_test.py | 12 +++ sdks/python/apache_beam/transforms/util.py | 34 +++++++ .../apache_beam/transforms/util_test.py | 29 ++++++ 9 files changed, 384 insertions(+) create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java create mode 100644 sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java diff --git a/CHANGES.md b/CHANGES.md index d58ceffeb411..a9d6eeba10d8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,6 +69,8 @@ * Dataflow worker can install packages from Google Artifact Registry Python repositories (Python) ([#32123](https://github.com/apache/beam/issues/32123)). * Added support for Zstd codec in SerializableAvroCodecFactory (Java) ([#32349](https://github.com/apache/beam/issues/32349)) * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Add new meta-transform FlattenWith and Tee that allow one to introduce branching + without breaking the linear/chaining style of pipeline construction. ## Breaking Changes diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java index 8da3cf71af9f..afc11353f1a5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java @@ -20,6 +20,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableLikeCoder; import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionList; @@ -82,6 +83,82 @@ public static Iterables iterables() { return new Iterables<>(); } + /** + * Returns a {@link PTransform} that flattens the input {@link PCollection} with a given a {@link + * PCollection} resulting in a {@link PCollection} containing all the elements of both {@link + * PCollection}s as its output. + * + *

This is equivalent to creating a {@link PCollectionList} containing both the input and + * {@code other} and then applying {@link #pCollections()}, but has the advantage that it can be + * more easily used inline. + * + *

Both {@cpde PCollections} must have equal {@link WindowFn}s. The output elements of {@code + * Flatten} are in the same windows and have the same timestamps as their corresponding input + * elements. The output {@code PCollection} will have the same {@link WindowFn} as both inputs. + * + * @param other the other PCollection to flatten with the input + * @param the type of the elements in the input and output {@code PCollection}s. + */ + public static PTransform, PCollection> with(PCollection other) { + return new FlattenWithPCollection<>(other); + } + + /** Implementation of {@link #with(PCollection)}. */ + private static class FlattenWithPCollection + extends PTransform, PCollection> { + // We only need to access this at pipeline construction time. + private final transient PCollection other; + + public FlattenWithPCollection(PCollection other) { + this.other = other; + } + + @Override + public PCollection expand(PCollection input) { + return PCollectionList.of(input).and(other).apply(pCollections()); + } + + @Override + public String getKindString() { + return "Flatten.With"; + } + } + + /** + * Returns a {@link PTransform} that flattens the input {@link PCollection} with the output of + * another {@link PTransform} resulting in a {@link PCollection} containing all the elements of + * both the input {@link PCollection}s and the output of the given {@link PTransform} as its + * output. + * + *

This is equivalent to creating a {@link PCollectionList} containing both the input and the + * output of {@code other} and then applying {@link #pCollections()}, but has the advantage that + * it can be more easily used inline. + * + *

Both {@cpde PCollections} must have equal {@link WindowFn}s. The output elements of {@code + * Flatten} are in the same windows and have the same timestamps as their corresponding input + * elements. The output {@code PCollection} will have the same {@link WindowFn} as both inputs. + * + * @param the type of the elements in the input and output {@code PCollection}s. + * @param other a PTransform whose ouptput should be flattened with the input + * @param the type of the elements in the input and output {@code PCollection}s. + */ + public static PTransform, PCollection> with( + PTransform> other) { + return new PTransform, PCollection>() { + @Override + public PCollection expand(PCollection input) { + return PCollectionList.of(input) + .and(input.getPipeline().apply(other)) + .apply(pCollections()); + } + + @Override + public String getKindString() { + return "Flatten.With"; + } + }; + } + /** * A {@link PTransform} that flattens a {@link PCollectionList} into a {@link PCollection} * containing all the elements of all the {@link PCollection}s in its input. Implements {@link diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java new file mode 100644 index 000000000000..bb65cbf94632 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.util.function.Consumer; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; + +/** + * A PTransform that returns its input, but also applies its input to an auxiliary PTransform, akin + * to the shell {@code tee} command. + * + *

This can be useful to write out or otherwise process an intermediate transform without + * breaking the linear flow of a chain of transforms, e.g. + * + *


+ * {@literal PCollection} input = ... ;
+ * {@literal PCollection} result =
+ *     {@literal input.apply(...)}
+ *     ...
+ *     {@literal input.apply(Tee.of(someSideTransform)}
+ *     ...
+ *     {@literal input.apply(...)};
+ * 
+ * + * @param the element type of the input PCollection + */ +public class Tee extends PTransform, PCollection> { + private final PTransform, ?> consumer; + + /** + * Returns a new Tee PTransform that will apply an auxilary transform to the input as well as pass + * it on. + * + * @param consumer An additional PTransform that should process the input PCollection. Its output + * will be ignored. + * @param the type of the elements in the input {@code PCollection}. + */ + public static Tee of(PTransform, ?> consumer) { + return new Tee<>(consumer); + } + + /** + * Returns a new Tee PTransform that will apply an auxilary transform to the input as well as pass + * it on. + * + * @param consumer An arbitrary {@link Consumer} that will be wrapped in a PTransform and applied + * to the input. Its output will be ignored. + * @param the type of the elements in the input {@code PCollection}. + */ + public static Tee of(Consumer> consumer) { + return of( + new PTransform, PCollectionTuple>() { + @Override + public PCollectionTuple expand(PCollection input) { + consumer.accept(input); + return PCollectionTuple.empty(input.getPipeline()); + } + }); + } + + private Tee(PTransform, ?> consumer) { + this.consumer = consumer; + } + + @Override + public PCollection expand(PCollection input) { + input.apply(consumer); + return input; + } + + @Override + protected String getKindString() { + return "Tee(" + consumer.getName() + ")"; + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java index 282a41bed0dc..7a02d95a5046 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlattenTest.java @@ -402,6 +402,32 @@ public void testFlattenWithDifferentInputAndOutputCoders2() { ///////////////////////////////////////////////////////////////////////////// + @Test + @Category(NeedsRunner.class) + public void testFlattenWithPCollection() { + PCollection output = + p.apply(Create.of(LINES)) + .apply("FlattenWithLines1", Flatten.with(p.apply("Create1", Create.of(LINES)))) + .apply("FlattenWithLines2", Flatten.with(p.apply("Create2", Create.of(LINES2)))); + + PAssert.that(output).containsInAnyOrder(flattenLists(Arrays.asList(LINES, LINES2, LINES))); + p.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testFlattenWithPTransform() { + PCollection output = + p.apply(Create.of(LINES)) + .apply("Create1", Flatten.with(Create.of(LINES))) + .apply("Create2", Flatten.with(Create.of(LINES2))); + + PAssert.that(output).containsInAnyOrder(flattenLists(Arrays.asList(LINES, LINES2, LINES))); + p.run(); + } + + ///////////////////////////////////////////////////////////////////////////// + @Test @Category(NeedsRunner.class) public void testEqualWindowFnPropagation() { @@ -470,6 +496,7 @@ public void testIncompatibleWindowFnPropagationFailure() { public void testFlattenGetName() { Assert.assertEquals("Flatten.Iterables", Flatten.iterables().getName()); Assert.assertEquals("Flatten.PCollections", Flatten.pCollections().getName()); + Assert.assertEquals("Flatten.With", Flatten.with((PCollection) null).getName()); } ///////////////////////////////////////////////////////////////////////////// diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java new file mode 100644 index 000000000000..ee3a00c46caa --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/TeeTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.UUID; +import org.apache.beam.sdk.testing.NeedsRunner; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashMultimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimaps; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for Tee. */ +@RunWith(JUnit4.class) +public class TeeTest { + + @Rule public final transient TestPipeline p = TestPipeline.create(); + + @Test + @Category(NeedsRunner.class) + public void testTee() { + List elements = Arrays.asList("a", "b", "c"); + CollectToMemory collector = new CollectToMemory<>(); + PCollection output = p.apply(Create.of(elements)).apply(Tee.of(collector)); + + PAssert.that(output).containsInAnyOrder(elements); + p.run().waitUntilFinish(); + + // Here we assert that this "sink" had the correct side effects. + assertThat(collector.get(), containsInAnyOrder(elements.toArray(new String[3]))); + } + + private static class CollectToMemory extends PTransform, PCollection> { + + private static final Multimap ALL_ELEMENTS = + Multimaps.synchronizedMultimap(HashMultimap.create()); + + UUID uuid = UUID.randomUUID(); + + @Override + public PCollection expand(PCollection input) { + return input.apply( + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) { + ALL_ELEMENTS.put(uuid, c.element()); + } + })); + } + + @SuppressWarnings("unchecked") + public Collection get() { + return (Collection) ALL_ELEMENTS.get(uuid); + } + } +} diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index d7415e8d8135..5671779e5811 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -103,6 +103,7 @@ 'Windowing', 'WindowInto', 'Flatten', + 'FlattenWith', 'Create', 'Impulse', 'RestrictionProvider', @@ -3836,6 +3837,33 @@ def from_runner_api_parameter( common_urns.primitives.FLATTEN.urn, None, Flatten.from_runner_api_parameter) +class FlattenWith(PTransform): + """A PTransform that flattens its input with other PCollections. + + This is equivalent to creating a tuple containing both the input and the + other PCollection(s), but has the advantage that it can be more easily used + inline. + + Root PTransforms can be passed as well as PCollections, in which case their + outputs will be flattened. + """ + def __init__(self, *others): + self._others = others + + def expand(self, pcoll): + pcolls = [pcoll] + for other in self._others: + if isinstance(other, pvalue.PCollection): + pcolls.append(other) + elif isinstance(other, PTransform): + pcolls.append(pcoll.pipeline | other) + else: + raise TypeError( + 'FlattenWith only takes other PCollections and PTransforms, ' + f'got {other}') + return tuple(pcolls) | Flatten() + + class Create(PTransform): """A transform that creates a PCollection from an iterable.""" def __init__(self, values, reshuffle=True): diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index a51d5cd83d26..2c9037185286 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -787,6 +787,18 @@ def split_even_odd(element): assert_that(even_length, equal_to(['AA', 'CC']), label='assert:even') assert_that(odd_length, equal_to(['BBB']), label='assert:odd') + def test_flatten_with(self): + with TestPipeline() as pipeline: + input = pipeline | 'Start' >> beam.Create(['AA', 'BBB', 'CC']) + + result = ( + input + | 'WithPCollection' >> beam.FlattenWith(input | beam.Map(str.lower)) + | 'WithPTransform' >> beam.FlattenWith(beam.Create(['x', 'y']))) + + assert_that( + result, equal_to(['AA', 'BBB', 'CC', 'aa', 'bbb', 'cc', 'x', 'y'])) + def test_group_by_key_input_must_be_kv_pairs(self): with self.assertRaises(typehints.TypeCheckError) as e: with TestPipeline() as pipeline: diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index a27c7aca9e20..a03652de2496 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -30,6 +30,7 @@ import uuid from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import List from typing import Tuple @@ -44,6 +45,7 @@ from apache_beam.portability import common_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.pvalue import AsSideInput +from apache_beam.pvalue import PCollection from apache_beam.transforms import window from apache_beam.transforms.combiners import CountCombineFn from apache_beam.transforms.core import CombinePerKey @@ -92,6 +94,7 @@ 'RemoveDuplicates', 'Reshuffle', 'ToString', + 'Tee', 'Values', 'WithKeys', 'GroupIntoBatches' @@ -1665,6 +1668,37 @@ def _process(element): return pcoll | FlatMap(_process) +@typehints.with_input_types(T) +@typehints.with_output_types(T) +class Tee(PTransform): + """A PTransform that returns its input, but also applies its input elsewhere. + + Similar to the shell {@code tee} command. This can be useful to write out or + otherwise process an intermediate transform without breaking the linear flow + of a chain of transforms, e.g.:: + + (input + | SomePTransform() + | ... + | Tee(SomeSideTransform()) + | ...) + """ + def __init__( + self, + *consumers: Union[PTransform[PCollection[T], Any], + Callable[[PCollection[T]], Any]]): + self._consumers = consumers + + def expand(self, input): + for consumer in self._consumers: + print("apply", consumer) + if callable(consumer): + _ = input | ptransform_fn(consumer)() + else: + _ = input | consumer + return input + + @typehints.with_input_types(T) @typehints.with_output_types(T) class WaitOn(PTransform): diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 74d9f438a5df..9c131504e6f4 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -19,6 +19,8 @@ # pytype: skip-file +import collections +import importlib import logging import math import random @@ -27,6 +29,7 @@ import unittest import warnings from datetime import datetime +from typing import Mapping import pytest import pytz @@ -1812,6 +1815,32 @@ def test_split_without_empty(self): assert_that(result, equal_to(expected_result)) +class TeeTest(unittest.TestCase): + _side_effects: Mapping[str, int] = collections.defaultdict(int) + + def test_tee(self): + # The imports here are to avoid issues with the class (and its attributes) + # possibly being pickled rather than referenced. + def cause_side_effect(element): + importlib.import_module(__name__).TeeTest._side_effects[element] += 1 + + def count_side_effects(element): + return importlib.import_module(__name__).TeeTest._side_effects[element] + + with TestPipeline() as p: + result = ( + p + | beam.Create(['a', 'b', 'c']) + | 'TeePTransform' >> beam.Tee(beam.Map(cause_side_effect)) + | 'TeeCallable' >> beam.Tee( + lambda pcoll: pcoll | beam.Map( + lambda element: cause_side_effect('X' + element)))) + assert_that(result, equal_to(['a', 'b', 'c'])) + + self.assertEqual(count_side_effects('a'), 1) + self.assertEqual(count_side_effects('Xa'), 1) + + class WaitOnTest(unittest.TestCase): def test_find(self): # We need shared reference that survives pickling. From 71d97b6896a7de7fdbb48f0f7835081e411964fc Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Tue, 17 Sep 2024 15:11:09 -0700 Subject: [PATCH 002/181] Add note about FlattenWith to the documentation. --- .../apache_beam/examples/snippets/snippets.py | 33 +++++++++++++++++++ .../examples/snippets/snippets_test.py | 6 ++++ .../en/documentation/programming-guide.md | 27 ++++++++++++++- 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/examples/snippets/snippets.py b/sdks/python/apache_beam/examples/snippets/snippets.py index 715011d302d2..2636f7d2637d 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets.py +++ b/sdks/python/apache_beam/examples/snippets/snippets.py @@ -1143,6 +1143,39 @@ def model_multiple_pcollections_flatten(contents, output_path): merged | beam.io.WriteToText(output_path) +def model_multiple_pcollections_flatten_with(contents, output_path): + """Merging a PCollection with FlattenWith.""" + some_hash_fn = lambda s: ord(s[0]) + partition_fn = lambda element, partitions: some_hash_fn(element) % partitions + import apache_beam as beam + with TestPipeline() as pipeline: # Use TestPipeline for testing. + + # Partition into deciles + partitioned = pipeline | beam.Create(contents) | beam.Partition( + partition_fn, 3) + pcoll1 = partitioned[0] + pcoll2 = partitioned[1] + pcoll3 = partitioned[2] + SomeTransform = lambda: beam.Map(lambda x: x) + SomeOtherTransform = lambda: beam.Map(lambda x: x) + + # Flatten them back into 1 + + # A collection of PCollection objects can be represented simply + # as a tuple (or list) of PCollections. + # (The SDK for Python has no separate type to store multiple + # PCollection objects, whether containing the same or different + # types.) + # [START model_multiple_pcollections_flatten_with] + merged = ( + pcoll1 + | SomeTransform() + | beam.FlattenWith(pcoll2, pcoll3) + | SomeOtherTransform()) + # [END model_multiple_pcollections_flatten_with] + merged | beam.io.WriteToText(output_path) + + def model_multiple_pcollections_partition(contents, output_path): """Splitting a PCollection with Partition.""" some_hash_fn = lambda s: ord(s[0]) diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py index e8cb8960cf4d..0560e9710f03 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets_test.py +++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py @@ -917,6 +917,12 @@ def test_model_multiple_pcollections_flatten(self): snippets.model_multiple_pcollections_flatten(contents, result_path) self.assertEqual(contents, self.get_output(result_path)) + def test_model_multiple_pcollections_flatten_with(self): + contents = ['a', 'b', 'c', 'd', 'e', 'f'] + result_path = self.create_temp_file() + snippets.model_multiple_pcollections_flatten_with(contents, result_path) + self.assertEqual(contents, self.get_output(result_path)) + def test_model_multiple_pcollections_partition(self): contents = [17, 42, 64, 32, 0, 99, 53, 89] result_path = self.create_temp_file() diff --git a/website/www/site/content/en/documentation/programming-guide.md b/website/www/site/content/en/documentation/programming-guide.md index c716c7554db4..cdf82d566a4f 100644 --- a/website/www/site/content/en/documentation/programming-guide.md +++ b/website/www/site/content/en/documentation/programming-guide.md @@ -2024,7 +2024,7 @@ playerAccuracies := ... // PCollection #### 4.2.5. Flatten {#flatten} [`Flatten`](https://beam.apache.org/releases/javadoc/{{< param release_latest >}}/index.html?org/apache/beam/sdk/transforms/Flatten.html) -[`Flatten`](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/transforms/core.py) +[`Flatten`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.core.html#apache_beam.transforms.core.Flatten) [`Flatten`](https://github.com/apache/beam/blob/master/sdks/go/pkg/beam/flatten.go) `Flatten` is a Beam transform for `PCollection` objects that store the same data type. @@ -2045,6 +2045,22 @@ PCollectionList collections = PCollectionList.of(pc1).and(pc2).and(pc3); PCollection merged = collections.apply(Flatten.pCollections()); {{< /highlight >}} +{{< paragraph class="language-java" >}} +One can also use the [`FlattenWith`](https://beam.apache.org/releases/javadoc/{{< param release_latest >}}/index.html?org/apache/beam/sdk/transforms/Flatten.html) +transform to merge PCollections into an output PCollection in a manner more compatible with chaining. +{{< /paragraph >}} + +{{< highlight java >}} +PCollection merged = pc1 + .apply(...) + // Merges the elements of pc2 in at this point... + .apply(FlattenWith.of(pc2)) + .apply(...) + // and the elements of pc3 at this point. + .apply(FlattenWith.of(pc3)) + .apply(...); +{{< /highlight >}} + {{< highlight py >}} # Flatten takes a tuple of PCollection objects. @@ -2052,6 +2068,15 @@ PCollection merged = collections.apply(Flatten.pCollections()); {{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten >}} {{< /highlight >}} +{{< paragraph class="language-py" >}} +One can also use the [`FlattenWith`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.core.html#apache_beam.transforms.core.FlattenWith) +transform to merge PCollections into an output PCollection in a manner more compatible with chaining. +{{< /paragraph >}} + +{{< highlight py >}} +{{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten_with >}} +{{< /highlight >}} + {{< highlight go >}} // Flatten accepts any number of PCollections of the same element type. // Returns a single PCollection that contains all of the elements in input PCollections. From efcf1bab3d3cb3b9f295d5db22fab16897ef7f88 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Mon, 23 Sep 2024 10:23:11 -0700 Subject: [PATCH 003/181] Fix checkstyle rule. --- .../src/main/java/org/apache/beam/sdk/transforms/Flatten.java | 1 - 1 file changed, 1 deletion(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java index afc11353f1a5..6d785c3bc591 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java @@ -140,7 +140,6 @@ public String getKindString() { * * @param the type of the elements in the input and output {@code PCollection}s. * @param other a PTransform whose ouptput should be flattened with the input - * @param the type of the elements in the input and output {@code PCollection}s. */ public static PTransform, PCollection> with( PTransform> other) { From 6707c2c991ef91562a3775d21ea9e823bcaa0326 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Thu, 10 Oct 2024 16:54:04 +0200 Subject: [PATCH 004/181] [blog] - Beam YAML protobuf blogpost --- .../site/content/en/blog/beam-yaml-proto.md | 274 ++++++++++++++++++ 1 file changed, 274 insertions(+) create mode 100644 website/www/site/content/en/blog/beam-yaml-proto.md diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md new file mode 100644 index 000000000000..1d0590a99304 --- /dev/null +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -0,0 +1,274 @@ +--- +title: "Efficient Streaming Data Processing with Beam YAML and Protobuf" +date: "2024-09-20T11:53:38+02:00" +categories: + - blog +authors: + - ffernandez92 +--- + + +As streaming data processing grows, so do its maintenance, complexity, and costs. +This post will explain how to efficiently scale pipelines using [Protobuf](https://protobuf.dev/), +ensuring they are reusable and quick to deploy. Our goal is to keep this process simple +for engineers to implement using [Beam YAML](https://beam.apache.org/documentation/sdks/yaml/). + + + +# Simplifying Pipelines with Beam YAML + +Creating a pipeline in Beam can be somewhat difficult, especially for newcomers with little experience with Beam. +Setting up the project, managing dependencies, and so on can be challenging. +Beam YAML helps eliminate most of the boilerplate code, +allowing you to focus solely on the most important part: data transformation. + +Some of the main key benefits include: + +* **Readability:** By using a declarative language ([YAML](https://yaml.org/)), we improve the human readability + aspect of the pipeline configuration. +* **Reusability:** It is much simpler to reuse the same components across different pipelines. +* **Maintainability:** It simplifies pipeline maintenance and updates. + +The following template shows an example of reading events from a [Kafka](https://kafka.apache.org/intro) topic and +writing them into [BigQuery](https://cloud.google.com/bigquery?hl=en). + +```yaml +pipeline: + transforms: + - type: ReadFromKafka + name: ReadProtoMovieEvents + config: + topic: 'TOPIC_NAME' + format: RAW/AVRO/JSON/PROTO + bootstrap_servers: 'BOOTSTRAP_SERVERS' + schema: 'SCHEMA' + - type: WriteToBigQuery + name: WriteMovieEvents + input: ReadProtoMovieEvents + config: + table: 'PROJECT_ID.DATASET.MOVIE_EVENTS_TABLE' + useAtLeastOnceSemantics: true + +options: + streaming: true + dataflow_service_options: [streaming_mode_at_least_once] +``` + +# Bringing It All Together + +### Let's create a simple proto event: + +```protobuf +// events/v1/movie_event.proto + +syntax = "proto3"; + +package event.v1; + +import "bq_field.proto"; +import "bq_table.proto"; +import "buf/validate/validate.proto"; +import "google/protobuf/wrappers.proto"; + +message MovieEvent { + option (gen_bq_schema.bigquery_opts).table_name = "movie_table"; + google.protobuf.StringValue event_id = 1 [(gen_bq_schema.bigquery).description = "Unique Event ID"]; + google.protobuf.StringValue user_id = 2 [(gen_bq_schema.bigquery).description = "Unique User ID"]; + google.protobuf.StringValue movie_id = 3 [(gen_bq_schema.bigquery).description = "Unique Movie ID"]; + google.protobuf.Int32Value rating = 4 [(buf.validate.field).int32 = { + // validates the average rating is at least 0 + gte: 0, + // validates the average rating is at most 100 + lte: 100 + }, (gen_bq_schema.bigquery).description = "Movie rating"]; + string event_dt = 5 [ + (gen_bq_schema.bigquery).type_override = "DATETIME", + (gen_bq_schema.bigquery).description = "UTC Datetime representing when we received this event. Format: YYYY-MM-DDTHH:MM:SS", + (buf.validate.field) = { + string: { + pattern: "^\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2}$" + }, + ignore_empty: false, + } + ]; +} +``` + +As you can see here, there are important points to consider. Since we are planning to write these events to BigQuery, +we have imported the *[bq_field](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_field.proto)* +and *[bq_table](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_table.proto)* proto. +These proto files help generate the BigQuery JSON schema. +In our example, we are also advocating for a shift-left approach, which means we want to move testing, quality, +and performance as early as possible in the development process. This is why we have included the *buf.validate* +elements to ensure that only valid events are generated from the source. + +Once we have our *movie_event.proto* in the *events/v1* folder, we can generate +the necessary [file descriptor](https://buf.build/docs/reference/descriptors). +Essentially, a file descriptor is a compiled representation of the schema that allows various tools and systems +to understand and work with Protobuf data dynamically. To simplify the process, we are using Buf in this example, +so we will need the following configuration files. + + +Buf configuration: + +```yaml +# buf.yaml + +version: v2 +deps: + - buf.build/googlecloudplatform/bq-schema-api + - buf.build/bufbuild/protovalidate +breaking: + use: + - FILE +lint: + use: + - DEFAULT +``` + +```yaml +# buf.gen.yaml + +version: v2 +managed: + enabled: true +plugins: + # Python Plugins + - remote: buf.build/protocolbuffers/python + out: gen/python + - remote: buf.build/grpc/python + out: gen/python + + # Java Plugins + - remote: buf.build/protocolbuffers/java:v25.2 + out: gen/maven/src/main/java + - remote: buf.build/grpc/java + out: gen/maven/src/main/java + + # BQ Schemas + - remote: buf.build/googlecloudplatform/bq-schema:v1.1.0 + out: protoc-gen/bq_schema + +``` + +Running the following two commands we will generate the necessary Java, Python, BigQuery schema, and Descriptor File: + +```bash +// Generate the buf.lock file +buf deps update + +// It will generate the descriptor in descriptor.binp. +buf build . -o descriptor.binp --exclude-imports + +// It will generate the Java, Python and BigQuery schema as described in buf.gen.yaml +buf generate --include-imports +``` + +# Let’s make our Beam YAML read proto: + +These are the modifications we need to make to the YAML file: + +```yaml +# movie_events_pipeline.yml + +pipeline: + transforms: + - type: ReadFromKafka + name: ReadProtoMovieEvents + config: + topic: 'movie_proto' + format: PROTO + bootstrap_servers: '' + file_descriptor_path: 'gs://my_proto_bucket/movie/v1.0.0/descriptor.binp' + message_name: 'event.v1.MovieEvent' + - type: WriteToBigQuery + name: WriteMovieEvents + input: ReadProtoMovieEvents + config: + table: '.raw.movie_table' + useAtLeastOnceSemantics: true +options: + streaming: true + dataflow_service_options: [streaming_mode_at_least_once] +``` + +As you can see, we just changed the format to be *PROTO* and added the *file_descriptor_path* and the *message_name*. + +### Let’s use Terraform to deploy it + +We can consider using [Terraform](https://www.terraform.io/) to deploy our Beam YAML pipeline +with [Dataflow](https://cloud.google.com/products/dataflow?hl=en) as the runner. +The following Terraform code example demonstrates how to achieve this: + +```hcl +// Enable Dataflow API. +resource "google_project_service" "enable_dataflow_api" { + project = var.gcp_project_id + service = "dataflow.googleapis.com" +} + +// DF Beam YAML +resource "google_dataflow_flex_template_job" "data_movie_job" { + provider = google-beta + project = var.gcp_project_id + name = "movie-proto-events" + container_spec_gcs_path = "gs://dataflow-templates-${var.gcp_region}/latest/flex/Yaml_Template" + region = var.gcp_region + on_delete = "drain" + machine_type = "n2d-standard-4" + enable_streaming_engine = true + subnetwork = var.subnetwork + skip_wait_on_job_termination = true + parameters = { + yaml_pipeline_file = "gs://${var.bucket_name}/yamls/${var.package_version}/movie_events_pipeline.yml" + max_num_workers = 40 + worker_zone = var.gcp_zone + } + depends_on = [google_project_service.enable_dataflow_api] +} +``` + +Assuming we have created the BigQuery table, which can also be done using Terraform and Proto as described earlier, +the previous code should create a Dataflow job using our Beam YAML code that reads Proto events from +Kafka and writes them into BigQuery. + +# Improvements and Conclusions + +Some potential improvements that can be done as part of community contributions to the previous Beam YAML code are: + +* **Supporting Schema Registries:** Integrate with schema registries such as Buf Registry or Apicurio for +better schema management. In the current solution, we generate the descriptors via Buf and store them in GCS. +We could store them in a schema registry instead. + + +* **Enhanced Monitoring:** Implement advanced monitoring and alerting mechanisms to quickly identify and address +issues in the data pipeline. + +As a conclusion, by leveraging Beam YAML and Protobuf, we have streamlined the creation and maintenance of +data processing pipelines, significantly reducing complexity. This approach ensures that engineers can more +efficiently implement and scale robust, reusable pipelines, compared to writing the equivalent Beam code manually. + +## Contributing + +Developers who wish to help build out and add functionalities are welcome to start contributing to the effort in the +Beam YAML module found [here](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/yaml). + +There is also a list of open [bugs](https://github.com/apache/beam/issues?q=is%3Aopen+is%3Aissue+label%3Ayaml) found +on the GitHub repo - now marked with the 'yaml' tag. + +While Beam YAML has been marked stable as of Beam 2.52, it is still under heavy development, with new features being +added with each release. Those who wish to be part of the design decisions and give insights to how the framework is +being used are highly encouraged to join the dev mailing list as those discussions will be directed there. A link to +the dev list can be found [here](https://beam.apache.org/community/contact-us/). From f717393db57271b159259f64c118b7e0fc4baed8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Thu, 10 Oct 2024 17:17:56 +0200 Subject: [PATCH 005/181] [blog] - update authors --- website/www/site/data/authors.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/website/www/site/data/authors.yml b/website/www/site/data/authors.yml index 6f16ae12f01a..324e675bfc8a 100644 --- a/website/www/site/data/authors.yml +++ b/website/www/site/data/authors.yml @@ -283,4 +283,7 @@ jkinard: email: jkinard@google.com jkim: name: Jaehyeon Kim - email: dottami@gmail.com \ No newline at end of file + email: dottami@gmail.com +ffernandez92: + name: Ferran Fernandez + email: ffernandez.upc@gmail.com From bc8eb6650c6d0c58fac3991b292eef196c9a6626 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Thu, 10 Oct 2024 17:35:33 +0200 Subject: [PATCH 006/181] [blog] - trail whitespace --- .../site/content/en/blog/beam-yaml-proto.md | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index 1d0590a99304..833dcee8ef32 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -20,7 +20,7 @@ See the License for the specific language governing permissions and limitations under the License. --> -As streaming data processing grows, so do its maintenance, complexity, and costs. +As streaming data processing grows, so do its maintenance, complexity, and costs. This post will explain how to efficiently scale pipelines using [Protobuf](https://protobuf.dev/), ensuring they are reusable and quick to deploy. Our goal is to keep this process simple for engineers to implement using [Beam YAML](https://beam.apache.org/documentation/sdks/yaml/). @@ -37,7 +37,7 @@ allowing you to focus solely on the most important part: data transformation. Some of the main key benefits include: * **Readability:** By using a declarative language ([YAML](https://yaml.org/)), we improve the human readability - aspect of the pipeline configuration. +aspect of the pipeline configuration. * **Reusability:** It is much simpler to reuse the same components across different pipelines. * **Maintainability:** It simplifies pipeline maintenance and updates. @@ -107,14 +107,14 @@ message MovieEvent { ``` As you can see here, there are important points to consider. Since we are planning to write these events to BigQuery, -we have imported the *[bq_field](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_field.proto)* -and *[bq_table](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_table.proto)* proto. +we have imported the *[bq_field](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_field.proto)* +and *[bq_table](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_table.proto)* proto. These proto files help generate the BigQuery JSON schema. In our example, we are also advocating for a shift-left approach, which means we want to move testing, quality, and performance as early as possible in the development process. This is why we have included the *buf.validate* elements to ensure that only valid events are generated from the source. -Once we have our *movie_event.proto* in the *events/v1* folder, we can generate +Once we have our *movie_event.proto* in the *events/v1* folder, we can generate the necessary [file descriptor](https://buf.build/docs/reference/descriptors). Essentially, a file descriptor is a compiled representation of the schema that allows various tools and systems to understand and work with Protobuf data dynamically. To simplify the process, we are using Buf in this example, @@ -169,7 +169,7 @@ Running the following two commands we will generate the necessary Java, Python, // Generate the buf.lock file buf deps update -// It will generate the descriptor in descriptor.binp. +// It will generate the descriptor in descriptor.binp. buf build . -o descriptor.binp --exclude-imports // It will generate the Java, Python and BigQuery schema as described in buf.gen.yaml @@ -208,8 +208,8 @@ As you can see, we just changed the format to be *PROTO* and added the *file_des ### Let’s use Terraform to deploy it -We can consider using [Terraform](https://www.terraform.io/) to deploy our Beam YAML pipeline -with [Dataflow](https://cloud.google.com/products/dataflow?hl=en) as the runner. +We can consider using [Terraform](https://www.terraform.io/) to deploy our Beam YAML pipeline +with [Dataflow](https://cloud.google.com/products/dataflow?hl=en) as the runner. The following Terraform code example demonstrates how to achieve this: ```hcl @@ -248,7 +248,7 @@ Kafka and writes them into BigQuery. Some potential improvements that can be done as part of community contributions to the previous Beam YAML code are: -* **Supporting Schema Registries:** Integrate with schema registries such as Buf Registry or Apicurio for +* **Supporting Schema Registries:** Integrate with schema registries such as Buf Registry or Apicurio for better schema management. In the current solution, we generate the descriptors via Buf and store them in GCS. We could store them in a schema registry instead. @@ -256,7 +256,7 @@ We could store them in a schema registry instead. * **Enhanced Monitoring:** Implement advanced monitoring and alerting mechanisms to quickly identify and address issues in the data pipeline. -As a conclusion, by leveraging Beam YAML and Protobuf, we have streamlined the creation and maintenance of +As a conclusion, by leveraging Beam YAML and Protobuf, we have streamlined the creation and maintenance of data processing pipelines, significantly reducing complexity. This approach ensures that engineers can more efficiently implement and scale robust, reusable pipelines, compared to writing the equivalent Beam code manually. From bade784f3e78968d048d0d8591e451756ff56ee9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Thu, 10 Oct 2024 17:46:57 +0200 Subject: [PATCH 007/181] [blog] - trail whitespace2 --- website/www/site/content/en/blog/beam-yaml-proto.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index 833dcee8ef32..44a3d0f4352f 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -41,7 +41,7 @@ aspect of the pipeline configuration. * **Reusability:** It is much simpler to reuse the same components across different pipelines. * **Maintainability:** It simplifies pipeline maintenance and updates. -The following template shows an example of reading events from a [Kafka](https://kafka.apache.org/intro) topic and +The following template shows an example of reading events from a [Kafka](https://kafka.apache.org/intro) topic and writing them into [BigQuery](https://cloud.google.com/bigquery?hl=en). ```yaml From f151824100fcfd6375403c0ee46e60346df82411 Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Thu, 10 Oct 2024 12:05:06 -0400 Subject: [PATCH 008/181] Remove beam logging in playground --- playground/backend/internal/preparers/python_preparers.go | 2 +- playground/backend/internal/preparers/python_preparers_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/playground/backend/internal/preparers/python_preparers.go b/playground/backend/internal/preparers/python_preparers.go index f050237492b1..4b3d556af861 100644 --- a/playground/backend/internal/preparers/python_preparers.go +++ b/playground/backend/internal/preparers/python_preparers.go @@ -26,7 +26,7 @@ import ( ) const ( - addLogHandlerCode = "import logging\nlogging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" + addLogHandlerCode = "" oneIndentation = " " findWithPipelinePattern = `(\s*)with.+Pipeline.+as (.+):` indentationPattern = `^(%s){0,1}\w+` diff --git a/playground/backend/internal/preparers/python_preparers_test.go b/playground/backend/internal/preparers/python_preparers_test.go index b2cfa7eccaac..549fe8431783 100644 --- a/playground/backend/internal/preparers/python_preparers_test.go +++ b/playground/backend/internal/preparers/python_preparers_test.go @@ -53,7 +53,7 @@ func TestGetPythonPreparers(t *testing.T) { } func Test_addCodeToFile(t *testing.T) { - wantCode := "import logging\nlogging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" + pyCode + wantCode := pyCode type args struct { args []interface{} From b20f36e2e584135ba069e3b9a82a596198d74dee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 13:47:50 +0200 Subject: [PATCH 009/181] Update website/www/site/content/en/blog/beam-yaml-proto.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- website/www/site/content/en/blog/beam-yaml-proto.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index 44a3d0f4352f..71fbef015cd0 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -21,7 +21,7 @@ limitations under the License. --> As streaming data processing grows, so do its maintenance, complexity, and costs. -This post will explain how to efficiently scale pipelines using [Protobuf](https://protobuf.dev/), +This post explains how to efficiently scale pipelines by using [Protobuf](https://protobuf.dev/), ensuring they are reusable and quick to deploy. Our goal is to keep this process simple for engineers to implement using [Beam YAML](https://beam.apache.org/documentation/sdks/yaml/). From c5c83b2e7bf47f642f36e3d12e80c7ff2883ec27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 13:48:34 +0200 Subject: [PATCH 010/181] Update website/www/site/content/en/blog/beam-yaml-proto.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- website/www/site/content/en/blog/beam-yaml-proto.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index 71fbef015cd0..7d95af50b8ca 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -22,7 +22,7 @@ limitations under the License. As streaming data processing grows, so do its maintenance, complexity, and costs. This post explains how to efficiently scale pipelines by using [Protobuf](https://protobuf.dev/), -ensuring they are reusable and quick to deploy. Our goal is to keep this process simple +which ensures that pipelines are reusable and quick to deploy. The goal is to keep this process simple for engineers to implement using [Beam YAML](https://beam.apache.org/documentation/sdks/yaml/). From b224fb9f9da79a47dad9d0a6d3f5c9d3f512990f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 13:48:41 +0200 Subject: [PATCH 011/181] Update website/www/site/content/en/blog/beam-yaml-proto.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- website/www/site/content/en/blog/beam-yaml-proto.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index 7d95af50b8ca..a5c39c1021e6 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -27,7 +27,7 @@ for engineers to implement using [Beam YAML](https://beam.apache.org/documentati -# Simplifying Pipelines with Beam YAML +## Simplify pipelines with Beam YAML Creating a pipeline in Beam can be somewhat difficult, especially for newcomers with little experience with Beam. Setting up the project, managing dependencies, and so on can be challenging. From 42db58d10a680eaa02d61b0a66a31ca92ca5d948 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 13:48:50 +0200 Subject: [PATCH 012/181] Update website/www/site/content/en/blog/beam-yaml-proto.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- website/www/site/content/en/blog/beam-yaml-proto.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index a5c39c1021e6..420b328efa29 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -29,7 +29,7 @@ for engineers to implement using [Beam YAML](https://beam.apache.org/documentati ## Simplify pipelines with Beam YAML -Creating a pipeline in Beam can be somewhat difficult, especially for newcomers with little experience with Beam. +Creating a pipeline in Beam can be somewhat difficult, especially for new Apache Beam users. Setting up the project, managing dependencies, and so on can be challenging. Beam YAML helps eliminate most of the boilerplate code, allowing you to focus solely on the most important part: data transformation. From d213d38ac1c8f9bc6739cb06d7bef462b7786d78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 13:49:02 +0200 Subject: [PATCH 013/181] Update website/www/site/content/en/blog/beam-yaml-proto.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- website/www/site/content/en/blog/beam-yaml-proto.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index 420b328efa29..1a3e70ac0eb7 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -31,7 +31,7 @@ for engineers to implement using [Beam YAML](https://beam.apache.org/documentati Creating a pipeline in Beam can be somewhat difficult, especially for new Apache Beam users. Setting up the project, managing dependencies, and so on can be challenging. -Beam YAML helps eliminate most of the boilerplate code, +By using Beam YAML, you can eliminate most of the boilerplate code, allowing you to focus solely on the most important part: data transformation. Some of the main key benefits include: From e6dc05f879ef68426b4cafd509c1fc486d517f45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 13:51:42 +0200 Subject: [PATCH 014/181] Update website/www/site/content/en/blog/beam-yaml-proto.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- website/www/site/content/en/blog/beam-yaml-proto.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index 1a3e70ac0eb7..4ddf06133753 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -32,7 +32,7 @@ for engineers to implement using [Beam YAML](https://beam.apache.org/documentati Creating a pipeline in Beam can be somewhat difficult, especially for new Apache Beam users. Setting up the project, managing dependencies, and so on can be challenging. By using Beam YAML, you can eliminate most of the boilerplate code, -allowing you to focus solely on the most important part: data transformation. +which allows you to focus on the most important part of the work: data transformation. Some of the main key benefits include: From 7b5bf93e3745de6dd77ff4ccee633a1169998696 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 13:51:52 +0200 Subject: [PATCH 015/181] Update website/www/site/content/en/blog/beam-yaml-proto.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- website/www/site/content/en/blog/beam-yaml-proto.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index 4ddf06133753..3cdfe9476eb3 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -266,7 +266,7 @@ Developers who wish to help build out and add functionalities are welcome to sta Beam YAML module found [here](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/yaml). There is also a list of open [bugs](https://github.com/apache/beam/issues?q=is%3Aopen+is%3Aissue+label%3Ayaml) found -on the GitHub repo - now marked with the 'yaml' tag. +on the GitHub repo - now marked with the `yaml` tag. While Beam YAML has been marked stable as of Beam 2.52, it is still under heavy development, with new features being added with each release. Those who wish to be part of the design decisions and give insights to how the framework is From cb3c673f512f3cb5daa6218f68e74d3d72ad8f99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 13:52:05 +0200 Subject: [PATCH 016/181] Update website/www/site/content/en/blog/beam-yaml-proto.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- website/www/site/content/en/blog/beam-yaml-proto.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index 3cdfe9476eb3..1c67add75c80 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -268,7 +268,7 @@ Beam YAML module found [here](https://github.com/apache/beam/tree/master/sdks/py There is also a list of open [bugs](https://github.com/apache/beam/issues?q=is%3Aopen+is%3Aissue+label%3Ayaml) found on the GitHub repo - now marked with the `yaml` tag. -While Beam YAML has been marked stable as of Beam 2.52, it is still under heavy development, with new features being +Although Beam YAML is marked stable as of Beam 2.52, it is still under heavy development, with new features being added with each release. Those who wish to be part of the design decisions and give insights to how the framework is being used are highly encouraged to join the dev mailing list as those discussions will be directed there. A link to the dev list can be found [here](https://beam.apache.org/community/contact-us/). From 94cb34a4a9b5709689f26c353943354ffae7f51f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 13:52:14 +0200 Subject: [PATCH 017/181] Update website/www/site/content/en/blog/beam-yaml-proto.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- website/www/site/content/en/blog/beam-yaml-proto.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index 1c67add75c80..9126135fae80 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -269,6 +269,6 @@ There is also a list of open [bugs](https://github.com/apache/beam/issues?q=is%3 on the GitHub repo - now marked with the `yaml` tag. Although Beam YAML is marked stable as of Beam 2.52, it is still under heavy development, with new features being -added with each release. Those who wish to be part of the design decisions and give insights to how the framework is +added with each release. Those who want to be part of the design decisions and give insights to how the framework is being used are highly encouraged to join the dev mailing list as those discussions will be directed there. A link to the dev list can be found [here](https://beam.apache.org/community/contact-us/). From 2af4991280febf5906b67dfbc80d8be0141b647c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 13:52:23 +0200 Subject: [PATCH 018/181] Update website/www/site/content/en/blog/beam-yaml-proto.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- website/www/site/content/en/blog/beam-yaml-proto.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index 9126135fae80..d83f8b8016c7 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -270,5 +270,5 @@ on the GitHub repo - now marked with the `yaml` tag. Although Beam YAML is marked stable as of Beam 2.52, it is still under heavy development, with new features being added with each release. Those who want to be part of the design decisions and give insights to how the framework is -being used are highly encouraged to join the dev mailing list as those discussions will be directed there. A link to +being used are highly encouraged to join the [dev mailing list](https://beam.apache.org/community/contact-us/), where those discussions are occurring. the dev list can be found [here](https://beam.apache.org/community/contact-us/). From 923f053cfb8d4a0232b904641457893e98f3f338 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 13:52:34 +0200 Subject: [PATCH 019/181] Update website/www/site/content/en/blog/beam-yaml-proto.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- website/www/site/content/en/blog/beam-yaml-proto.md | 1 - 1 file changed, 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index d83f8b8016c7..440f183c0277 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -271,4 +271,3 @@ on the GitHub repo - now marked with the `yaml` tag. Although Beam YAML is marked stable as of Beam 2.52, it is still under heavy development, with new features being added with each release. Those who want to be part of the design decisions and give insights to how the framework is being used are highly encouraged to join the [dev mailing list](https://beam.apache.org/community/contact-us/), where those discussions are occurring. -the dev list can be found [here](https://beam.apache.org/community/contact-us/). From f1bfaa30fe7fe271bd2a739b93e5da1f367bf1ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 13:54:44 +0200 Subject: [PATCH 020/181] Apply suggestions from code review Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- .../site/content/en/blog/beam-yaml-proto.md | 74 +++++++++---------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index 440f183c0277..db56154750d5 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -34,12 +34,11 @@ Setting up the project, managing dependencies, and so on can be challenging. By using Beam YAML, you can eliminate most of the boilerplate code, which allows you to focus on the most important part of the work: data transformation. -Some of the main key benefits include: +Some of the key benefits of Beam YAML include: -* **Readability:** By using a declarative language ([YAML](https://yaml.org/)), we improve the human readability -aspect of the pipeline configuration. -* **Reusability:** It is much simpler to reuse the same components across different pipelines. -* **Maintainability:** It simplifies pipeline maintenance and updates. +* **Readability:** By using a declarative language ([YAML](https://yaml.org/)), the pipeline configuration is more human readable. +* **Reusability:** Reusing the same components across different pipelines is simplified. +* **Maintainability:** Pipeline maintenance and updates are easier. The following template shows an example of reading events from a [Kafka](https://kafka.apache.org/intro) topic and writing them into [BigQuery](https://cloud.google.com/bigquery?hl=en). @@ -66,9 +65,11 @@ options: dataflow_service_options: [streaming_mode_at_least_once] ``` -# Bringing It All Together +## The complete workflow -### Let's create a simple proto event: +This section demonstrates the complete workflow for this pipeline. + +### Create a simple proto event ```protobuf // events/v1/movie_event.proto @@ -106,19 +107,18 @@ message MovieEvent { } ``` -As you can see here, there are important points to consider. Since we are planning to write these events to BigQuery, -we have imported the *[bq_field](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_field.proto)* -and *[bq_table](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_table.proto)* proto. +Because these events are written to BigQuery, +the [`bq_field`](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_field.proto) proto +and the [`bq_table`](https://buf.build/googlecloudplatform/bq-schema-api/file/main:bq_table.proto) proto are imported. These proto files help generate the BigQuery JSON schema. -In our example, we are also advocating for a shift-left approach, which means we want to move testing, quality, -and performance as early as possible in the development process. This is why we have included the *buf.validate* -elements to ensure that only valid events are generated from the source. +This example also demonstrates a shift-left approach, which moves testing, quality, +and performance as early as possible in the development process. For example, to ensure that only valid events are generated from the source, the `buf.validate` elements are included. -Once we have our *movie_event.proto* in the *events/v1* folder, we can generate +After you create the `movie_event.proto` proto in the `events/v1` folder, you can generate the necessary [file descriptor](https://buf.build/docs/reference/descriptors). -Essentially, a file descriptor is a compiled representation of the schema that allows various tools and systems -to understand and work with Protobuf data dynamically. To simplify the process, we are using Buf in this example, -so we will need the following configuration files. +A file descriptor is a compiled representation of the schema that allows various tools and systems +to understand and work with protobuf data dynamically. To simplify the process, this example uses Buf, +which requires the following configuration files. Buf configuration: @@ -163,22 +163,22 @@ plugins: ``` -Running the following two commands we will generate the necessary Java, Python, BigQuery schema, and Descriptor File: +Run the following two commands to generate the necessary Java, Python, BigQuery schema, and Descriptor file: ```bash // Generate the buf.lock file buf deps update -// It will generate the descriptor in descriptor.binp. +// It generates the descriptor in descriptor.binp. buf build . -o descriptor.binp --exclude-imports -// It will generate the Java, Python and BigQuery schema as described in buf.gen.yaml +// It generates the Java, Python and BigQuery schema as described in buf.gen.yaml buf generate --include-imports ``` -# Let’s make our Beam YAML read proto: +### Make the Beam YAML read proto -These are the modifications we need to make to the YAML file: +Make the following modifications to the to the YAML file: ```yaml # movie_events_pipeline.yml @@ -204,11 +204,11 @@ options: dataflow_service_options: [streaming_mode_at_least_once] ``` -As you can see, we just changed the format to be *PROTO* and added the *file_descriptor_path* and the *message_name*. +This step changes the format to `PROTO` and adds the `file_descriptor_path` and the `message_name`. -### Let’s use Terraform to deploy it +### Deploy the pipeline with Terraform -We can consider using [Terraform](https://www.terraform.io/) to deploy our Beam YAML pipeline +You can use [Terraform](https://www.terraform.io/) to deploy the Beam YAML pipeline with [Dataflow](https://cloud.google.com/products/dataflow?hl=en) as the runner. The following Terraform code example demonstrates how to achieve this: @@ -240,30 +240,30 @@ resource "google_dataflow_flex_template_job" "data_movie_job" { } ``` -Assuming we have created the BigQuery table, which can also be done using Terraform and Proto as described earlier, -the previous code should create a Dataflow job using our Beam YAML code that reads Proto events from +Assuming the BigQuery table exists, which you can do by using Terraform and Proto, +this code creates a Dataflow job by using the Beam YAML code that reads Proto events from Kafka and writes them into BigQuery. -# Improvements and Conclusions +## Improvements and conclusions -Some potential improvements that can be done as part of community contributions to the previous Beam YAML code are: +The following community contributions could improve the Beam YAML code in this example: -* **Supporting Schema Registries:** Integrate with schema registries such as Buf Registry or Apicurio for -better schema management. In the current solution, we generate the descriptors via Buf and store them in GCS. -We could store them in a schema registry instead. +* **Support schema registries:** Integrate with schema registries such as Buf Registry or Apicurio for +better schema management. The current workflow generates the descriptors by using Buf and store them in Google Cloud Storage. +The descriptors could be stored in a schema registry instead. * **Enhanced Monitoring:** Implement advanced monitoring and alerting mechanisms to quickly identify and address issues in the data pipeline. -As a conclusion, by leveraging Beam YAML and Protobuf, we have streamlined the creation and maintenance of +Leveraging Beam YAML and Protobuf lets us streamline the creation and maintenance of data processing pipelines, significantly reducing complexity. This approach ensures that engineers can more -efficiently implement and scale robust, reusable pipelines, compared to writing the equivalent Beam code manually. +efficiently implement and scale robust, reusable pipelines without needs to manually write Beam code. -## Contributing +## Contribute -Developers who wish to help build out and add functionalities are welcome to start contributing to the effort in the -Beam YAML module found [here](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/yaml). +Developers who want to help build out and add functionalities are welcome to start contributing to the effort in the +[Beam YAML module](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/yaml). There is also a list of open [bugs](https://github.com/apache/beam/issues?q=is%3Aopen+is%3Aissue+label%3Ayaml) found on the GitHub repo - now marked with the `yaml` tag. From f713b3e30e9ef1c6b56df81abdd32e22c20209ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ferran=20Fern=C3=A1ndez=20Garrido?= Date: Fri, 11 Oct 2024 23:00:54 +0200 Subject: [PATCH 021/181] Apply suggestions from code review Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> --- website/www/site/content/en/blog/beam-yaml-proto.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/website/www/site/content/en/blog/beam-yaml-proto.md b/website/www/site/content/en/blog/beam-yaml-proto.md index db56154750d5..995b59b978c1 100644 --- a/website/www/site/content/en/blog/beam-yaml-proto.md +++ b/website/www/site/content/en/blog/beam-yaml-proto.md @@ -20,6 +20,8 @@ See the License for the specific language governing permissions and limitations under the License. --> +# Efficient Streaming Data Processing with Beam YAML and Protobuf + As streaming data processing grows, so do its maintenance, complexity, and costs. This post explains how to efficiently scale pipelines by using [Protobuf](https://protobuf.dev/), which ensures that pipelines are reusable and quick to deploy. The goal is to keep this process simple @@ -31,7 +33,7 @@ for engineers to implement using [Beam YAML](https://beam.apache.org/documentati Creating a pipeline in Beam can be somewhat difficult, especially for new Apache Beam users. Setting up the project, managing dependencies, and so on can be challenging. -By using Beam YAML, you can eliminate most of the boilerplate code, +Beam YAML eliminates most of the boilerplate code, which allows you to focus on the most important part of the work: data transformation. Some of the key benefits of Beam YAML include: @@ -71,6 +73,8 @@ This section demonstrates the complete workflow for this pipeline. ### Create a simple proto event +The following code creates a simple movie event. + ```protobuf // events/v1/movie_event.proto From 80c7450f77f58c9e02087cc578ad8abe07648736 Mon Sep 17 00:00:00 2001 From: Jeffrey Kinard Date: Mon, 14 Oct 2024 18:57:01 -0400 Subject: [PATCH 022/181] Propogate field_descriptions to RowTypeConstraint Signed-off-by: Jeffrey Kinard --- sdks/python/apache_beam/typehints/row_type.py | 20 +++++++--- sdks/python/apache_beam/typehints/schemas.py | 1 + .../apache_beam/typehints/schemas_test.py | 38 +++++++++++++++++++ 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py index fd7885ad59c4..880a897bbbe8 100644 --- a/sdks/python/apache_beam/typehints/row_type.py +++ b/sdks/python/apache_beam/typehints/row_type.py @@ -49,7 +49,8 @@ def __init__( fields: Sequence[Tuple[str, type]], user_type, schema_options: Optional[Sequence[Tuple[str, Any]]] = None, - field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None): + field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, + field_descriptions: Optional[Dict[str, str]] = None): """For internal use only, no backwards comatibility guaratees. See https://beam.apache.org/documentation/programming-guide/#schemas-for-pl-types for guidance on creating PCollections with inferred schemas. @@ -96,6 +97,7 @@ def __init__( self._schema_options = schema_options or [] self._field_options = field_options or {} + self._field_descriptions = field_descriptions or {} @staticmethod def from_user_type( @@ -107,12 +109,15 @@ def from_user_type( fields = [(name, user_type.__annotations__[name]) for name in user_type._fields] + field_descriptions = getattr(user_type, '_field_descriptions', None) + if _user_type_is_generated(user_type): return RowTypeConstraint.from_fields( fields, schema_id=getattr(user_type, _BEAM_SCHEMA_ID), schema_options=schema_options, - field_options=field_options) + field_options=field_options, + field_descriptions=field_descriptions) # TODO(https://github.com/apache/beam/issues/22125): Add user API for # specifying schema/field options @@ -120,7 +125,8 @@ def from_user_type( fields=fields, user_type=user_type, schema_options=schema_options, - field_options=field_options) + field_options=field_options, + field_descriptions=field_descriptions) return None @@ -131,13 +137,15 @@ def from_fields( schema_options: Optional[Sequence[Tuple[str, Any]]] = None, field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, schema_registry: Optional[SchemaTypeRegistry] = None, + field_descriptions: Optional[Dict[str, str]] = None, ) -> RowTypeConstraint: return GeneratedClassRowTypeConstraint( fields, schema_id=schema_id, schema_options=schema_options, field_options=field_options, - schema_registry=schema_registry) + schema_registry=schema_registry, + field_descriptions=field_descriptions) def __call__(self, *args, **kwargs): # We make RowTypeConstraint callable (defers to constructing the user type) @@ -206,6 +214,7 @@ def __init__( schema_options: Optional[Sequence[Tuple[str, Any]]] = None, field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, schema_registry: Optional[SchemaTypeRegistry] = None, + field_descriptions: Optional[Dict[str, str]] = None, ): from apache_beam.typehints.schemas import named_fields_to_schema from apache_beam.typehints.schemas import named_tuple_from_schema @@ -224,7 +233,8 @@ def __init__( fields, user_type, schema_options=schema_options, - field_options=field_options) + field_options=field_options, + field_descriptions=field_descriptions) def __reduce__(self): return ( diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index ef82ca91044c..fea9b3534b0c 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -274,6 +274,7 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType: self.option_to_runner_api(option_tuple) for option_tuple in type_.field_options(field_name) ], + description=type_._field_descriptions.get(field_name, None), ) for (field_name, field_type) in type_._fields ], id=schema_id, diff --git a/sdks/python/apache_beam/typehints/schemas_test.py b/sdks/python/apache_beam/typehints/schemas_test.py index 5d38b16d9783..15144c6c2c17 100644 --- a/sdks/python/apache_beam/typehints/schemas_test.py +++ b/sdks/python/apache_beam/typehints/schemas_test.py @@ -489,6 +489,44 @@ def test_row_type_constraint_to_schema_with_field_options(self): ] self.assertEqual(list(field.options), expected) + def test_row_type_constraint_to_schema_with_field_descriptions(self): + row_type_with_options = row_type.RowTypeConstraint.from_fields( + [ + ('foo', np.int8), + ('bar', float), + ('baz', bytes), + ], + field_descriptions={ + 'foo': 'foo description', + 'bar': 'bar description', + 'baz': 'baz description', + }) + result_type = typing_to_runner_api(row_type_with_options) + + self.assertIsInstance(result_type, schema_pb2.FieldType) + self.assertEqual(result_type.WhichOneof("type_info"), "row_type") + + fields = result_type.row_type.schema.fields + + expected = [ + schema_pb2.Field( + name='foo', + description='foo description', + type=schema_pb2.FieldType(atomic_type=schema_pb2.BYTE), + ), + schema_pb2.Field( + name='bar', + description='bar description', + type=schema_pb2.FieldType(atomic_type=schema_pb2.DOUBLE), + ), + schema_pb2.Field( + name='baz', + description='baz description', + type=schema_pb2.FieldType(atomic_type=schema_pb2.BYTES), + ), + ] + self.assertEqual(list(fields), expected) + def assert_namedtuple_equivalent(self, actual, expected): # Two types are only considered equal if they are literally the same # object (i.e. `actual == expected` is the same as `actual is expected` in From ff1e0cc806e960dcfc42a164ea5a26d8542af525 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 17 Oct 2024 07:54:03 -0400 Subject: [PATCH 023/181] Note Flink 1.19 support in CHANGES.md --- CHANGES.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 840c60fc1149..2ce6a9710224 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -25,8 +25,6 @@ * New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). * New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). -* Added support for Flink 1.19 - ## I/Os * Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). @@ -62,6 +60,7 @@ * New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). * New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). * [Python] Introduce Managed Transforms API ([#31495](https://github.com/apache/beam/pull/31495)) +* Flink 1.19 support added ([#32648](https://github.com/apache/beam/pull/32648)) ## I/Os From acd52f9748f50f98c213b6ec906bbb902abcd559 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 17 Oct 2024 09:14:43 -0700 Subject: [PATCH 024/181] Apply suggestions from code review Co-authored-by: tvalentyn --- .../src/main/java/org/apache/beam/sdk/transforms/Flatten.java | 4 ++-- .../src/main/java/org/apache/beam/sdk/transforms/Tee.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java index 6d785c3bc591..159f92cd5e87 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Flatten.java @@ -84,7 +84,7 @@ public static Iterables iterables() { } /** - * Returns a {@link PTransform} that flattens the input {@link PCollection} with a given a {@link + * Returns a {@link PTransform} that flattens the input {@link PCollection} with a given {@link * PCollection} resulting in a {@link PCollection} containing all the elements of both {@link * PCollection}s as its output. * @@ -134,7 +134,7 @@ public String getKindString() { * output of {@code other} and then applying {@link #pCollections()}, but has the advantage that * it can be more easily used inline. * - *

Both {@cpde PCollections} must have equal {@link WindowFn}s. The output elements of {@code + *

Both {@code PCollections} must have equal {@link WindowFn}s. The output elements of {@code * Flatten} are in the same windows and have the same timestamps as their corresponding input * elements. The output {@code PCollection} will have the same {@link WindowFn} as both inputs. * diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java index bb65cbf94632..492a1cc84f74 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Tee.java @@ -23,7 +23,7 @@ /** * A PTransform that returns its input, but also applies its input to an auxiliary PTransform, akin - * to the shell {@code tee} command. + * to the shell {@code tee} command, which is named after the T-splitter used in plumbing. * *

This can be useful to write out or otherwise process an intermediate transform without * breaking the linear flow of a chain of transforms, e.g. From 177c9ab2af445524c6a444e8e92c98e3d4bbf299 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 17 Oct 2024 09:49:26 -0700 Subject: [PATCH 025/181] Add example of FlattenWith accepting a root PTransform. --- .../apache_beam/examples/snippets/snippets.py | 21 +++++++++++++++++++ .../examples/snippets/snippets_test.py | 7 +++++++ .../en/documentation/programming-guide.md | 11 ++++++++++ 3 files changed, 39 insertions(+) diff --git a/sdks/python/apache_beam/examples/snippets/snippets.py b/sdks/python/apache_beam/examples/snippets/snippets.py index 2636f7d2637d..a16e4c7eef61 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets.py +++ b/sdks/python/apache_beam/examples/snippets/snippets.py @@ -1176,6 +1176,27 @@ def model_multiple_pcollections_flatten_with(contents, output_path): merged | beam.io.WriteToText(output_path) +def model_multiple_pcollections_flatten_with_transform(contents, output_path): + """Merging output of PTransform with FlattenWith.""" + some_hash_fn = lambda s: ord(s[0]) + partition_fn = lambda element, partitions: some_hash_fn(element) % partitions + import apache_beam as beam + with TestPipeline() as pipeline: # Use TestPipeline for testing. + + pcoll = pipeline | beam.Create(contents) + SomeTransform = lambda: beam.Map(lambda x: x) + SomeOtherTransform = lambda: beam.Map(lambda x: x) + + # [START model_multiple_pcollections_flatten_with] + merged = ( + pcoll + | SomeTransform() + | beam.FlattenWith(beam.Create(['x', 'y', 'z'])) + | SomeOtherTransform()) + # [END model_multiple_pcollections_flatten_with] + merged | beam.io.WriteToText(output_path) + + def model_multiple_pcollections_partition(contents, output_path): """Splitting a PCollection with Partition.""" some_hash_fn = lambda s: ord(s[0]) diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py index 0560e9710f03..54a57673b5f4 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets_test.py +++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py @@ -923,6 +923,13 @@ def test_model_multiple_pcollections_flatten_with(self): snippets.model_multiple_pcollections_flatten_with(contents, result_path) self.assertEqual(contents, self.get_output(result_path)) + def test_model_multiple_pcollections_flatten_with_transform(self): + contents = ['a', 'b', 'c', 'd', 'e', 'f'] + result_path = self.create_temp_file() + snippets.model_multiple_pcollections_flatten_with_transform( + contents, result_path) + self.assertEqual(contents + ['x', 'y', 'z'], self.get_output(result_path)) + def test_model_multiple_pcollections_partition(self): contents = [17, 42, 64, 32, 0, 99, 53, 89] result_path = self.create_temp_file() diff --git a/website/www/site/content/en/documentation/programming-guide.md b/website/www/site/content/en/documentation/programming-guide.md index cdf82d566a4f..04f7bf62499d 100644 --- a/website/www/site/content/en/documentation/programming-guide.md +++ b/website/www/site/content/en/documentation/programming-guide.md @@ -2077,6 +2077,17 @@ transform to merge PCollections into an output PCollection in a manner more comp {{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten_with >}} {{< /highlight >}} +{{< paragraph class="language-py" >}} +`FlattenWith` can take root `PCollection`-producing transforms +(such as `Create` and `Read`) as well as already constructed PCollections, +and will apply them and flatten their outputs into the resulting output +PCollection. +{{< /paragraph >}} + +{{< highlight py >}} +{{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten_with_transforms >}} +{{< /highlight >}} + {{< highlight go >}} // Flatten accepts any number of PCollections of the same element type. // Returns a single PCollection that contains all of the elements in input PCollections. From b88df7b12fc73eb573afce5f989a4341c1dafc99 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Thu, 17 Oct 2024 12:51:58 -0400 Subject: [PATCH 026/181] Using single temp maven local repository for release validation tasks --- .../beam_PostRelease_NightlySnapshot.yml | 7 ++++ .../beam/gradle/BeamModulePlugin.groovy | 5 +++ release/src/main/groovy/TestScripts.groovy | 38 ++++++++++++------- 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/.github/workflows/beam_PostRelease_NightlySnapshot.yml b/.github/workflows/beam_PostRelease_NightlySnapshot.yml index 9b7fb2af2f2d..0c4144af1a4f 100644 --- a/.github/workflows/beam_PostRelease_NightlySnapshot.yml +++ b/.github/workflows/beam_PostRelease_NightlySnapshot.yml @@ -59,6 +59,9 @@ jobs: uses: ./.github/actions/setup-environment-action with: java-version: default + - name: Setup temp local maven + id: setup_local_maven + run: echo "NEW_TEMP_DIR=$(mktemp -d)" >> $GITHUB_OUTPUT - name: run PostRelease validation script uses: ./.github/actions/gradle-command-self-hosted-action with: @@ -66,3 +69,7 @@ jobs: arguments: | -Pver='${{ github.event.inputs.RELEASE }}' \ -Prepourl='${{ github.event.inputs.SNAPSHOT_URL }}' \ + -PmavenLocalPath='${{ steps.setup_local_maven.outputs.NEW_TEMP_DIR }}' + - name: Clean up local maven + if: steps.setup_local_maven.outcome == 'success' + run: rm -rf '${{ steps.setup_local_maven.outputs.NEW_TEMP_DIR }}' diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index dd3b129e6c34..576b8defb40b 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -2513,6 +2513,8 @@ class BeamModulePlugin implements Plugin { def taskName = "run${config.type}Java${config.runner}" def releaseVersion = project.findProperty('ver') ?: project.version def releaseRepo = project.findProperty('repourl') ?: 'https://repository.apache.org/content/repositories/snapshots' + // shared maven local path for maven archetype projects + def sharedMavenLocal = project.findProperty('mavenLocalPath') ?: '' def argsNeeded = [ "--ver=${releaseVersion}", "--repourl=${releaseRepo}" @@ -2532,6 +2534,9 @@ class BeamModulePlugin implements Plugin { if (config.pubsubTopic) { argsNeeded.add("--pubsubTopic=${config.pubsubTopic}") } + if (sharedMavenLocal) { + argsNeeded.add("--mavenLocalPath=${sharedMavenLocal}") + } project.evaluationDependsOn(':release') project.task(taskName, dependsOn: ':release:classes', type: JavaExec) { group = "Verification" diff --git a/release/src/main/groovy/TestScripts.groovy b/release/src/main/groovy/TestScripts.groovy index d5042aa61941..dc2438007ac1 100644 --- a/release/src/main/groovy/TestScripts.groovy +++ b/release/src/main/groovy/TestScripts.groovy @@ -36,6 +36,7 @@ class TestScripts { static String gcsBucket static String bqDataset static String pubsubTopic + static String mavenLocalPath } def TestScripts(String[] args) { @@ -47,6 +48,7 @@ class TestScripts { cli.gcsBucket(args:1, 'Google Cloud Storage Bucket') cli.bqDataset(args:1, "BigQuery Dataset") cli.pubsubTopic(args:1, "PubSub Topic") + cli.mavenLocalPath(args:1, "Maven local path") def options = cli.parse(args) var.repoUrl = options.repourl @@ -73,6 +75,10 @@ class TestScripts { var.pubsubTopic = options.pubsubTopic println "PubSub Topic: ${var.pubsubTopic}" } + if (options.mavenLocalPath) { + var.mavenLocalPath = options.mavenLocalPath + println "Maven local path: ${var.mavenLocalPath}" + } } def ver() { @@ -189,11 +195,16 @@ class TestScripts { } } - // Run a maven command, setting up a new local repository and a settings.xml with a custom repository + // Run a maven command, setting up a new local repository and a settings.xml with a custom repository if needed private String _mvn(String args) { - def m2 = new File(var.startDir, ".m2/repository") + String mvnlocalPath = var.mavenLocalPath + if (!(var.mavenLocalPath)) { + mvnlocalPath = var.startDir + } + def m2 = new File(mvnlocalPath, ".m2/repository") m2.mkdirs() - def settings = new File(var.startDir, "settings.xml") + def settings = new File(mvnlocalPath, "settings.xml") + if(!settings.exists()) { settings.write """ ${m2.absolutePath} @@ -209,16 +220,17 @@ class TestScripts { - """ - def cmd = "mvn ${args} -s ${settings.absolutePath} -Ptestrel -B" - String path = System.getenv("PATH"); - // Set the path on jenkins executors to use a recent maven - // MAVEN_HOME is not set on some executors, so default to 3.5.2 - String maven_home = System.getenv("MAVEN_HOME") ?: '/home/jenkins/tools/maven/apache-maven-3.5.4' - println "Using maven ${maven_home}" - def mvnPath = "${maven_home}/bin" - def setPath = "export PATH=\"${mvnPath}:${path}\" && " - return _execute(setPath + cmd) + """ + } + def cmd = "mvn ${args} -s ${settings.absolutePath} -Ptestrel -B" + String path = System.getenv("PATH"); + // Set the path on jenkins executors to use a recent maven + // MAVEN_HOME is not set on some executors, so default to 3.5.2 + String maven_home = System.getenv("MAVEN_HOME") ?: '/usr/local/maven' + println "Using maven ${maven_home}" + def mvnPath = "${maven_home}/bin" + def setPath = "export PATH=\"${mvnPath}:${path}\" && " + return _execute(setPath + cmd) } // Clean up and report error From 418a1d40fff8924de3e13dee19fefe8a5cfb50e8 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 17 Oct 2024 09:03:35 -0700 Subject: [PATCH 027/181] Use WindowedValueParam for interactive cache. --- .../apache_beam/runners/interactive/caching/reify.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/runners/interactive/caching/reify.py b/sdks/python/apache_beam/runners/interactive/caching/reify.py index ce82785b2585..c82033dc1b9b 100644 --- a/sdks/python/apache_beam/runners/interactive/caching/reify.py +++ b/sdks/python/apache_beam/runners/interactive/caching/reify.py @@ -28,7 +28,6 @@ import apache_beam as beam from apache_beam.runners.interactive import cache_manager as cache from apache_beam.testing import test_stream -from apache_beam.transforms.window import WindowedValue READ_CACHE = 'ReadCache_' WRITE_CACHE = 'WriteCache_' @@ -40,13 +39,8 @@ class Reify(beam.DoFn): Internally used to capture window info with each element into cache for replayability. """ - def process( - self, - e, - w=beam.DoFn.WindowParam, - p=beam.DoFn.PaneInfoParam, - t=beam.DoFn.TimestampParam): - yield test_stream.WindowedValueHolder(WindowedValue(e, t, [w], p)) + def process(self, e, wv=beam.DoFn.WindowedValueParam): + yield test_stream.WindowedValueHolder(wv) class Unreify(beam.DoFn): From d363a1712c37d8f6ba34990c65a9b1aaca4a60e8 Mon Sep 17 00:00:00 2001 From: Jeffrey Kinard Date: Thu, 17 Oct 2024 18:57:19 -0400 Subject: [PATCH 028/181] Fix failing YAML mapping IT Signed-off-by: Jeffrey Kinard --- sdks/python/apache_beam/yaml/tests/map.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/yaml/tests/map.yaml b/sdks/python/apache_beam/yaml/tests/map.yaml index b676966ad6bd..31fb442085fb 100644 --- a/sdks/python/apache_beam/yaml/tests/map.yaml +++ b/sdks/python/apache_beam/yaml/tests/map.yaml @@ -30,8 +30,8 @@ pipelines: config: append: true fields: - named_field: element literal_int: 10 + named_field: element literal_float: 1.5 literal_str: '"abc"' @@ -42,5 +42,5 @@ pipelines: - type: AssertEqual config: elements: - - {element: 100, named_field: 100, literal_int: 10, literal_float: 1.5, literal_str: "abc"} - - {element: 200, named_field: 200, literal_int: 10, literal_float: 1.5, literal_str: "abc"} + - {element: 100, literal_int: 10, named_field: 100, literal_float: 1.5, literal_str: "abc"} + - {element: 200, literal_int: 10, named_field: 200, literal_float: 1.5, literal_str: "abc"} From 621cdfbb9284d46cea97eb239e8458626a549248 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 17 Oct 2024 12:07:35 -0700 Subject: [PATCH 029/181] Fix doc references. --- sdks/python/apache_beam/examples/snippets/snippets.py | 4 ++-- sdks/python/scripts/generate_pydoc.sh | 3 +++ .../www/site/content/en/documentation/programming-guide.md | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/examples/snippets/snippets.py b/sdks/python/apache_beam/examples/snippets/snippets.py index a16e4c7eef61..c849af4a00b3 100644 --- a/sdks/python/apache_beam/examples/snippets/snippets.py +++ b/sdks/python/apache_beam/examples/snippets/snippets.py @@ -1187,13 +1187,13 @@ def model_multiple_pcollections_flatten_with_transform(contents, output_path): SomeTransform = lambda: beam.Map(lambda x: x) SomeOtherTransform = lambda: beam.Map(lambda x: x) - # [START model_multiple_pcollections_flatten_with] + # [START model_multiple_pcollections_flatten_with_transform] merged = ( pcoll | SomeTransform() | beam.FlattenWith(beam.Create(['x', 'y', 'z'])) | SomeOtherTransform()) - # [END model_multiple_pcollections_flatten_with] + # [END model_multiple_pcollections_flatten_with_transform] merged | beam.io.WriteToText(output_path) diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index 3462429190c8..827df30861cb 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -244,6 +244,9 @@ ignore_identifiers = [ # IPython Magics py:class reference target not found 'IPython.core.magic.Magics', + + # Type variables. + 'apache_beam.transforms.util.T', ] ignore_references = [ 'BeamIOError', diff --git a/website/www/site/content/en/documentation/programming-guide.md b/website/www/site/content/en/documentation/programming-guide.md index 04f7bf62499d..f4058e604288 100644 --- a/website/www/site/content/en/documentation/programming-guide.md +++ b/website/www/site/content/en/documentation/programming-guide.md @@ -2085,7 +2085,7 @@ PCollection. {{< /paragraph >}} {{< highlight py >}} -{{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten_with_transforms >}} +{{< code_sample "sdks/python/apache_beam/examples/snippets/snippets.py" model_multiple_pcollections_flatten_with_transform >}} {{< /highlight >}} {{< highlight go >}} From 6e3516baf2894b806e9cd3592257ee896c03fe15 Mon Sep 17 00:00:00 2001 From: liferoad Date: Thu, 17 Oct 2024 21:33:33 -0400 Subject: [PATCH 030/181] Revert "Update pyproject.toml by using grpcio-tools==1.65.5" --- sdks/python/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/pyproject.toml b/sdks/python/pyproject.toml index a99599a2ce2b..4eb827297019 100644 --- a/sdks/python/pyproject.toml +++ b/sdks/python/pyproject.toml @@ -21,7 +21,7 @@ requires = [ "setuptools", "wheel>=0.36.0", - "grpcio-tools==1.65.5", + "grpcio-tools==1.62.1", "mypy-protobuf==3.5.0", # Avoid https://github.com/pypa/virtualenv/issues/2006 "distlib==0.3.7", From b412928be498065feae40bc8d14b79d9bcda6f30 Mon Sep 17 00:00:00 2001 From: Robert Burke Date: Fri, 18 Oct 2024 06:36:10 -0700 Subject: [PATCH 031/181] [#32847][prism] Add Github Action for Prism as a Python precommit (#32845) * Add Github Action for Prism as a Python precommit * Update the execution condition. --------- Co-authored-by: lostluck <13907733+lostluck@users.noreply.github.com> --- .../workflows/beam_PreCommit_Prism_Python.yml | 109 ++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 .github/workflows/beam_PreCommit_Prism_Python.yml diff --git a/.github/workflows/beam_PreCommit_Prism_Python.yml b/.github/workflows/beam_PreCommit_Prism_Python.yml new file mode 100644 index 000000000000..5eb26d139ef5 --- /dev/null +++ b/.github/workflows/beam_PreCommit_Prism_Python.yml @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: PreCommit Prism Python + +on: + push: + tags: ['v*'] + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/python/**' + - 'release/**' + - '.github/workflows/beam_PreCommit_Prism_Python.yml' + pull_request_target: + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/python/**' + - 'release/**' + - 'release/trigger_all_tests.json' + - '.github/trigger_files/beam_PreCommit_Prism_Python.json' + issue_comment: + types: [created] + schedule: + - cron: '30 2/6 * * *' + workflow_dispatch: + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: read + checks: read + contents: read + deployments: read + id-token: none + issues: read + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.event.pull_request.head.label || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PreCommit_Prism_Python: + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) + timeout-minutes: 120 + runs-on: ['self-hosted', ubuntu-20.04, main] + strategy: + fail-fast: false + matrix: + job_name: ['beam_PreCommit_Prism_Python'] + job_phrase: ['Run Prism_Python PreCommit'] + python_version: ['3.9', '3.12'] + if: | + github.event_name == 'push' || + github.event_name == 'pull_request_target' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + startsWith(github.event.comment.body, 'Run Prism_Python PreCommit') + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} ${{ matrix.python_version }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }} ${{ matrix.python_version }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + with: + java-version: default + python-version: | + ${{ matrix.python_version }} + 3.9 + - name: Set PY_VER_CLEAN + id: set_py_ver_clean + run: | + PY_VER=${{ matrix.python_version }} + PY_VER_CLEAN=${PY_VER//.} + echo "py_ver_clean=$PY_VER_CLEAN" >> $GITHUB_OUTPUT + - name: run Prism Python Validates Runner script + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :sdks:python:test-suites:portable:py${{steps.set_py_ver_clean.outputs.py_ver_clean}}:prismValidatesRunner \ No newline at end of file From 4a4da907003ec2eaca2824179fc79a46e28d576a Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Fri, 18 Oct 2024 10:18:58 -0400 Subject: [PATCH 032/181] Follow up website and change.md after 2.60 release (#32853) * Update release date in CHANGE.md * Update latest version for website --- CHANGES.md | 2 +- website/www/site/config.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 766f74fc3be0..6e589c318dd9 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -97,7 +97,7 @@ * ([#X](https://github.com/apache/beam/issues/X)). -# [2.60.0] - Unreleased +# [2.60.0] - 2024-10-17 ## Highlights diff --git a/website/www/site/config.toml b/website/www/site/config.toml index e937289fbde7..d769f8434a7f 100644 --- a/website/www/site/config.toml +++ b/website/www/site/config.toml @@ -104,7 +104,7 @@ github_project_repo = "https://github.com/apache/beam" [params] description = "Apache Beam is an open source, unified model and set of language-specific SDKs for defining and executing data processing workflows, and also data ingestion and integration flows, supporting Enterprise Integration Patterns (EIPs) and Domain Specific Languages (DSLs). Dataflow pipelines simplify the mechanics of large-scale batch and streaming data processing and can run on a number of runtimes like Apache Flink, Apache Spark, and Google Cloud Dataflow (a cloud service). Beam also brings DSL in different languages, allowing users to easily implement their data integration processes." -release_latest = "2.59.0" +release_latest = "2.60.0" # The repository and branch where the files live in Github or Colab. This is used # to serve and stage from your local branch, but publish to the master branch. # e.g. https://github.com/{{< param branch_repo >}}/path/to/notebook.ipynb From 79528e17b5f3a959b9d52087eeb30a6fb4806f0f Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Fri, 18 Oct 2024 11:14:31 -0400 Subject: [PATCH 033/181] Add RAG to docs (#32859) --- .../ml/large-language-modeling.md | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/website/www/site/content/en/documentation/ml/large-language-modeling.md b/website/www/site/content/en/documentation/ml/large-language-modeling.md index 90bbd43383c0..b8bd0704d20e 100644 --- a/website/www/site/content/en/documentation/ml/large-language-modeling.md +++ b/website/www/site/content/en/documentation/ml/large-language-modeling.md @@ -170,3 +170,32 @@ class MyModelHandler(): def run_inference(self, batch: Sequence[str], model: MyWrapper, inference_args): return model.predict(unpickleable_object) ``` + +## RAG and Prompt Engineering in Beam + +Beam is also an excellent tool for improving the quality of your LLM prompts using Retrieval Augmented Generation (RAG). +Retrieval augmented generation is a technique that enhances large language models (LLMs) by connecting them to external knowledge sources. +This allows the LLM to access and process real-time information, improving the accuracy, relevance, and factuality of its responses. + +Beam has several mechanisms to make this process simpler: + +1. Beam's [MLTransform](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.embeddings.html) provides an embeddings package to generate the embeddings used for RAG. You can also use RunInference to generate embeddings if you have a model without an embeddings handler. +2. Beam's [Enrichment transform](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) makes it easy to look up embeddings or other information in an external storage system like a [vector database](https://www.pinecone.io/learn/vector-database/). + +Collectively, you can use these to perform RAG using the following steps: + +**Pipeline 1 - generate knowledge base:** + +1. Ingest data from external source using one of [Beam's IO connectors](https://beam.apache.org/documentation/io/connectors/) +2. Generate embeddings on that data using [MLTransform](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.embeddings.html) +3. Write those embeddings to a vector DB using a [ParDo](https://beam.apache.org/documentation/programming-guide/#pardo) + +**Pipeline 2 - use knowledge base to perform RAG:** + +1. Ingest data from external source using one of [Beam's IO connectors](https://beam.apache.org/documentation/io/connectors/) +2. Generate embeddings on that data using [MLTransform](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.transforms.embeddings.html) +3. Enrich that data with additional embeddings from your vector DB using [Enrichment](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) +4. Use that enriched data to prompt your LLM with [RunInference](https://beam.apache.org/documentation/transforms/python/elementwise/runinference/) +5. Write that data to your desired sink using one of [Beam's IO connectors](https://beam.apache.org/documentation/io/connectors/) + +To view an example pipeline performing RAG, see https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/rag_usecase/beam_rag_notebook.ipynb From fa9eb2fe17f5f96b40275fe7b0a3981f4a52e0df Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Fri, 18 Oct 2024 11:15:02 -0400 Subject: [PATCH 034/181] Enrichment pydoc improvements (#32861) --- .../apache_beam/yaml/yaml_enrichment.py | 62 +++++++------------ 1 file changed, 21 insertions(+), 41 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment.py b/sdks/python/apache_beam/yaml/yaml_enrichment.py index 00f2a5c1b1d1..9bea17f78fdd 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment.py @@ -48,7 +48,19 @@ def enrichment_transform( """ The Enrichment transform allows you to dynamically enhance elements in a pipeline by performing key-value - lookups against external services like APIs or databases. + lookups against external services like APIs or databases. + + Example Usage:: + + - type: Enrichment + config: + enrichment_handler: 'BigTable' + handler_config: + project_id: 'apache-beam-testing' + instance_id: 'beam-test' + table_id: 'bigtable-enrichment-test' + row_key: 'product_id' + timeout: 30 Args: enrichment_handler: Specifies the source from @@ -58,46 +70,14 @@ def enrichment_transform( "BigTable", "FeastFeatureStore", "VertexAIFeatureStore"]. handler_config: Specifies the parameters for - the respective enrichment_handler in a dictionary format. - BigQuery = ( - "BigQuery: " - "project, table_name, row_restriction_template, " - "fields, column_names, "condition_value_fn, " - "query_fn, min_batch_size, max_batch_size" - ) - - BigTable = ( - "BigTable: " - "project_id, instance_id, table_id, " - "row_key, row_filter, app_profile_id, " - "encoding, ow_key_fn, exception_level, include_timestamp" - ) - - FeastFeatureStore = ( - "FeastFeatureStore: " - "feature_store_yaml_path, feature_names, " - "feature_service_name, full_feature_names, " - "entity_row_fn, exception_level" - ) - - VertexAIFeatureStore = ( - "VertexAIFeatureStore: " - "project, location, api_endpoint, feature_store_name, " - "feature_view_name, row_key, exception_level" - ) - - Example Usage: - - - type: Enrichment - config: - enrichment_handler: 'BigTable' - handler_config: - project_id: 'apache-beam-testing' - instance_id: 'beam-test' - table_id: 'bigtable-enrichment-test' - row_key: 'product_id' - timeout: 30 - + the respective enrichment_handler in a dictionary format. + To see the full set of handler_config parameters, see + their corresponding doc pages: + + - :class:`~apache_beam.transforms.enrichment_handlers.bigquery.BigQueryEnrichmentHandler` # pylint: disable=line-too-long + - :class:`~apache_beam.transforms.enrichment_handlers.bigtable.BigTableEnrichmentHandler` # pylint: disable=line-too-long + - :class:`~apache_beam.transforms.enrichment_handlers.feast_feature_store.FeastFeatureStoreEnrichmentHandler` # pylint: disable=line-too-long + - :class:`~apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store.VertexAIFeatureStoreEnrichmentHandler` # pylint: disable=line-too-long """ options.YamlOptions.check_enabled(pcoll.pipeline, 'Enrichment') From 3fd3db30d07e1f25ec91df9dece707c371977a52 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 17 Oct 2024 08:26:39 -0400 Subject: [PATCH 035/181] Drop Flink 1.15 support --- CHANGES.md | 1 + gradle.properties | 2 +- .../runner-concepts/description.md | 8 ++-- runners/flink/1.15/build.gradle | 25 ----------- .../1.15/job-server-container/build.gradle | 26 ----------- runners/flink/1.15/job-server/build.gradle | 31 ------------- .../types/CoderTypeSerializer.java | 0 .../streaming/MemoryStateBackendWrapper.java | 0 .../flink/streaming/StreamSources.java | 0 runners/flink/flink_runner.gradle | 43 ++++++------------- .../src/apache_beam/runners/flink.ts | 2 +- settings.gradle.kts | 4 -- .../content/en/documentation/runners/flink.md | 3 +- 13 files changed, 21 insertions(+), 124 deletions(-) delete mode 100644 runners/flink/1.15/build.gradle delete mode 100644 runners/flink/1.15/job-server-container/build.gradle delete mode 100644 runners/flink/1.15/job-server/build.gradle rename runners/flink/{1.15 => 1.16}/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java (100%) rename runners/flink/{1.15 => 1.16}/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java (100%) rename runners/flink/{1.15 => 1.16}/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java (100%) diff --git a/CHANGES.md b/CHANGES.md index 30f904d7733a..2ccaeeb49f7e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -82,6 +82,7 @@ ## Deprecations +* Removed support for Flink 1.15 * X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). ## Bugfixes diff --git a/gradle.properties b/gradle.properties index f6e143690a34..868c7501ac31 100644 --- a/gradle.properties +++ b/gradle.properties @@ -39,6 +39,6 @@ docker_image_default_repo_root=apache docker_image_default_repo_prefix=beam_ # supported flink versions -flink_versions=1.15,1.16,1.17,1.18,1.19 +flink_versions=1.16,1.17,1.18,1.19 # supported python versions python_versions=3.8,3.9,3.10,3.11,3.12 diff --git a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md index 6eb1c04e966a..063e7f35f876 100644 --- a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md +++ b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md @@ -191,8 +191,8 @@ $ wordcount --input gs://dataflow-samples/shakespeare/kinglear.txt \ {{if (eq .Sdk "java")}} ##### Portable -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.15`, `Flink 1.16`, `Flink 1.17`, `Flink 1.18`. -2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest` +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. +2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.19_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: ``` @@ -233,8 +233,8 @@ mvn exec:java -Dexec.mainClass=org.apache.beam.examples.WordCount \ {{end}} {{if (eq .Sdk "python")}} -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.10`, `Flink 1.11`, `Flink 1.12`, `Flink 1.13`, `Flink 1.14`. -2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest` +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. +2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.19_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: ``` diff --git a/runners/flink/1.15/build.gradle b/runners/flink/1.15/build.gradle deleted file mode 100644 index 8055cf593ad0..000000000000 --- a/runners/flink/1.15/build.gradle +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -project.ext { - flink_major = '1.15' - flink_version = '1.15.0' -} - -// Load the main build script which contains all build logic. -apply from: "../flink_runner.gradle" diff --git a/runners/flink/1.15/job-server-container/build.gradle b/runners/flink/1.15/job-server-container/build.gradle deleted file mode 100644 index afdb68a0fc91..000000000000 --- a/runners/flink/1.15/job-server-container/build.gradle +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -def basePath = '../../job-server-container' - -project.ext { - resource_path = basePath -} - -// Load the main build script which contains all build logic. -apply from: "$basePath/flink_job_server_container.gradle" diff --git a/runners/flink/1.15/job-server/build.gradle b/runners/flink/1.15/job-server/build.gradle deleted file mode 100644 index 05ad8feb5b78..000000000000 --- a/runners/flink/1.15/job-server/build.gradle +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -def basePath = '../../job-server' - -project.ext { - // Look for the source code in the parent module - main_source_dirs = ["$basePath/src/main/java"] - test_source_dirs = ["$basePath/src/test/java"] - main_resources_dirs = ["$basePath/src/main/resources"] - test_resources_dirs = ["$basePath/src/test/resources"] - archives_base_name = 'beam-runners-flink-1.15-job-server' -} - -// Load the main build script which contains all build logic. -apply from: "$basePath/flink_job_server.gradle" diff --git a/runners/flink/1.15/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/1.16/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java similarity index 100% rename from runners/flink/1.15/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java rename to runners/flink/1.16/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java diff --git a/runners/flink/1.15/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java b/runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java similarity index 100% rename from runners/flink/1.15/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java rename to runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java diff --git a/runners/flink/1.15/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java b/runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java similarity index 100% rename from runners/flink/1.15/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java rename to runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java diff --git a/runners/flink/flink_runner.gradle b/runners/flink/flink_runner.gradle index c8f492a901d3..d13e1c5faf6e 100644 --- a/runners/flink/flink_runner.gradle +++ b/runners/flink/flink_runner.gradle @@ -173,36 +173,19 @@ dependencies { implementation library.java.joda_time implementation library.java.args4j - // Flink 1.15 shades all remaining scala dependencies and therefor does not depend on a specific version of Scala anymore - if (flink_version.compareTo("1.15") >= 0) { - implementation "org.apache.flink:flink-clients:$flink_version" - // Runtime dependencies are not included in Beam's generated pom.xml, so we must declare flink-clients in implementation - // configuration (https://issues.apache.org/jira/browse/BEAM-11732). - permitUnusedDeclared "org.apache.flink:flink-clients:$flink_version" - - implementation "org.apache.flink:flink-streaming-java:$flink_version" - // RocksDB state backend (included in the Flink distribution) - provided "org.apache.flink:flink-statebackend-rocksdb:$flink_version" - testImplementation "org.apache.flink:flink-statebackend-rocksdb:$flink_version" - testImplementation "org.apache.flink:flink-streaming-java:$flink_version:tests" - testImplementation "org.apache.flink:flink-test-utils:$flink_version" - - miniCluster "org.apache.flink:flink-runtime-web:$flink_version" - } else { - implementation "org.apache.flink:flink-clients_2.12:$flink_version" - // Runtime dependencies are not included in Beam's generated pom.xml, so we must declare flink-clients in implementation - // configuration (https://issues.apache.org/jira/browse/BEAM-11732). - permitUnusedDeclared "org.apache.flink:flink-clients_2.12:$flink_version" - - implementation "org.apache.flink:flink-streaming-java_2.12:$flink_version" - // RocksDB state backend (included in the Flink distribution) - provided "org.apache.flink:flink-statebackend-rocksdb_2.12:$flink_version" - testImplementation "org.apache.flink:flink-statebackend-rocksdb_2.12:$flink_version" - testImplementation "org.apache.flink:flink-streaming-java_2.12:$flink_version:tests" - testImplementation "org.apache.flink:flink-test-utils_2.12:$flink_version" - - miniCluster "org.apache.flink:flink-runtime-web_2.12:$flink_version" - } + implementation "org.apache.flink:flink-clients:$flink_version" + // Runtime dependencies are not included in Beam's generated pom.xml, so we must declare flink-clients in implementation + // configuration (https://issues.apache.org/jira/browse/BEAM-11732). + permitUnusedDeclared "org.apache.flink:flink-clients:$flink_version" + + implementation "org.apache.flink:flink-streaming-java:$flink_version" + // RocksDB state backend (included in the Flink distribution) + provided "org.apache.flink:flink-statebackend-rocksdb:$flink_version" + testImplementation "org.apache.flink:flink-statebackend-rocksdb:$flink_version" + testImplementation "org.apache.flink:flink-streaming-java:$flink_version:tests" + testImplementation "org.apache.flink:flink-test-utils:$flink_version" + + miniCluster "org.apache.flink:flink-runtime-web:$flink_version" implementation "org.apache.flink:flink-core:$flink_version" implementation "org.apache.flink:flink-metrics-core:$flink_version" diff --git a/sdks/typescript/src/apache_beam/runners/flink.ts b/sdks/typescript/src/apache_beam/runners/flink.ts index e21876c0d517..b68d3070a720 100644 --- a/sdks/typescript/src/apache_beam/runners/flink.ts +++ b/sdks/typescript/src/apache_beam/runners/flink.ts @@ -28,7 +28,7 @@ import { JavaJarService } from "../utils/service"; const MAGIC_HOST_NAMES = ["[local]", "[auto]"]; // These should stay in sync with gradle.properties. -const PUBLISHED_FLINK_VERSIONS = ["1.15", "1.16", "1.17", "1.18", "1.19"]; +const PUBLISHED_FLINK_VERSIONS = ["1.16", "1.17", "1.18", "1.19"]; const defaultOptions = { flinkMaster: "[local]", diff --git a/settings.gradle.kts b/settings.gradle.kts index b71ed1ede134..67e499e1ea31 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -125,10 +125,6 @@ include(":runners:extensions-java:metrics") * verify versions in website/www/site/content/en/documentation/runners/flink.md * verify version in sdks/python/apache_beam/runners/interactive/interactive_beam.py */ -// Flink 1.15 -include(":runners:flink:1.15") -include(":runners:flink:1.15:job-server") -include(":runners:flink:1.15:job-server-container") // Flink 1.16 include(":runners:flink:1.16") include(":runners:flink:1.16:job-server") diff --git a/website/www/site/content/en/documentation/runners/flink.md b/website/www/site/content/en/documentation/runners/flink.md index 2c28aa7062ec..9bf99cf9e4c2 100644 --- a/website/www/site/content/en/documentation/runners/flink.md +++ b/website/www/site/content/en/documentation/runners/flink.md @@ -196,7 +196,6 @@ The optional `flink_version` option may be required as well for older versions o {{< paragraph class="language-portable" >}} Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: -[Flink 1.15](https://hub.docker.com/r/apache/beam_flink1.15_job_server). [Flink 1.16](https://hub.docker.com/r/apache/beam_flink1.16_job_server). [Flink 1.17](https://hub.docker.com/r/apache/beam_flink1.17_job_server). [Flink 1.18](https://hub.docker.com/r/apache/beam_flink1.18_job_server). @@ -350,7 +349,7 @@ To find out which version of Flink is compatible with Beam please see the table 1.15.x beam-runners-flink-1.15 - ≥ 2.40.0 + 2.40.0 - 2.60.0 1.14.x From 56b54a4b66fbaafae57bcd5dd019ac0c183ee141 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 17 Oct 2024 13:58:07 -0400 Subject: [PATCH 036/181] Drop Flink 1.16 support --- CHANGES.md | 2 +- gradle.properties | 2 +- .../runner-concepts/description.md | 4 +- runners/flink/1.16/build.gradle | 25 --- .../1.16/job-server-container/build.gradle | 26 --- runners/flink/1.16/job-server/build.gradle | 31 --- .../types/CoderTypeSerializer.java | 195 ------------------ .../streaming/MemoryStateBackendWrapper.java | 0 .../flink/streaming/StreamSources.java | 0 .../src/apache_beam/runners/flink.ts | 2 +- settings.gradle.kts | 4 - .../content/en/documentation/runners/flink.md | 7 +- 12 files changed, 8 insertions(+), 290 deletions(-) delete mode 100644 runners/flink/1.16/build.gradle delete mode 100644 runners/flink/1.16/job-server-container/build.gradle delete mode 100644 runners/flink/1.16/job-server/build.gradle delete mode 100644 runners/flink/1.16/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java rename runners/flink/{1.16 => 1.17}/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java (100%) rename runners/flink/{1.16 => 1.17}/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java (100%) diff --git a/CHANGES.md b/CHANGES.md index 2ccaeeb49f7e..0167b575f1de 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -82,7 +82,7 @@ ## Deprecations -* Removed support for Flink 1.15 +* Removed support for Flink 1.15 and 1.16 * X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). ## Bugfixes diff --git a/gradle.properties b/gradle.properties index 868c7501ac31..db1db368beb0 100644 --- a/gradle.properties +++ b/gradle.properties @@ -39,6 +39,6 @@ docker_image_default_repo_root=apache docker_image_default_repo_prefix=beam_ # supported flink versions -flink_versions=1.16,1.17,1.18,1.19 +flink_versions=1.17,1.18,1.19 # supported python versions python_versions=3.8,3.9,3.10,3.11,3.12 diff --git a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md index 063e7f35f876..c0d7b37725ac 100644 --- a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md +++ b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md @@ -191,7 +191,7 @@ $ wordcount --input gs://dataflow-samples/shakespeare/kinglear.txt \ {{if (eq .Sdk "java")}} ##### Portable -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. 2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.19_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: @@ -233,7 +233,7 @@ mvn exec:java -Dexec.mainClass=org.apache.beam.examples.WordCount \ {{end}} {{if (eq .Sdk "python")}} -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. 2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.19_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: diff --git a/runners/flink/1.16/build.gradle b/runners/flink/1.16/build.gradle deleted file mode 100644 index 21a222864a27..000000000000 --- a/runners/flink/1.16/build.gradle +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -project.ext { - flink_major = '1.16' - flink_version = '1.16.0' -} - -// Load the main build script which contains all build logic. -apply from: "../flink_runner.gradle" diff --git a/runners/flink/1.16/job-server-container/build.gradle b/runners/flink/1.16/job-server-container/build.gradle deleted file mode 100644 index afdb68a0fc91..000000000000 --- a/runners/flink/1.16/job-server-container/build.gradle +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -def basePath = '../../job-server-container' - -project.ext { - resource_path = basePath -} - -// Load the main build script which contains all build logic. -apply from: "$basePath/flink_job_server_container.gradle" diff --git a/runners/flink/1.16/job-server/build.gradle b/runners/flink/1.16/job-server/build.gradle deleted file mode 100644 index 99dc00275a0c..000000000000 --- a/runners/flink/1.16/job-server/build.gradle +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -def basePath = '../../job-server' - -project.ext { - // Look for the source code in the parent module - main_source_dirs = ["$basePath/src/main/java"] - test_source_dirs = ["$basePath/src/test/java"] - main_resources_dirs = ["$basePath/src/main/resources"] - test_resources_dirs = ["$basePath/src/test/resources"] - archives_base_name = 'beam-runners-flink-1.16-job-server' -} - -// Load the main build script which contains all build logic. -apply from: "$basePath/flink_job_server.gradle" diff --git a/runners/flink/1.16/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/1.16/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java deleted file mode 100644 index 956aad428d8b..000000000000 --- a/runners/flink/1.16/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.flink.translation.types; - -import java.io.EOFException; -import java.io.IOException; -import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -import org.apache.beam.runners.flink.FlinkPipelineOptions; -import org.apache.beam.runners.flink.translation.wrappers.DataInputViewWrapper; -import org.apache.beam.runners.flink.translation.wrappers.DataOutputViewWrapper; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot; -import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility; -import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; -import org.apache.flink.core.io.VersionedIOReadableWritable; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; -import org.checkerframework.checker.nullness.qual.Nullable; - -/** - * Flink {@link org.apache.flink.api.common.typeutils.TypeSerializer} for Beam {@link - * org.apache.beam.sdk.coders.Coder Coders}. - */ -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public class CoderTypeSerializer extends TypeSerializer { - - private static final long serialVersionUID = 7247319138941746449L; - - private final Coder coder; - - /** - * {@link SerializablePipelineOptions} deserialization will cause {@link - * org.apache.beam.sdk.io.FileSystems} registration needed for {@link - * org.apache.beam.sdk.transforms.Reshuffle} translation. - */ - private final SerializablePipelineOptions pipelineOptions; - - private final boolean fasterCopy; - - public CoderTypeSerializer(Coder coder, SerializablePipelineOptions pipelineOptions) { - Preconditions.checkNotNull(coder); - Preconditions.checkNotNull(pipelineOptions); - this.coder = coder; - this.pipelineOptions = pipelineOptions; - - FlinkPipelineOptions options = pipelineOptions.get().as(FlinkPipelineOptions.class); - this.fasterCopy = options.getFasterCopy(); - } - - @Override - public boolean isImmutableType() { - return false; - } - - @Override - public CoderTypeSerializer duplicate() { - return new CoderTypeSerializer<>(coder, pipelineOptions); - } - - @Override - public T createInstance() { - return null; - } - - @Override - public T copy(T t) { - if (fasterCopy) { - return t; - } - try { - return CoderUtils.clone(coder, t); - } catch (CoderException e) { - throw new RuntimeException("Could not clone.", e); - } - } - - @Override - public T copy(T t, T reuse) { - return copy(t); - } - - @Override - public int getLength() { - return -1; - } - - @Override - public void serialize(T t, DataOutputView dataOutputView) throws IOException { - DataOutputViewWrapper outputWrapper = new DataOutputViewWrapper(dataOutputView); - coder.encode(t, outputWrapper); - } - - @Override - public T deserialize(DataInputView dataInputView) throws IOException { - try { - DataInputViewWrapper inputWrapper = new DataInputViewWrapper(dataInputView); - return coder.decode(inputWrapper); - } catch (CoderException e) { - Throwable cause = e.getCause(); - if (cause instanceof EOFException) { - throw (EOFException) cause; - } else { - throw e; - } - } - } - - @Override - public T deserialize(T t, DataInputView dataInputView) throws IOException { - return deserialize(dataInputView); - } - - @Override - public void copy(DataInputView dataInputView, DataOutputView dataOutputView) throws IOException { - serialize(deserialize(dataInputView), dataOutputView); - } - - @Override - public boolean equals(@Nullable Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - CoderTypeSerializer that = (CoderTypeSerializer) o; - return coder.equals(that.coder); - } - - @Override - public int hashCode() { - return coder.hashCode(); - } - - @Override - public TypeSerializerSnapshot snapshotConfiguration() { - return new UnversionedTypeSerializerSnapshot<>(this); - } - - /** - * A legacy snapshot which does not care about schema compatibility. This is used only for state - * restore of state created by Beam 2.54.0 and below for Flink 1.16 and below. - */ - public static class LegacySnapshot extends TypeSerializerConfigSnapshot { - - /** Needs to be public to work with {@link VersionedIOReadableWritable}. */ - public LegacySnapshot() {} - - public LegacySnapshot(CoderTypeSerializer serializer) { - setPriorSerializer(serializer); - } - - @Override - public int getVersion() { - // We always return the same version - return 1; - } - - @Override - public TypeSerializerSchemaCompatibility resolveSchemaCompatibility( - TypeSerializer newSerializer) { - - // We assume compatibility because we don't have a way of checking schema compatibility - return TypeSerializerSchemaCompatibility.compatibleAsIs(); - } - } - - @Override - public String toString() { - return "CoderTypeSerializer{" + "coder=" + coder + '}'; - } -} diff --git a/runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java b/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java similarity index 100% rename from runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java rename to runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/MemoryStateBackendWrapper.java diff --git a/runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java b/runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java similarity index 100% rename from runners/flink/1.16/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java rename to runners/flink/1.17/src/test/java/org/apache/beam/runners/flink/streaming/StreamSources.java diff --git a/sdks/typescript/src/apache_beam/runners/flink.ts b/sdks/typescript/src/apache_beam/runners/flink.ts index b68d3070a720..ab2d641b3302 100644 --- a/sdks/typescript/src/apache_beam/runners/flink.ts +++ b/sdks/typescript/src/apache_beam/runners/flink.ts @@ -28,7 +28,7 @@ import { JavaJarService } from "../utils/service"; const MAGIC_HOST_NAMES = ["[local]", "[auto]"]; // These should stay in sync with gradle.properties. -const PUBLISHED_FLINK_VERSIONS = ["1.16", "1.17", "1.18", "1.19"]; +const PUBLISHED_FLINK_VERSIONS = ["1.17", "1.18", "1.19"]; const defaultOptions = { flinkMaster: "[local]", diff --git a/settings.gradle.kts b/settings.gradle.kts index 67e499e1ea31..a38f69dac09e 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -125,10 +125,6 @@ include(":runners:extensions-java:metrics") * verify versions in website/www/site/content/en/documentation/runners/flink.md * verify version in sdks/python/apache_beam/runners/interactive/interactive_beam.py */ -// Flink 1.16 -include(":runners:flink:1.16") -include(":runners:flink:1.16:job-server") -include(":runners:flink:1.16:job-server-container") // Flink 1.17 include(":runners:flink:1.17") include(":runners:flink:1.17:job-server") diff --git a/website/www/site/content/en/documentation/runners/flink.md b/website/www/site/content/en/documentation/runners/flink.md index 9bf99cf9e4c2..fb897805cfd6 100644 --- a/website/www/site/content/en/documentation/runners/flink.md +++ b/website/www/site/content/en/documentation/runners/flink.md @@ -196,7 +196,6 @@ The optional `flink_version` option may be required as well for older versions o {{< paragraph class="language-portable" >}} Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: -[Flink 1.16](https://hub.docker.com/r/apache/beam_flink1.16_job_server). [Flink 1.17](https://hub.docker.com/r/apache/beam_flink1.17_job_server). [Flink 1.18](https://hub.docker.com/r/apache/beam_flink1.18_job_server). [Flink 1.19](https://hub.docker.com/r/apache/beam_flink1.19_job_server). @@ -312,8 +311,8 @@ reference. ## Flink Version Compatibility The Flink cluster version has to match the minor version used by the FlinkRunner. -The minor version is the first two numbers in the version string, e.g. in `1.16.0` the -minor version is `1.16`. +The minor version is the first two numbers in the version string, e.g. in `1.19.0` the +minor version is `1.19`. We try to track the latest version of Apache Flink at the time of the Beam release. A Flink version is supported by Beam for the time it is supported by the Flink community. @@ -344,7 +343,7 @@ To find out which version of Flink is compatible with Beam please see the table 1.16.x beam-runners-flink-1.16 - ≥ 2.47.0 + 2.47.0 - 2.60.0 1.15.x From dfa54b23e8d4143275f4bd2c0f90d85944ae76ee Mon Sep 17 00:00:00 2001 From: Jan Lukavsky Date: Fri, 18 Oct 2024 09:50:48 +0200 Subject: [PATCH 037/181] [flink] #32838 remove removed flink version references --- sdks/go/examples/wasm/README.md | 2 +- sdks/python/apache_beam/options/pipeline_options.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/go/examples/wasm/README.md b/sdks/go/examples/wasm/README.md index a78649134305..103bef88642b 100644 --- a/sdks/go/examples/wasm/README.md +++ b/sdks/go/examples/wasm/README.md @@ -68,7 +68,7 @@ cd $BEAM_HOME Expected output should include the following, from which you acquire the latest flink runner version. ```shell -'flink_versions: 1.15,1.16,1.17,1.18,1.19' +'flink_versions: 1.17,1.18,1.19' ``` #### 2. Set to the latest flink runner version i.e. 1.16 diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 837dc0f5439f..455d12b4d3c1 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -1679,7 +1679,7 @@ def _add_argparse_args(cls, parser): class FlinkRunnerOptions(PipelineOptions): # These should stay in sync with gradle.properties. - PUBLISHED_FLINK_VERSIONS = ['1.15', '1.16', '1.17', '1.18', '1.19'] + PUBLISHED_FLINK_VERSIONS = ['1.17', '1.18', '1.19'] @classmethod def _add_argparse_args(cls, parser): From ef6caf4c65c09990e169311750804856a734e5bd Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 18 Oct 2024 14:10:14 -0700 Subject: [PATCH 038/181] Added a TODO. --- sdks/python/apache_beam/yaml/tests/map.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/python/apache_beam/yaml/tests/map.yaml b/sdks/python/apache_beam/yaml/tests/map.yaml index 31fb442085fb..bbb7fc4527de 100644 --- a/sdks/python/apache_beam/yaml/tests/map.yaml +++ b/sdks/python/apache_beam/yaml/tests/map.yaml @@ -30,6 +30,7 @@ pipelines: config: append: true fields: + # TODO(https://github.com/apache/beam/issues/32832): Figure out why Java sometimes re-orders these fields. literal_int: 10 named_field: element literal_float: 1.5 From eafe08b7d9c4bb28695d27d9933ff144cd657714 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 18 Oct 2024 14:44:57 -0700 Subject: [PATCH 039/181] Update docs on error handling output schema. --- .../www/site/content/en/documentation/sdks/yaml-errors.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/website/www/site/content/en/documentation/sdks/yaml-errors.md b/website/www/site/content/en/documentation/sdks/yaml-errors.md index 8c0d9f06ade3..903e18d6b3c7 100644 --- a/website/www/site/content/en/documentation/sdks/yaml-errors.md +++ b/website/www/site/content/en/documentation/sdks/yaml-errors.md @@ -37,7 +37,8 @@ The `output` parameter is a name that must referenced as an input to another transform that will process the errors (e.g. by writing them out). For example, the following code will write all "good" processed records to one file and -any "bad" records to a separate file. +any "bad" records, along with metadata about what error was encountered, +to a separate file. ``` pipeline: @@ -77,6 +78,8 @@ for a robust pipeline). Note also that the exact format of the error outputs is still being finalized. They can be safely printed and written to outputs, but their precise schema may change in a future version of Beam and should not yet be depended on. +Currently it has, at the very least, an `element` field which holds the element +that caused the error. Some transforms allow for extra arguments in their error_handling config, e.g. for Python functions one can give a `threshold` which limits the relative number From 1ba33b888fc76a1e25cd7ad45ec9fde642b6f572 Mon Sep 17 00:00:00 2001 From: Robert Burke Date: Sun, 20 Oct 2024 08:39:04 -0700 Subject: [PATCH 040/181] [#30703][prism] Update logging handling (#32826) * Migrate to standard library slog package * Add dev logger dependency for pre printed development logs * Improve logging output for prism and user side logs, and emit container logs. * Fix missed lines from artifact and worker. --------- Co-authored-by: lostluck <13907733+lostluck@users.noreply.github.com> --- sdks/go.mod | 4 +- sdks/go.sum | 2 + sdks/go/cmd/prism/prism.go | 47 ++++++++++++++ .../beam/core/runtime/metricsx/metricsx.go | 2 +- .../pkg/beam/runners/prism/internal/coders.go | 2 +- .../runners/prism/internal/engine/data.go | 2 +- .../prism/internal/engine/elementmanager.go | 4 +- .../runners/prism/internal/environments.go | 32 ++++++++-- .../beam/runners/prism/internal/execute.go | 14 ++-- .../runners/prism/internal/handlerunner.go | 2 +- .../prism/internal/jobservices/artifact.go | 4 +- .../runners/prism/internal/jobservices/job.go | 4 +- .../prism/internal/jobservices/management.go | 3 +- .../prism/internal/jobservices/metrics.go | 4 +- .../prism/internal/jobservices/server.go | 6 +- .../beam/runners/prism/internal/preprocess.go | 2 +- .../runners/prism/internal/separate_test.go | 12 ++-- .../pkg/beam/runners/prism/internal/stage.go | 4 +- .../beam/runners/prism/internal/web/web.go | 2 +- .../runners/prism/internal/worker/bundle.go | 2 +- .../runners/prism/internal/worker/worker.go | 64 +++++++++++-------- 21 files changed, 153 insertions(+), 65 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index 74556ee12a55..706be73f97f6 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -20,7 +20,7 @@ // directory. module github.com/apache/beam/sdks/v2 -go 1.21 +go 1.21.0 require ( cloud.google.com/go/bigquery v1.63.1 @@ -69,6 +69,8 @@ require ( require ( github.com/avast/retry-go/v4 v4.6.0 github.com/fsouza/fake-gcs-server v1.49.2 + github.com/golang-cz/devslog v0.0.11 + github.com/golang/protobuf v1.5.4 golang.org/x/exp v0.0.0-20231006140011-7918f672742d ) diff --git a/sdks/go.sum b/sdks/go.sum index af68a630addd..fa3c75bd3395 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -853,6 +853,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-cz/devslog v0.0.11 h1:v4Yb9o0ZpuZ/D8ZrtVw1f9q5XrjnkxwHF1XmWwO8IHg= +github.com/golang-cz/devslog v0.0.11/go.mod h1:bSe5bm0A7Nyfqtijf1OMNgVJHlWEuVSXnkuASiE1vV8= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= diff --git a/sdks/go/cmd/prism/prism.go b/sdks/go/cmd/prism/prism.go index 39c19df00dc3..070d2f023b74 100644 --- a/sdks/go/cmd/prism/prism.go +++ b/sdks/go/cmd/prism/prism.go @@ -22,9 +22,14 @@ import ( "flag" "fmt" "log" + "log/slog" + "os" + "strings" + "time" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism" + "github.com/golang-cz/devslog" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -37,10 +42,52 @@ var ( idleShutdownTimeout = flag.Duration("idle_shutdown_timeout", -1, "duration that prism will wait for a new job before shutting itself down. Negative durations disable auto shutdown. Defaults to never shutting down.") ) +// Logging flags +var ( + debug = flag.Bool("debug", false, + "Enables full verbosity debug logging from the runner by default. Used to build SDKs or debug Prism itself.") + logKind = flag.String("log_kind", "dev", + "Determines the format of prism's logging to std err: valid values are `dev', 'json', or 'text'. Default is `dev`.") +) + +var logLevel = new(slog.LevelVar) + func main() { flag.Parse() ctx, cancel := context.WithCancelCause(context.Background()) + var logHandler slog.Handler + loggerOutput := os.Stderr + handlerOpts := &slog.HandlerOptions{ + Level: logLevel, + AddSource: *debug, + } + if *debug { + logLevel.Set(slog.LevelDebug) + // Print the Prism source line for a log in debug mode. + handlerOpts.AddSource = true + } + switch strings.ToLower(*logKind) { + case "dev": + logHandler = + devslog.NewHandler(loggerOutput, &devslog.Options{ + TimeFormat: "[" + time.RFC3339Nano + "]", + StringerFormatter: true, + HandlerOptions: handlerOpts, + StringIndentation: false, + NewLineAfterLog: true, + MaxErrorStackTrace: 3, + }) + case "json": + logHandler = slog.NewJSONHandler(loggerOutput, handlerOpts) + case "text": + logHandler = slog.NewTextHandler(loggerOutput, handlerOpts) + default: + log.Fatalf("Invalid value for log_kind: %v, must be 'dev', 'json', or 'text'", *logKind) + } + + slog.SetDefault(slog.New(logHandler)) + cli, err := makeJobClient(ctx, prism.Options{ Port: *jobPort, diff --git a/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go b/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go index c71ead208364..06bb727178fc 100644 --- a/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go +++ b/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go @@ -19,12 +19,12 @@ import ( "bytes" "fmt" "log" + "log/slog" "time" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/metrics" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" - "golang.org/x/exp/slog" ) // FromMonitoringInfos extracts metrics from monitored states and diff --git a/sdks/go/pkg/beam/runners/prism/internal/coders.go b/sdks/go/pkg/beam/runners/prism/internal/coders.go index eb8abe16ecf8..ffea90e79065 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/coders.go +++ b/sdks/go/pkg/beam/runners/prism/internal/coders.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log/slog" "strings" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" @@ -28,7 +29,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go index eaaf7f831712..7b8689f95112 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go @@ -18,12 +18,12 @@ package engine import ( "bytes" "fmt" + "log/slog" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" - "golang.org/x/exp/slog" ) // StateData is a "union" between Bag state and MultiMap state to increase common code. diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go index f7229853e4d3..3cfde4701a8f 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go +++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go @@ -23,6 +23,7 @@ import ( "context" "fmt" "io" + "log/slog" "sort" "strings" "sync" @@ -36,7 +37,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" ) type element struct { @@ -1607,7 +1607,7 @@ func (ss *stageState) bundleReady(em *ElementManager, emNow mtime.Time) (mtime.T inputW := ss.input _, upstreamW := ss.UpstreamWatermark() if inputW == upstreamW { - slog.Debug("bundleReady: insufficient upstream watermark", + slog.Debug("bundleReady: unchanged upstream watermark", slog.String("stage", ss.ID), slog.Group("watermark", slog.Any("upstream", upstreamW), diff --git a/sdks/go/pkg/beam/runners/prism/internal/environments.go b/sdks/go/pkg/beam/runners/prism/internal/environments.go index add7f769a702..2f960a04f0cb 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/environments.go +++ b/sdks/go/pkg/beam/runners/prism/internal/environments.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log/slog" "os" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" @@ -27,7 +28,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" - "golang.org/x/exp/slog" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/proto" @@ -42,7 +42,7 @@ import ( // TODO move environment handling to the worker package. func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *worker.W) error { - logger := slog.With(slog.String("envID", wk.Env)) + logger := j.Logger.With(slog.String("envID", wk.Env)) // TODO fix broken abstraction. // We're starting a worker pool here, because that's the loopback environment. // It's sort of a mess, largely because of loopback, which has @@ -56,7 +56,7 @@ func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *wor } go func() { externalEnvironment(ctx, ep, wk) - slog.Debug("environment stopped", slog.String("job", j.String())) + logger.Debug("environment stopped", slog.String("job", j.String())) }() return nil case urns.EnvDocker: @@ -129,6 +129,8 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock credEnv := fmt.Sprintf("%v=%v", gcloudCredsEnv, dockerGcloudCredsFile) envs = append(envs, credEnv) } + } else { + logger.Debug("local GCP credentials environment variable not found") } if _, _, err := cli.ImageInspectWithRaw(ctx, dp.GetContainerImage()); err != nil { // We don't have a local image, so we should pull it. @@ -140,6 +142,7 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock logger.Warn("unable to pull image and it's not local", "error", err) } } + logger.Debug("creating container", "envs", envs, "mounts", mounts) ccr, err := cli.ContainerCreate(ctx, &container.Config{ Image: dp.GetContainerImage(), @@ -169,17 +172,32 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock return fmt.Errorf("unable to start container image %v with docker for env %v, err: %w", dp.GetContainerImage(), wk.Env, err) } + logger.Debug("container started") + // Start goroutine to wait on container state. go func() { defer cli.Close() defer wk.Stop() + defer func() { + logger.Debug("container stopped") + }() - statusCh, errCh := cli.ContainerWait(ctx, containerID, container.WaitConditionNotRunning) + bgctx := context.Background() + statusCh, errCh := cli.ContainerWait(bgctx, containerID, container.WaitConditionNotRunning) select { case <-ctx.Done(): - // Can't use command context, since it's already canceled here. - err := cli.ContainerKill(context.Background(), containerID, "") + rc, err := cli.ContainerLogs(bgctx, containerID, container.LogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) if err != nil { + logger.Error("error fetching container logs error on context cancellation", "error", err) + } + if rc != nil { + defer rc.Close() + var buf bytes.Buffer + stdcopy.StdCopy(&buf, &buf, rc) + logger.Info("container being killed", slog.Any("cause", context.Cause(ctx)), slog.Any("containerLog", buf)) + } + // Can't use command context, since it's already canceled here. + if err := cli.ContainerKill(bgctx, containerID, ""); err != nil { logger.Error("docker container kill error", "error", err) } case err := <-errCh: @@ -189,7 +207,7 @@ func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.Dock case resp := <-statusCh: logger.Info("docker container has self terminated", "status_code", resp.StatusCode) - rc, err := cli.ContainerLogs(ctx, containerID, container.LogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) + rc, err := cli.ContainerLogs(bgctx, containerID, container.LogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) if err != nil { logger.Error("docker container logs error", "error", err) } diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index d7605f34f5f2..614edee47721 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "io" + "log/slog" "sort" "sync/atomic" "time" @@ -34,7 +35,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" ) @@ -311,7 +311,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic return fmt.Errorf("prism error building stage %v: \n%w", stage.ID, err) } stages[stage.ID] = stage - slog.Debug("pipelineBuild", slog.Group("stage", slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName()))) + j.Logger.Debug("pipelineBuild", slog.Group("stage", slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName()))) outputs := maps.Keys(stage.OutputsToCoders) sort.Strings(outputs) em.AddStage(stage.ID, []string{stage.primaryInput}, outputs, stage.sideInputs) @@ -322,9 +322,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic em.StageProcessingTimeTimers(stage.ID, stage.processingTimeTimers) } default: - err := fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId()) - slog.Error("Execute", err) - return err + return fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId()) } } @@ -344,11 +342,13 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic for { select { case <-ctx.Done(): - return context.Cause(ctx) + err := context.Cause(ctx) + j.Logger.Debug("context canceled", slog.Any("cause", err)) + return err case rb, ok := <-bundles: if !ok { err := eg.Wait() - slog.Debug("pipeline done!", slog.String("job", j.String()), slog.Any("error", err)) + j.Logger.Debug("pipeline done!", slog.String("job", j.String()), slog.Any("error", err), slog.Any("topo", topo)) return err } eg.Go(func() error { diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go index 8590fd0d4ced..be9d39ad02b7 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go +++ b/sdks/go/pkg/beam/runners/prism/internal/handlerunner.go @@ -19,6 +19,7 @@ import ( "bytes" "fmt" "io" + "log/slog" "reflect" "sort" @@ -31,7 +32,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go index 99b786d45980..e42e3e7ca666 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go @@ -20,9 +20,9 @@ import ( "context" "fmt" "io" + "log/slog" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" ) @@ -77,7 +77,7 @@ func (s *Server) ReverseArtifactRetrievalService(stream jobpb.ArtifactStagingSer case *jobpb.ArtifactResponseWrapper_ResolveArtifactResponse: err := fmt.Errorf("unexpected ResolveArtifactResponse to GetArtifact: %v", in.GetResponse()) - slog.Error("GetArtifact failure", err) + slog.Error("GetArtifact failure", slog.Any("error", err)) return err } } diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go index 1407feafe325..deef259a99d1 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go @@ -27,6 +27,7 @@ package jobservices import ( "context" "fmt" + "log/slog" "sort" "strings" "sync" @@ -37,7 +38,6 @@ import ( jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "golang.org/x/exp/slog" "google.golang.org/protobuf/types/known/structpb" ) @@ -88,6 +88,8 @@ type Job struct { // Context used to terminate this job. RootCtx context.Context CancelFn context.CancelCauseFunc + // Logger for this job. + Logger *slog.Logger metrics metricsStore } diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index b957b99ca63d..a2840760bf7a 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -19,6 +19,7 @@ import ( "context" "errors" "fmt" + "log/slog" "sync" "sync/atomic" @@ -27,7 +28,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -92,6 +92,7 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (_ * cancelFn(err) terminalOnceWrap() }, + Logger: s.logger, // TODO substitute with a configured logger. artifactEndpoint: s.Endpoint(), } // Stop the idle timer when a new job appears. diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go index 03d5b0a98369..bbbdfd1eba4f 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/metrics.go @@ -19,6 +19,7 @@ import ( "bytes" "fmt" "hash/maphash" + "log/slog" "math" "sort" "sync" @@ -28,7 +29,6 @@ import ( fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "golang.org/x/exp/constraints" - "golang.org/x/exp/slog" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -589,7 +589,7 @@ func (m *metricsStore) AddShortIDs(resp *fnpb.MonitoringInfosMetadataResponse) { urn := mi.GetUrn() ops, ok := mUrn2Ops[urn] if !ok { - slog.Debug("unknown metrics urn", slog.String("urn", urn)) + slog.Debug("unknown metrics urn", slog.String("shortID", short), slog.String("urn", urn), slog.String("type", mi.Type)) continue } key := ops.keyFn(urn, mi.GetLabels()) diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go index 320159f54c06..bdfe2aff2dd4 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go @@ -18,6 +18,7 @@ package jobservices import ( "context" "fmt" + "log/slog" "math" "net" "os" @@ -27,7 +28,6 @@ import ( fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" - "golang.org/x/exp/slog" "google.golang.org/grpc" ) @@ -53,6 +53,7 @@ type Server struct { terminatedJobCount uint32 // Use with atomics. idleTimeout time.Duration cancelFn context.CancelCauseFunc + logger *slog.Logger // execute defines how a job is executed. execute func(*Job) @@ -71,8 +72,9 @@ func NewServer(port int, execute func(*Job)) *Server { lis: lis, jobs: make(map[string]*Job), execute: execute, + logger: slog.Default(), // TODO substitute with a configured logger. } - slog.Info("Serving JobManagement", slog.String("endpoint", s.Endpoint())) + s.logger.Info("Serving JobManagement", slog.String("endpoint", s.Endpoint())) opts := []grpc.ServerOption{ grpc.MaxRecvMsgSize(math.MaxInt32), } diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go index 7de32f85b7ee..dceaa9ab8fcb 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go +++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go @@ -17,6 +17,7 @@ package internal import ( "fmt" + "log/slog" "sort" "strings" @@ -26,7 +27,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "google.golang.org/protobuf/proto" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/separate_test.go b/sdks/go/pkg/beam/runners/prism/internal/separate_test.go index 1be3d3e70841..650932f525c8 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/separate_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/separate_test.go @@ -18,6 +18,7 @@ package internal_test import ( "context" "fmt" + "log/slog" "net" "net/http" "net/rpc" @@ -34,7 +35,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/stats" - "golang.org/x/exp/slog" ) // separate_test.go retains structures and tests to ensure the runner can @@ -286,7 +286,7 @@ func (ws *Watchers) Check(args *Args, unblocked *bool) error { w.mu.Lock() *unblocked = w.sentinelCount >= w.sentinelCap w.mu.Unlock() - slog.Debug("sentinel target for watcher%d is %d/%d. unblocked=%v", args.WatcherID, w.sentinelCount, w.sentinelCap, *unblocked) + slog.Debug("sentinel watcher status", slog.Int("watcher", args.WatcherID), slog.Int("sentinelCount", w.sentinelCount), slog.Int("sentinelCap", w.sentinelCap), slog.Bool("unblocked", *unblocked)) return nil } @@ -360,7 +360,7 @@ func (fn *sepHarnessBase) setup() error { sepClientOnce.Do(func() { client, err := rpc.DialHTTP("tcp", fn.LocalService) if err != nil { - slog.Error("failed to dial sentinels server", err, slog.String("endpoint", fn.LocalService)) + slog.Error("failed to dial sentinels server", slog.Any("error", err), slog.String("endpoint", fn.LocalService)) panic(fmt.Sprintf("dialing sentinels server %v: %v", fn.LocalService, err)) } sepClient = client @@ -385,7 +385,7 @@ func (fn *sepHarnessBase) setup() error { var unblock bool err := sepClient.Call("Watchers.Check", &Args{WatcherID: id}, &unblock) if err != nil { - slog.Error("Watchers.Check: sentinels server error", err, slog.String("endpoint", fn.LocalService)) + slog.Error("Watchers.Check: sentinels server error", slog.Any("error", err), slog.String("endpoint", fn.LocalService)) panic("sentinel server error") } if unblock { @@ -406,7 +406,7 @@ func (fn *sepHarnessBase) block() { var ignored bool err := sepClient.Call("Watchers.Block", &Args{WatcherID: fn.WatcherID}, &ignored) if err != nil { - slog.Error("Watchers.Block error", err, slog.String("endpoint", fn.LocalService)) + slog.Error("Watchers.Block error", slog.Any("error", err), slog.String("endpoint", fn.LocalService)) panic(err) } c := sepWaitMap[fn.WatcherID] @@ -423,7 +423,7 @@ func (fn *sepHarnessBase) delay() bool { var delay bool err := sepClient.Call("Watchers.Delay", &Args{WatcherID: fn.WatcherID}, &delay) if err != nil { - slog.Error("Watchers.Delay error", err) + slog.Error("Watchers.Delay error", slog.Any("error", err)) panic(err) } return delay diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index f33754b2ca0a..9f00c22789b6 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "log/slog" "runtime/debug" "sync/atomic" "time" @@ -33,7 +34,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" ) @@ -361,7 +361,7 @@ func portFor(wInCid string, wk *worker.W) []byte { } sourcePortBytes, err := proto.Marshal(sourcePort) if err != nil { - slog.Error("bad port", err, slog.String("endpoint", sourcePort.ApiServiceDescriptor.GetUrl())) + slog.Error("bad port", slog.Any("error", err), slog.String("endpoint", sourcePort.ApiServiceDescriptor.GetUrl())) } return sourcePortBytes } diff --git a/sdks/go/pkg/beam/runners/prism/internal/web/web.go b/sdks/go/pkg/beam/runners/prism/internal/web/web.go index 9fabe22cee3a..b14778e4462c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/web/web.go +++ b/sdks/go/pkg/beam/runners/prism/internal/web/web.go @@ -26,6 +26,7 @@ import ( "fmt" "html/template" "io" + "log/slog" "net/http" "sort" "strings" @@ -40,7 +41,6 @@ import ( jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "golang.org/x/exp/maps" - "golang.org/x/exp/slog" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" ) diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go index 3ccafdb81e9a..55cdb97f258c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go @@ -19,12 +19,12 @@ import ( "bytes" "context" "fmt" + "log/slog" "sync/atomic" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" - "golang.org/x/exp/slog" ) // SideInputKey is for data lookups for a given bundle. diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index f9ec03793488..1f129595abef 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -22,10 +22,9 @@ import ( "context" "fmt" "io" + "log/slog" "math" "net" - "strconv" - "strings" "sync" "sync/atomic" "time" @@ -39,7 +38,6 @@ import ( pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" - "golang.org/x/exp/slog" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -203,30 +201,46 @@ func (wk *W) Logging(stream fnpb.BeamFnLogging_LoggingServer) error { case codes.Canceled: return nil default: - slog.Error("logging.Recv", err, "worker", wk) + slog.Error("logging.Recv", slog.Any("error", err), slog.Any("worker", wk)) return err } } for _, l := range in.GetLogEntries() { - if l.Severity >= minsev { - // TODO: Connect to the associated Job for this worker instead of - // logging locally for SDK side logging. - file := l.GetLogLocation() - i := strings.LastIndex(file, ":") - line, _ := strconv.Atoi(file[i+1:]) - if i > 0 { - file = file[:i] - } + // TODO base this on a per pipeline logging setting. + if l.Severity < minsev { + continue + } + + // Ideally we'd be writing these to per-pipeline files, but for now re-log them on the Prism process. + // We indicate they're from the SDK, and which worker, keeping the same log severity. + // SDK specific and worker specific fields are in separate groups for legibility. - slog.LogAttrs(stream.Context(), toSlogSev(l.GetSeverity()), l.GetMessage(), - slog.Any(slog.SourceKey, &slog.Source{ - File: file, - Line: line, - }), - slog.Time(slog.TimeKey, l.GetTimestamp().AsTime()), - slog.Any("worker", wk), - ) + attrs := []any{ + slog.String("transformID", l.GetTransformId()), // TODO: pull the unique name from the pipeline graph. + slog.String("location", l.GetLogLocation()), + slog.Time(slog.TimeKey, l.GetTimestamp().AsTime()), + slog.String(slog.MessageKey, l.GetMessage()), } + if fs := l.GetCustomData().GetFields(); len(fs) > 0 { + var grp []any + for n, v := range l.GetCustomData().GetFields() { + var attr slog.Attr + switch v.Kind.(type) { + case *structpb.Value_BoolValue: + attr = slog.Bool(n, v.GetBoolValue()) + case *structpb.Value_NumberValue: + attr = slog.Float64(n, v.GetNumberValue()) + case *structpb.Value_StringValue: + attr = slog.String(n, v.GetStringValue()) + default: + attr = slog.Any(n, v.AsInterface()) + } + grp = append(grp, attr) + } + attrs = append(attrs, slog.Group("customData", grp...)) + } + + slog.LogAttrs(stream.Context(), toSlogSev(l.GetSeverity()), "log from SDK worker", slog.Any("worker", wk), slog.Group("sdk", attrs...)) } } } @@ -298,7 +312,7 @@ func (wk *W) Control(ctrl fnpb.BeamFnControl_ControlServer) error { if b, ok := wk.activeInstructions[resp.GetInstructionId()]; ok { b.Respond(resp) } else { - slog.Debug("ctrl.Recv: %v", resp) + slog.Debug("ctrl.Recv", slog.Any("response", resp)) } wk.mu.Unlock() } @@ -355,7 +369,7 @@ func (wk *W) Data(data fnpb.BeamFnData_DataServer) error { case codes.Canceled: return default: - slog.Error("data.Recv failed", err, "worker", wk) + slog.Error("data.Recv failed", slog.Any("error", err), slog.Any("worker", wk)) panic(err) } } @@ -434,7 +448,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { case codes.Canceled: return default: - slog.Error("state.Recv failed", err, "worker", wk) + slog.Error("state.Recv failed", slog.Any("error", err), slog.Any("worker", wk)) panic(err) } } @@ -584,7 +598,7 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) error { }() for resp := range responses { if err := state.Send(resp); err != nil { - slog.Error("state.Send error", err) + slog.Error("state.Send", slog.Any("error", err)) } } return nil From ac87d7b48e86e0c3e863d13b5e8d52469134a446 Mon Sep 17 00:00:00 2001 From: Thiago Nunes Date: Mon, 21 Oct 2024 18:51:09 +1100 Subject: [PATCH 041/181] fix: generate random index name for change streams (#32689) Generates index names for change stream partition metadata table using a random UUID. This prevents issues if the job is being redeployed in an existing database. --- .../beam/sdk/io/gcp/spanner/SpannerIO.java | 19 ++- .../spanner/changestreams/NameGenerator.java | 52 ------- .../spanner/changestreams/dao/DaoFactory.java | 12 +- .../dao/PartitionMetadataAdminDao.java | 58 +++---- .../dao/PartitionMetadataDao.java | 35 +++++ .../dao/PartitionMetadataTableNames.java | 144 ++++++++++++++++++ .../dofn/CleanUpReadChangeStreamDoFn.java | 4 +- .../changestreams/dofn/InitializeDoFn.java | 1 + .../changestreams/NameGeneratorTest.java | 41 ----- .../dao/PartitionMetadataAdminDaoTest.java | 56 +++++-- .../dao/PartitionMetadataTableNamesTest.java | 73 +++++++++ 11 files changed, 344 insertions(+), 151 deletions(-) delete mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java delete mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java create mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 435bbba9ae8e..d9dde11a3081 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -25,7 +25,6 @@ import static org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants.DEFAULT_RPC_PRIORITY; import static org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants.MAX_INCLUSIVE_END_AT; import static org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants.THROUGHPUT_WINDOW_SECONDS; -import static org.apache.beam.sdk.io.gcp.spanner.changestreams.NameGenerator.generatePartitionMetadataTableName; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; @@ -61,6 +60,7 @@ import java.util.HashMap; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -77,6 +77,7 @@ import org.apache.beam.sdk.io.gcp.spanner.changestreams.MetadataSpannerConfigFactory; import org.apache.beam.sdk.io.gcp.spanner.changestreams.action.ActionFactory; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.DaoFactory; +import org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.PartitionMetadataTableNames; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.CleanUpReadChangeStreamDoFn; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.DetectNewPartitionsDoFn; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn.InitializeDoFn; @@ -1772,9 +1773,13 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta + fullPartitionMetadataDatabaseId + " has dialect " + metadataDatabaseDialect); - final String partitionMetadataTableName = - MoreObjects.firstNonNull( - getMetadataTable(), generatePartitionMetadataTableName(partitionMetadataDatabaseId)); + PartitionMetadataTableNames partitionMetadataTableNames = + Optional.ofNullable(getMetadataTable()) + .map( + table -> + PartitionMetadataTableNames.fromExistingTable( + partitionMetadataDatabaseId, table)) + .orElse(PartitionMetadataTableNames.generateRandom(partitionMetadataDatabaseId)); final String changeStreamName = getChangeStreamName(); final Timestamp startTimestamp = getInclusiveStartAt(); // Uses (Timestamp.MAX - 1ns) at max for end timestamp, because we add 1ns to transform the @@ -1791,7 +1796,7 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta changeStreamSpannerConfig, changeStreamName, partitionMetadataSpannerConfig, - partitionMetadataTableName, + partitionMetadataTableNames, rpcPriority, input.getPipeline().getOptions().getJobName(), changeStreamDatabaseDialect, @@ -1807,7 +1812,9 @@ && getInclusiveStartAt().toSqlTimestamp().after(getInclusiveEndAt().toSqlTimesta final PostProcessingMetricsDoFn postProcessingMetricsDoFn = new PostProcessingMetricsDoFn(metrics); - LOG.info("Partition metadata table that will be used is " + partitionMetadataTableName); + LOG.info( + "Partition metadata table that will be used is " + + partitionMetadataTableNames.getTableName()); final PCollection impulseOut = input.apply(Impulse.create()); final PCollection partitionsOut = diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java deleted file mode 100644 index 322e85cb07a2..000000000000 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGenerator.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.spanner.changestreams; - -import java.util.UUID; - -/** - * This class generates a unique name for the partition metadata table, which is created when the - * Connector is initialized. - */ -public class NameGenerator { - - private static final String PARTITION_METADATA_TABLE_NAME_FORMAT = "Metadata_%s_%s"; - private static final int MAX_TABLE_NAME_LENGTH = 63; - - /** - * Generates an unique name for the partition metadata table in the form of {@code - * "Metadata__"}. - * - * @param databaseId The database id where the table will be created - * @return the unique generated name of the partition metadata table - */ - public static String generatePartitionMetadataTableName(String databaseId) { - // There are 11 characters in the name format. - // Maximum Spanner database ID length is 30 characters. - // UUID always generates a String with 36 characters. - // Since the Postgres table name length is 63, we may need to truncate the table name depending - // on the database length. - String fullString = - String.format(PARTITION_METADATA_TABLE_NAME_FORMAT, databaseId, UUID.randomUUID()) - .replaceAll("-", "_"); - if (fullString.length() < MAX_TABLE_NAME_LENGTH) { - return fullString; - } - return fullString.substring(0, MAX_TABLE_NAME_LENGTH); - } -} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java index b9718fdb675e..787abad02e02 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/DaoFactory.java @@ -44,7 +44,7 @@ public class DaoFactory implements Serializable { private final SpannerConfig metadataSpannerConfig; private final String changeStreamName; - private final String partitionMetadataTableName; + private final PartitionMetadataTableNames partitionMetadataTableNames; private final RpcPriority rpcPriority; private final String jobName; private final Dialect spannerChangeStreamDatabaseDialect; @@ -56,7 +56,7 @@ public class DaoFactory implements Serializable { * @param changeStreamSpannerConfig the configuration for the change streams DAO * @param changeStreamName the name of the change stream for the change streams DAO * @param metadataSpannerConfig the metadata tables configuration - * @param partitionMetadataTableName the name of the created partition metadata table + * @param partitionMetadataTableNames the names of the partition metadata ddl objects * @param rpcPriority the priority of the requests made by the DAO queries * @param jobName the name of the running job */ @@ -64,7 +64,7 @@ public DaoFactory( SpannerConfig changeStreamSpannerConfig, String changeStreamName, SpannerConfig metadataSpannerConfig, - String partitionMetadataTableName, + PartitionMetadataTableNames partitionMetadataTableNames, RpcPriority rpcPriority, String jobName, Dialect spannerChangeStreamDatabaseDialect, @@ -78,7 +78,7 @@ public DaoFactory( this.changeStreamSpannerConfig = changeStreamSpannerConfig; this.changeStreamName = changeStreamName; this.metadataSpannerConfig = metadataSpannerConfig; - this.partitionMetadataTableName = partitionMetadataTableName; + this.partitionMetadataTableNames = partitionMetadataTableNames; this.rpcPriority = rpcPriority; this.jobName = jobName; this.spannerChangeStreamDatabaseDialect = spannerChangeStreamDatabaseDialect; @@ -102,7 +102,7 @@ public synchronized PartitionMetadataAdminDao getPartitionMetadataAdminDao() { databaseAdminClient, metadataSpannerConfig.getInstanceId().get(), metadataSpannerConfig.getDatabaseId().get(), - partitionMetadataTableName, + partitionMetadataTableNames, this.metadataDatabaseDialect); } return partitionMetadataAdminDao; @@ -120,7 +120,7 @@ public synchronized PartitionMetadataDao getPartitionMetadataDao() { if (partitionMetadataDaoInstance == null) { partitionMetadataDaoInstance = new PartitionMetadataDao( - this.partitionMetadataTableName, + this.partitionMetadataTableNames.getTableName(), spannerAccessor.getDatabaseClient(), this.metadataDatabaseDialect); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java index 368cab7022b3..3e6045d8858b 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDao.java @@ -79,19 +79,13 @@ public class PartitionMetadataAdminDao { */ public static final String COLUMN_FINISHED_AT = "FinishedAt"; - /** Metadata table index for queries over the watermark column. */ - public static final String WATERMARK_INDEX = "WatermarkIndex"; - - /** Metadata table index for queries over the created at / start timestamp columns. */ - public static final String CREATED_AT_START_TIMESTAMP_INDEX = "CreatedAtStartTimestampIndex"; - private static final int TIMEOUT_MINUTES = 10; private static final int TTL_AFTER_PARTITION_FINISHED_DAYS = 1; private final DatabaseAdminClient databaseAdminClient; private final String instanceId; private final String databaseId; - private final String tableName; + private final PartitionMetadataTableNames names; private final Dialect dialect; /** @@ -101,18 +95,18 @@ public class PartitionMetadataAdminDao { * table * @param instanceId the instance where the metadata table will reside * @param databaseId the database where the metadata table will reside - * @param tableName the name of the metadata table + * @param names the names of the metadata table ddl objects */ PartitionMetadataAdminDao( DatabaseAdminClient databaseAdminClient, String instanceId, String databaseId, - String tableName, + PartitionMetadataTableNames names, Dialect dialect) { this.databaseAdminClient = databaseAdminClient; this.instanceId = instanceId; this.databaseId = databaseId; - this.tableName = tableName; + this.names = names; this.dialect = dialect; } @@ -128,8 +122,8 @@ public void createPartitionMetadataTable() { if (this.isPostgres()) { // Literals need be added around literals to preserve casing. ddl.add( - "CREATE TABLE \"" - + tableName + "CREATE TABLE IF NOT EXISTS \"" + + names.getTableName() + "\"(\"" + COLUMN_PARTITION_TOKEN + "\" text NOT NULL,\"" @@ -163,20 +157,20 @@ public void createPartitionMetadataTable() { + COLUMN_FINISHED_AT + "\""); ddl.add( - "CREATE INDEX \"" - + WATERMARK_INDEX + "CREATE INDEX IF NOT EXISTS \"" + + names.getWatermarkIndexName() + "\" on \"" - + tableName + + names.getTableName() + "\" (\"" + COLUMN_WATERMARK + "\") INCLUDE (\"" + COLUMN_STATE + "\")"); ddl.add( - "CREATE INDEX \"" - + CREATED_AT_START_TIMESTAMP_INDEX + "CREATE INDEX IF NOT EXISTS \"" + + names.getCreatedAtIndexName() + "\" ON \"" - + tableName + + names.getTableName() + "\" (\"" + COLUMN_CREATED_AT + "\",\"" @@ -184,8 +178,8 @@ public void createPartitionMetadataTable() { + "\")"); } else { ddl.add( - "CREATE TABLE " - + tableName + "CREATE TABLE IF NOT EXISTS " + + names.getTableName() + " (" + COLUMN_PARTITION_TOKEN + " STRING(MAX) NOT NULL," @@ -218,20 +212,20 @@ public void createPartitionMetadataTable() { + TTL_AFTER_PARTITION_FINISHED_DAYS + " DAY))"); ddl.add( - "CREATE INDEX " - + WATERMARK_INDEX + "CREATE INDEX IF NOT EXISTS " + + names.getWatermarkIndexName() + " on " - + tableName + + names.getTableName() + " (" + COLUMN_WATERMARK + ") STORING (" + COLUMN_STATE + ")"); ddl.add( - "CREATE INDEX " - + CREATED_AT_START_TIMESTAMP_INDEX + "CREATE INDEX IF NOT EXISTS " + + names.getCreatedAtIndexName() + " ON " - + tableName + + names.getTableName() + " (" + COLUMN_CREATED_AT + "," @@ -261,16 +255,14 @@ public void createPartitionMetadataTable() { * Drops the metadata table. This operation should complete in {@link * PartitionMetadataAdminDao#TIMEOUT_MINUTES} minutes. */ - public void deletePartitionMetadataTable() { + public void deletePartitionMetadataTable(List indexes) { List ddl = new ArrayList<>(); if (this.isPostgres()) { - ddl.add("DROP INDEX \"" + CREATED_AT_START_TIMESTAMP_INDEX + "\""); - ddl.add("DROP INDEX \"" + WATERMARK_INDEX + "\""); - ddl.add("DROP TABLE \"" + tableName + "\""); + indexes.forEach(index -> ddl.add("DROP INDEX \"" + index + "\"")); + ddl.add("DROP TABLE \"" + names.getTableName() + "\""); } else { - ddl.add("DROP INDEX " + CREATED_AT_START_TIMESTAMP_INDEX); - ddl.add("DROP INDEX " + WATERMARK_INDEX); - ddl.add("DROP TABLE " + tableName); + indexes.forEach(index -> ddl.add("DROP INDEX " + index)); + ddl.add("DROP TABLE " + names.getTableName()); } OperationFuture op = databaseAdminClient.updateDatabaseDdl(instanceId, databaseId, ddl, null); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java index 7867932cd1ad..654fd946663c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataDao.java @@ -96,6 +96,41 @@ public boolean tableExists() { } } + /** + * Finds all indexes for the metadata table. + * + * @return a list of index names for the metadata table. + */ + public List findAllTableIndexes() { + String indexesStmt; + if (this.isPostgres()) { + indexesStmt = + "SELECT index_name FROM information_schema.indexes" + + " WHERE table_schema = 'public'" + + " AND table_name = '" + + metadataTableName + + "' AND index_type != 'PRIMARY_KEY'"; + } else { + indexesStmt = + "SELECT index_name FROM information_schema.indexes" + + " WHERE table_schema = ''" + + " AND table_name = '" + + metadataTableName + + "' AND index_type != 'PRIMARY_KEY'"; + } + + List result = new ArrayList<>(); + try (ResultSet queryResultSet = + databaseClient + .singleUseReadOnlyTransaction() + .executeQuery(Statement.of(indexesStmt), Options.tag("query=findAllTableIndexes"))) { + while (queryResultSet.next()) { + result.add(queryResultSet.getString("index_name")); + } + } + return result; + } + /** * Fetches the partition metadata row data for the given partition token. * diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java new file mode 100644 index 000000000000..07d7b80676de --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNames.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner.changestreams.dao; + +import java.io.Serializable; +import java.util.Objects; +import java.util.UUID; +import javax.annotation.Nullable; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; + +/** + * Configuration for a partition metadata table. It encapsulates the name of the metadata table and + * indexes. + */ +public class PartitionMetadataTableNames implements Serializable { + + private static final long serialVersionUID = 8848098877671834584L; + + /** PostgreSQL max table and index length is 63 bytes. */ + @VisibleForTesting static final int MAX_NAME_LENGTH = 63; + + private static final String PARTITION_METADATA_TABLE_NAME_FORMAT = "Metadata_%s_%s"; + private static final String WATERMARK_INDEX_NAME_FORMAT = "WatermarkIdx_%s_%s"; + private static final String CREATED_AT_START_TIMESTAMP_INDEX_NAME_FORMAT = "CreatedAtIdx_%s_%s"; + + /** + * Generates a unique name for the partition metadata table and its indexes. The table name will + * be in the form of {@code "Metadata__"}. The watermark index will be in the + * form of {@code "WatermarkIdx__}. The createdAt / start timestamp index will + * be in the form of {@code "CreatedAtIdx__}. + * + * @param databaseId The database id where the table will be created + * @return the unique generated names of the partition metadata ddl + */ + public static PartitionMetadataTableNames generateRandom(String databaseId) { + UUID uuid = UUID.randomUUID(); + + String table = generateName(PARTITION_METADATA_TABLE_NAME_FORMAT, databaseId, uuid); + String watermarkIndex = generateName(WATERMARK_INDEX_NAME_FORMAT, databaseId, uuid); + String createdAtIndex = + generateName(CREATED_AT_START_TIMESTAMP_INDEX_NAME_FORMAT, databaseId, uuid); + + return new PartitionMetadataTableNames(table, watermarkIndex, createdAtIndex); + } + + /** + * Encapsulates a selected table name. Index names are generated, but will only be used if the + * given table does not exist. The watermark index will be in the form of {@code + * "WatermarkIdx__}. The createdAt / start timestamp index will be in the form + * of {@code "CreatedAtIdx__}. + * + * @param databaseId The database id for the table + * @param table The table name to be used + * @return an instance with the table name and generated index names + */ + public static PartitionMetadataTableNames fromExistingTable(String databaseId, String table) { + UUID uuid = UUID.randomUUID(); + + String watermarkIndex = generateName(WATERMARK_INDEX_NAME_FORMAT, databaseId, uuid); + String createdAtIndex = + generateName(CREATED_AT_START_TIMESTAMP_INDEX_NAME_FORMAT, databaseId, uuid); + return new PartitionMetadataTableNames(table, watermarkIndex, createdAtIndex); + } + + private static String generateName(String template, String databaseId, UUID uuid) { + String name = String.format(template, databaseId, uuid).replaceAll("-", "_"); + if (name.length() > MAX_NAME_LENGTH) { + return name.substring(0, MAX_NAME_LENGTH); + } + return name; + } + + private final String tableName; + private final String watermarkIndexName; + private final String createdAtIndexName; + + public PartitionMetadataTableNames( + String tableName, String watermarkIndexName, String createdAtIndexName) { + this.tableName = tableName; + this.watermarkIndexName = watermarkIndexName; + this.createdAtIndexName = createdAtIndexName; + } + + public String getTableName() { + return tableName; + } + + public String getWatermarkIndexName() { + return watermarkIndexName; + } + + public String getCreatedAtIndexName() { + return createdAtIndexName; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof PartitionMetadataTableNames)) { + return false; + } + PartitionMetadataTableNames that = (PartitionMetadataTableNames) o; + return Objects.equals(tableName, that.tableName) + && Objects.equals(watermarkIndexName, that.watermarkIndexName) + && Objects.equals(createdAtIndexName, that.createdAtIndexName); + } + + @Override + public int hashCode() { + return Objects.hash(tableName, watermarkIndexName, createdAtIndexName); + } + + @Override + public String toString() { + return "PartitionMetadataTableNames{" + + "tableName='" + + tableName + + '\'' + + ", watermarkIndexName='" + + watermarkIndexName + + '\'' + + ", createdAtIndexName='" + + createdAtIndexName + + '\'' + + '}'; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java index a048c885a001..f8aa497292bf 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/CleanUpReadChangeStreamDoFn.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.spanner.changestreams.dofn; import java.io.Serializable; +import java.util.List; import org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.DaoFactory; import org.apache.beam.sdk.transforms.DoFn; @@ -33,6 +34,7 @@ public CleanUpReadChangeStreamDoFn(DaoFactory daoFactory) { @ProcessElement public void processElement(OutputReceiver receiver) { - daoFactory.getPartitionMetadataAdminDao().deletePartitionMetadataTable(); + List indexes = daoFactory.getPartitionMetadataDao().findAllTableIndexes(); + daoFactory.getPartitionMetadataAdminDao().deletePartitionMetadataTable(indexes); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java index 387ffd603b14..ca93f34bf1ba 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dofn/InitializeDoFn.java @@ -64,6 +64,7 @@ public InitializeDoFn( public void processElement(OutputReceiver receiver) { PartitionMetadataDao partitionMetadataDao = daoFactory.getPartitionMetadataDao(); if (!partitionMetadataDao.tableExists()) { + // Creates partition metadata table and associated indexes daoFactory.getPartitionMetadataAdminDao().createPartitionMetadataTable(); createFakeParentPartition(); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java deleted file mode 100644 index f15fc5307374..000000000000 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/NameGeneratorTest.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.spanner.changestreams; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -import org.junit.Test; - -public class NameGeneratorTest { - private static final int MAXIMUM_POSTGRES_TABLE_NAME_LENGTH = 63; - - @Test - public void testGenerateMetadataTableNameRemovesHyphens() { - final String tableName = - NameGenerator.generatePartitionMetadataTableName("my-database-id-12345"); - assertFalse(tableName.contains("-")); - } - - @Test - public void testGenerateMetadataTableNameIsShorterThan64Characters() { - final String tableName = - NameGenerator.generatePartitionMetadataTableName("my-database-id1-maximum-length"); - assertTrue(tableName.length() <= MAXIMUM_POSTGRES_TABLE_NAME_LENGTH); - } -} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java index 3752c2fb3afc..02b9d111583b 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataAdminDaoTest.java @@ -33,7 +33,9 @@ import com.google.cloud.spanner.ErrorCode; import com.google.cloud.spanner.SpannerException; import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata; +import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.Iterator; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -58,6 +60,8 @@ public class PartitionMetadataAdminDaoTest { private static final String DATABASE_ID = "SPANNER_DATABASE"; private static final String TABLE_NAME = "SPANNER_TABLE"; + private static final String WATERMARK_INDEX_NAME = "WATERMARK_INDEX"; + private static final String CREATED_AT_INDEX_NAME = "CREATED_AT_INDEX"; private static final int TIMEOUT_MINUTES = 10; @@ -68,12 +72,14 @@ public class PartitionMetadataAdminDaoTest { @Before public void setUp() { databaseAdminClient = mock(DatabaseAdminClient.class); + PartitionMetadataTableNames names = + new PartitionMetadataTableNames(TABLE_NAME, WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME); partitionMetadataAdminDao = new PartitionMetadataAdminDao( - databaseAdminClient, INSTANCE_ID, DATABASE_ID, TABLE_NAME, Dialect.GOOGLE_STANDARD_SQL); + databaseAdminClient, INSTANCE_ID, DATABASE_ID, names, Dialect.GOOGLE_STANDARD_SQL); partitionMetadataAdminDaoPostgres = new PartitionMetadataAdminDao( - databaseAdminClient, INSTANCE_ID, DATABASE_ID, TABLE_NAME, Dialect.POSTGRESQL); + databaseAdminClient, INSTANCE_ID, DATABASE_ID, names, Dialect.POSTGRESQL); op = (OperationFuture) mock(OperationFuture.class); statements = ArgumentCaptor.forClass(Iterable.class); when(databaseAdminClient.updateDatabaseDdl( @@ -89,9 +95,9 @@ public void testCreatePartitionMetadataTable() throws Exception { .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); Iterator it = statements.getValue().iterator(); - assertTrue(it.next().contains("CREATE TABLE")); - assertTrue(it.next().contains("CREATE INDEX")); - assertTrue(it.next().contains("CREATE INDEX")); + assertTrue(it.next().contains("CREATE TABLE IF NOT EXISTS")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS")); } @Test @@ -102,9 +108,9 @@ public void testCreatePartitionMetadataTablePostgres() throws Exception { .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); Iterator it = statements.getValue().iterator(); - assertTrue(it.next().contains("CREATE TABLE \"")); - assertTrue(it.next().contains("CREATE INDEX \"")); - assertTrue(it.next().contains("CREATE INDEX \"")); + assertTrue(it.next().contains("CREATE TABLE IF NOT EXISTS \"")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS \"")); + assertTrue(it.next().contains("CREATE INDEX IF NOT EXISTS \"")); } @Test @@ -133,7 +139,8 @@ public void testCreatePartitionMetadataTableWithInterruptedException() throws Ex @Test public void testDeletePartitionMetadataTable() throws Exception { when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); - partitionMetadataAdminDao.deletePartitionMetadataTable(); + partitionMetadataAdminDao.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); verify(databaseAdminClient, times(1)) .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); @@ -143,10 +150,22 @@ public void testDeletePartitionMetadataTable() throws Exception { assertTrue(it.next().contains("DROP TABLE")); } + @Test + public void testDeletePartitionMetadataTableWithNoIndexes() throws Exception { + when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); + partitionMetadataAdminDao.deletePartitionMetadataTable(Collections.emptyList()); + verify(databaseAdminClient, times(1)) + .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); + assertEquals(1, ((Collection) statements.getValue()).size()); + Iterator it = statements.getValue().iterator(); + assertTrue(it.next().contains("DROP TABLE")); + } + @Test public void testDeletePartitionMetadataTablePostgres() throws Exception { when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); - partitionMetadataAdminDaoPostgres.deletePartitionMetadataTable(); + partitionMetadataAdminDaoPostgres.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); verify(databaseAdminClient, times(1)) .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); assertEquals(3, ((Collection) statements.getValue()).size()); @@ -156,11 +175,23 @@ public void testDeletePartitionMetadataTablePostgres() throws Exception { assertTrue(it.next().contains("DROP TABLE \"")); } + @Test + public void testDeletePartitionMetadataTablePostgresWithNoIndexes() throws Exception { + when(op.get(TIMEOUT_MINUTES, TimeUnit.MINUTES)).thenReturn(null); + partitionMetadataAdminDaoPostgres.deletePartitionMetadataTable(Collections.emptyList()); + verify(databaseAdminClient, times(1)) + .updateDatabaseDdl(eq(INSTANCE_ID), eq(DATABASE_ID), statements.capture(), isNull()); + assertEquals(1, ((Collection) statements.getValue()).size()); + Iterator it = statements.getValue().iterator(); + assertTrue(it.next().contains("DROP TABLE \"")); + } + @Test public void testDeletePartitionMetadataTableWithTimeoutException() throws Exception { when(op.get(10, TimeUnit.MINUTES)).thenThrow(new TimeoutException(TIMED_OUT)); try { - partitionMetadataAdminDao.deletePartitionMetadataTable(); + partitionMetadataAdminDao.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); fail(); } catch (SpannerException e) { assertTrue(e.getMessage().contains(TIMED_OUT)); @@ -171,7 +202,8 @@ public void testDeletePartitionMetadataTableWithTimeoutException() throws Except public void testDeletePartitionMetadataTableWithInterruptedException() throws Exception { when(op.get(10, TimeUnit.MINUTES)).thenThrow(new InterruptedException(INTERRUPTED)); try { - partitionMetadataAdminDao.deletePartitionMetadataTable(); + partitionMetadataAdminDao.deletePartitionMetadataTable( + Arrays.asList(WATERMARK_INDEX_NAME, CREATED_AT_INDEX_NAME)); fail(); } catch (SpannerException e) { assertEquals(ErrorCode.CANCELLED, e.getErrorCode()); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java new file mode 100644 index 000000000000..2aae5b26a2cb --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/dao/PartitionMetadataTableNamesTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.spanner.changestreams.dao; + +import static org.apache.beam.sdk.io.gcp.spanner.changestreams.dao.PartitionMetadataTableNames.MAX_NAME_LENGTH; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +public class PartitionMetadataTableNamesTest { + @Test + public void testGeneratePartitionMetadataNamesRemovesHyphens() { + String databaseId = "my-database-id-12345"; + + PartitionMetadataTableNames names1 = PartitionMetadataTableNames.generateRandom(databaseId); + assertFalse(names1.getTableName().contains("-")); + assertFalse(names1.getWatermarkIndexName().contains("-")); + assertFalse(names1.getCreatedAtIndexName().contains("-")); + + PartitionMetadataTableNames names2 = PartitionMetadataTableNames.generateRandom(databaseId); + assertNotEquals(names1.getTableName(), names2.getTableName()); + assertNotEquals(names1.getWatermarkIndexName(), names2.getWatermarkIndexName()); + assertNotEquals(names1.getCreatedAtIndexName(), names2.getCreatedAtIndexName()); + } + + @Test + public void testGeneratePartitionMetadataNamesIsShorterThan64Characters() { + PartitionMetadataTableNames names = + PartitionMetadataTableNames.generateRandom( + "my-database-id-larger-than-maximum-length-1234567890-1234567890-1234567890"); + assertTrue(names.getTableName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getWatermarkIndexName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getCreatedAtIndexName().length() <= MAX_NAME_LENGTH); + + names = PartitionMetadataTableNames.generateRandom("d"); + assertTrue(names.getTableName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getWatermarkIndexName().length() <= MAX_NAME_LENGTH); + assertTrue(names.getCreatedAtIndexName().length() <= MAX_NAME_LENGTH); + } + + @Test + public void testPartitionMetadataNamesFromExistingTable() { + PartitionMetadataTableNames names1 = + PartitionMetadataTableNames.fromExistingTable("databaseid", "mytable"); + assertEquals("mytable", names1.getTableName()); + assertFalse(names1.getWatermarkIndexName().contains("-")); + assertFalse(names1.getCreatedAtIndexName().contains("-")); + + PartitionMetadataTableNames names2 = + PartitionMetadataTableNames.fromExistingTable("databaseid", "mytable"); + assertEquals("mytable", names2.getTableName()); + assertNotEquals(names1.getWatermarkIndexName(), names2.getWatermarkIndexName()); + assertNotEquals(names1.getCreatedAtIndexName(), names2.getCreatedAtIndexName()); + } +} From 68f1543b6bfe4ceaa752c7f16fc2cae7393211fd Mon Sep 17 00:00:00 2001 From: martin trieu Date: Mon, 21 Oct 2024 03:05:43 -0600 Subject: [PATCH 042/181] Simplify budget distribution logic and new worker metadata consumption (#32775) --- .../FanOutStreamingEngineWorkerHarness.java | 379 ++++++++-------- .../harness/GlobalDataStreamSender.java | 63 +++ ...tate.java => StreamingEngineBackends.java} | 30 +- .../harness/WindmillStreamSender.java | 25 +- .../worker/windmill/WindmillEndpoints.java | 28 +- .../windmill/WindmillServiceAddress.java | 22 +- .../windmill/client/WindmillStream.java | 7 +- .../client/grpc/GrpcDirectGetWorkStream.java | 286 ++++++++----- .../client/grpc/GrpcGetDataStream.java | 2 +- .../client/grpc/GrpcGetWorkStream.java | 10 +- .../grpc/GrpcWindmillStreamFactory.java | 6 +- .../grpc/stubs/WindmillChannelFactory.java | 17 +- .../budget/EvenGetWorkBudgetDistributor.java | 59 +-- .../budget/GetWorkBudgetDistributors.java | 6 +- .../work/budget/GetWorkBudgetSpender.java | 8 +- .../dataflow/worker/FakeWindmillServer.java | 10 +- ...anOutStreamingEngineWorkerHarnessTest.java | 111 ++--- .../harness/WindmillStreamSenderTest.java | 4 +- .../grpc/GrpcDirectGetWorkStreamTest.java | 405 ++++++++++++++++++ .../EvenGetWorkBudgetDistributorTest.java | 186 ++------ 20 files changed, 998 insertions(+), 666 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java rename runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/{StreamingEngineConnectionState.java => StreamingEngineBackends.java} (55%) create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 3556b7ce2919..458cf57ca8e7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -20,20 +20,25 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; -import java.util.Collection; -import java.util.List; +import java.io.Closeable; +import java.util.HashSet; import java.util.Map.Entry; +import java.util.NoSuchElementException; import java.util.Optional; -import java.util.Queue; -import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import java.util.stream.Collectors; import javax.annotation.CheckReturnValue; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; @@ -54,18 +59,14 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetDistributor; -import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetRefresher; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.util.MoreFutures; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -80,32 +81,39 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorkerHarness { private static final Logger LOG = LoggerFactory.getLogger(FanOutStreamingEngineWorkerHarness.class); - private static final String PUBLISH_NEW_WORKER_METADATA_THREAD = "PublishNewWorkerMetadataThread"; - private static final String CONSUME_NEW_WORKER_METADATA_THREAD = "ConsumeNewWorkerMetadataThread"; + private static final String WORKER_METADATA_CONSUMER_THREAD_NAME = + "WindmillWorkerMetadataConsumerThread"; + private static final String STREAM_MANAGER_THREAD_NAME = "WindmillStreamManager-%d"; private final JobHeader jobHeader; private final GrpcWindmillStreamFactory streamFactory; private final WorkItemScheduler workItemScheduler; private final ChannelCachingStubFactory channelCachingStubFactory; private final GrpcDispatcherClient dispatcherClient; - private final AtomicBoolean isBudgetRefreshPaused; - private final GetWorkBudgetRefresher getWorkBudgetRefresher; - private final AtomicReference lastBudgetRefresh; + private final GetWorkBudgetDistributor getWorkBudgetDistributor; + private final GetWorkBudget totalGetWorkBudget; private final ThrottleTimer getWorkerMetadataThrottleTimer; - private final ExecutorService newWorkerMetadataPublisher; - private final ExecutorService newWorkerMetadataConsumer; - private final long clientId; - private final Supplier getWorkerMetadataStream; - private final Queue newWindmillEndpoints; private final Function workCommitterFactory; private final ThrottlingGetDataMetricTracker getDataMetricTracker; + private final ExecutorService windmillStreamManager; + private final ExecutorService workerMetadataConsumer; + private final Object metadataLock = new Object(); /** Writes are guarded by synchronization, reads are lock free. */ - private final AtomicReference connections; + private final AtomicReference backends; - private volatile boolean started; + @GuardedBy("this") + private long activeMetadataVersion; + + @GuardedBy("metadataLock") + private long pendingMetadataVersion; + + @GuardedBy("this") + private boolean started; + + @GuardedBy("this") + private @Nullable GetWorkerMetadataStream getWorkerMetadataStream; - @SuppressWarnings("FutureReturnValueIgnored") private FanOutStreamingEngineWorkerHarness( JobHeader jobHeader, GetWorkBudget totalGetWorkBudget, @@ -114,7 +122,6 @@ private FanOutStreamingEngineWorkerHarness( ChannelCachingStubFactory channelCachingStubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - long clientId, Function workCommitterFactory, ThrottlingGetDataMetricTracker getDataMetricTracker) { this.jobHeader = jobHeader; @@ -122,42 +129,21 @@ private FanOutStreamingEngineWorkerHarness( this.started = false; this.streamFactory = streamFactory; this.workItemScheduler = workItemScheduler; - this.connections = new AtomicReference<>(StreamingEngineConnectionState.EMPTY); + this.backends = new AtomicReference<>(StreamingEngineBackends.EMPTY); this.channelCachingStubFactory = channelCachingStubFactory; this.dispatcherClient = dispatcherClient; - this.isBudgetRefreshPaused = new AtomicBoolean(false); this.getWorkerMetadataThrottleTimer = new ThrottleTimer(); - this.newWorkerMetadataPublisher = - singleThreadedExecutorServiceOf(PUBLISH_NEW_WORKER_METADATA_THREAD); - this.newWorkerMetadataConsumer = - singleThreadedExecutorServiceOf(CONSUME_NEW_WORKER_METADATA_THREAD); - this.clientId = clientId; - this.lastBudgetRefresh = new AtomicReference<>(Instant.EPOCH); - this.newWindmillEndpoints = Queues.synchronizedQueue(EvictingQueue.create(1)); - this.getWorkBudgetRefresher = - new GetWorkBudgetRefresher( - isBudgetRefreshPaused::get, - () -> { - getWorkBudgetDistributor.distributeBudget( - connections.get().windmillStreams().values(), totalGetWorkBudget); - lastBudgetRefresh.set(Instant.now()); - }); - this.getWorkerMetadataStream = - Suppliers.memoize( - () -> - streamFactory.createGetWorkerMetadataStream( - dispatcherClient.getWindmillMetadataServiceStubBlocking(), - getWorkerMetadataThrottleTimer, - endpoints -> - // Run this on a separate thread than the grpc stream thread. - newWorkerMetadataPublisher.submit( - () -> newWindmillEndpoints.add(endpoints)))); + this.windmillStreamManager = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat(STREAM_MANAGER_THREAD_NAME).build()); + this.workerMetadataConsumer = + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder().setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME).build()); + this.getWorkBudgetDistributor = getWorkBudgetDistributor; + this.totalGetWorkBudget = totalGetWorkBudget; + this.activeMetadataVersion = Long.MIN_VALUE; this.workCommitterFactory = workCommitterFactory; - } - - private static ExecutorService singleThreadedExecutorServiceOf(String threadName) { - return Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder().setNameFormat(threadName).build()); + this.getWorkerMetadataStream = null; } /** @@ -183,7 +169,6 @@ public static FanOutStreamingEngineWorkerHarness create( channelCachingStubFactory, getWorkBudgetDistributor, dispatcherClient, - /* clientId= */ new Random().nextLong(), workCommitterFactory, getDataMetricTracker); } @@ -197,7 +182,6 @@ static FanOutStreamingEngineWorkerHarness forTesting( ChannelCachingStubFactory stubFactory, GetWorkBudgetDistributor getWorkBudgetDistributor, GrpcDispatcherClient dispatcherClient, - long clientId, Function workCommitterFactory, ThrottlingGetDataMetricTracker getDataMetricTracker) { FanOutStreamingEngineWorkerHarness fanOutStreamingEngineWorkProvider = @@ -209,201 +193,218 @@ static FanOutStreamingEngineWorkerHarness forTesting( stubFactory, getWorkBudgetDistributor, dispatcherClient, - clientId, workCommitterFactory, getDataMetricTracker); fanOutStreamingEngineWorkProvider.start(); return fanOutStreamingEngineWorkProvider; } - @SuppressWarnings("ReturnValueIgnored") @Override public synchronized void start() { - Preconditions.checkState(!started, "StreamingEngineClient cannot start twice."); - // Starts the stream, this value is memoized. - getWorkerMetadataStream.get(); - startWorkerMetadataConsumer(); - getWorkBudgetRefresher.start(); + Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness cannot start twice."); + getWorkerMetadataStream = + streamFactory.createGetWorkerMetadataStream( + dispatcherClient.getWindmillMetadataServiceStubBlocking(), + getWorkerMetadataThrottleTimer, + this::consumeWorkerMetadata); started = true; } public ImmutableSet currentWindmillEndpoints() { - return connections.get().windmillConnections().keySet().stream() + return backends.get().windmillStreams().keySet().stream() .map(Endpoint::directEndpoint) .filter(Optional::isPresent) .map(Optional::get) - .filter( - windmillServiceAddress -> - windmillServiceAddress.getKind() != WindmillServiceAddress.Kind.IPV6) - .map( - windmillServiceAddress -> - windmillServiceAddress.getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS - ? windmillServiceAddress.gcpServiceAddress() - : windmillServiceAddress.authenticatedGcpServiceAddress().gcpServiceAddress()) + .map(WindmillServiceAddress::getServiceAddress) .collect(toImmutableSet()); } /** - * Fetches {@link GetDataStream} mapped to globalDataKey if one exists, or defaults to {@link - * GetDataStream} pointing to dispatcher. + * Fetches {@link GetDataStream} mapped to globalDataKey if or throws {@link + * NoSuchElementException} if one is not found. */ private GetDataStream getGlobalDataStream(String globalDataKey) { - return Optional.ofNullable(connections.get().globalDataStreams().get(globalDataKey)) - .map(Supplier::get) - .orElseGet( - () -> - streamFactory.createGetDataStream( - dispatcherClient.getWindmillServiceStub(), new ThrottleTimer())); - } - - @SuppressWarnings("FutureReturnValueIgnored") - private void startWorkerMetadataConsumer() { - newWorkerMetadataConsumer.submit( - () -> { - while (true) { - Optional.ofNullable(newWindmillEndpoints.poll()) - .ifPresent(this::consumeWindmillWorkerEndpoints); - } - }); + return Optional.ofNullable(backends.get().globalDataStreams().get(globalDataKey)) + .map(GlobalDataStreamSender::get) + .orElseThrow( + () -> new NoSuchElementException("No endpoint for global data tag: " + globalDataKey)); } @VisibleForTesting @Override public synchronized void shutdown() { - Preconditions.checkState(started, "StreamingEngineClient never started."); - getWorkerMetadataStream.get().halfClose(); - getWorkBudgetRefresher.stop(); - newWorkerMetadataPublisher.shutdownNow(); - newWorkerMetadataConsumer.shutdownNow(); + Preconditions.checkState(started, "FanOutStreamingEngineWorkerHarness never started."); + Preconditions.checkNotNull(getWorkerMetadataStream).shutdown(); + workerMetadataConsumer.shutdownNow(); + closeStreamsNotIn(WindmillEndpoints.none()); channelCachingStubFactory.shutdown(); + + try { + Preconditions.checkNotNull(getWorkerMetadataStream).awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted waiting for GetWorkerMetadataStream to shutdown.", e); + } + + windmillStreamManager.shutdown(); + boolean isStreamManagerShutdown = false; + try { + isStreamManagerShutdown = windmillStreamManager.awaitTermination(30, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Interrupted waiting for windmillStreamManager to shutdown.", e); + } + if (!isStreamManagerShutdown) { + windmillStreamManager.shutdownNow(); + } + } + + private void consumeWorkerMetadata(WindmillEndpoints windmillEndpoints) { + synchronized (metadataLock) { + // Only process versions greater than what we currently have to prevent double processing of + // metadata. workerMetadataConsumer is single-threaded so we maintain ordering. + if (windmillEndpoints.version() > pendingMetadataVersion) { + pendingMetadataVersion = windmillEndpoints.version(); + workerMetadataConsumer.execute(() -> consumeWindmillWorkerEndpoints(windmillEndpoints)); + } + } } - /** - * {@link java.util.function.Consumer} used to update {@link #connections} on - * new backend worker metadata. - */ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWindmillEndpoints) { - isBudgetRefreshPaused.set(true); - LOG.info("Consuming new windmill endpoints: {}", newWindmillEndpoints); - ImmutableMap newWindmillConnections = - createNewWindmillConnections(newWindmillEndpoints.windmillEndpoints()); - - StreamingEngineConnectionState newConnectionsState = - StreamingEngineConnectionState.builder() - .setWindmillConnections(newWindmillConnections) - .setWindmillStreams( - closeStaleStreamsAndCreateNewStreams(newWindmillConnections.values())) + // Since this is run on a single threaded executor, multiple versions of the metadata maybe + // queued up while a previous version of the windmillEndpoints were being consumed. Only consume + // the endpoints if they are the most current version. + synchronized (metadataLock) { + if (newWindmillEndpoints.version() < pendingMetadataVersion) { + return; + } + } + + LOG.debug( + "Consuming new endpoints: {}. previous metadata version: {}, current metadata version: {}", + newWindmillEndpoints, + activeMetadataVersion, + newWindmillEndpoints.version()); + closeStreamsNotIn(newWindmillEndpoints); + ImmutableMap newStreams = + createAndStartNewStreams(newWindmillEndpoints.windmillEndpoints()).join(); + StreamingEngineBackends newBackends = + StreamingEngineBackends.builder() + .setWindmillStreams(newStreams) .setGlobalDataStreams( createNewGlobalDataStreams(newWindmillEndpoints.globalDataEndpoints())) .build(); + backends.set(newBackends); + getWorkBudgetDistributor.distributeBudget(newStreams.values(), totalGetWorkBudget); + activeMetadataVersion = newWindmillEndpoints.version(); + } + + /** Close the streams that are no longer valid asynchronously. */ + private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { + StreamingEngineBackends currentBackends = backends.get(); + currentBackends.windmillStreams().entrySet().stream() + .filter( + connectionAndStream -> + !newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey())) + .forEach( + entry -> + windmillStreamManager.execute( + () -> closeStreamSender(entry.getKey(), entry.getValue()))); - LOG.info( - "Setting new connections: {}. Previous connections: {}.", - newConnectionsState, - connections.get()); - connections.set(newConnectionsState); - isBudgetRefreshPaused.set(false); - getWorkBudgetRefresher.requestBudgetRefresh(); + Set newGlobalDataEndpoints = + new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values()); + currentBackends.globalDataStreams().values().stream() + .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) + .forEach( + sender -> + windmillStreamManager.execute(() -> closeStreamSender(sender.endpoint(), sender))); + } + + private void closeStreamSender(Endpoint endpoint, Closeable sender) { + LOG.debug("Closing streams to endpoint={}, sender={}", endpoint, sender); + try { + sender.close(); + endpoint.directEndpoint().ifPresent(channelCachingStubFactory::remove); + LOG.debug("Successfully closed streams to {}", endpoint); + } catch (Exception e) { + LOG.error("Error closing streams to endpoint={}, sender={}", endpoint, sender); + } + } + + private synchronized CompletableFuture> + createAndStartNewStreams(ImmutableSet newWindmillEndpoints) { + ImmutableMap currentStreams = backends.get().windmillStreams(); + return MoreFutures.allAsList( + newWindmillEndpoints.stream() + .map(endpoint -> getOrCreateWindmillStreamSenderFuture(endpoint, currentStreams)) + .collect(Collectors.toList())) + .thenApply( + backends -> backends.stream().collect(toImmutableMap(Pair::getLeft, Pair::getRight))) + .toCompletableFuture(); + } + + private CompletionStage> + getOrCreateWindmillStreamSenderFuture( + Endpoint endpoint, ImmutableMap currentStreams) { + return MoreFutures.supplyAsync( + () -> + Pair.of( + endpoint, + Optional.ofNullable(currentStreams.get(endpoint)) + .orElseGet(() -> createAndStartWindmillStreamSender(endpoint))), + windmillStreamManager); } /** Add up all the throttle times of all streams including GetWorkerMetadataStream. */ - public long getAndResetThrottleTimes() { - return connections.get().windmillStreams().values().stream() + public long getAndResetThrottleTime() { + return backends.get().windmillStreams().values().stream() .map(WindmillStreamSender::getAndResetThrottleTime) .reduce(0L, Long::sum) + getWorkerMetadataThrottleTimer.getAndResetThrottleTime(); } public long currentActiveCommitBytes() { - return connections.get().windmillStreams().values().stream() + return backends.get().windmillStreams().values().stream() .map(WindmillStreamSender::getCurrentActiveCommitBytes) .reduce(0L, Long::sum); } @VisibleForTesting - StreamingEngineConnectionState getCurrentConnections() { - return connections.get(); - } - - private synchronized ImmutableMap createNewWindmillConnections( - List newWindmillEndpoints) { - ImmutableMap currentConnections = - connections.get().windmillConnections(); - return newWindmillEndpoints.stream() - .collect( - toImmutableMap( - Function.identity(), - endpoint -> - // Reuse existing stubs if they exist. Optional.orElseGet only calls the - // supplier if the value is not present, preventing constructing expensive - // objects. - Optional.ofNullable(currentConnections.get(endpoint)) - .orElseGet( - () -> WindmillConnection.from(endpoint, this::createWindmillStub)))); + StreamingEngineBackends currentBackends() { + return backends.get(); } - private synchronized ImmutableMap - closeStaleStreamsAndCreateNewStreams(Collection newWindmillConnections) { - ImmutableMap currentStreams = - connections.get().windmillStreams(); - - // Close the streams that are no longer valid. - currentStreams.entrySet().stream() - .filter( - connectionAndStream -> !newWindmillConnections.contains(connectionAndStream.getKey())) - .forEach( - entry -> { - entry.getValue().closeAllStreams(); - entry.getKey().directEndpoint().ifPresent(channelCachingStubFactory::remove); - }); - - return newWindmillConnections.stream() - .collect( - toImmutableMap( - Function.identity(), - newConnection -> - Optional.ofNullable(currentStreams.get(newConnection)) - .orElseGet(() -> createAndStartWindmillStreamSenderFor(newConnection)))); - } - - private ImmutableMap> createNewGlobalDataStreams( + private ImmutableMap createNewGlobalDataStreams( ImmutableMap newGlobalDataEndpoints) { - ImmutableMap> currentGlobalDataStreams = - connections.get().globalDataStreams(); + ImmutableMap currentGlobalDataStreams = + backends.get().globalDataStreams(); return newGlobalDataEndpoints.entrySet().stream() .collect( toImmutableMap( Entry::getKey, keyedEndpoint -> - existingOrNewGetDataStreamFor(keyedEndpoint, currentGlobalDataStreams))); + getOrCreateGlobalDataSteam(keyedEndpoint, currentGlobalDataStreams))); } - private Supplier existingOrNewGetDataStreamFor( + private GlobalDataStreamSender getOrCreateGlobalDataSteam( Entry keyedEndpoint, - ImmutableMap> currentGlobalDataStreams) { - return Preconditions.checkNotNull( - currentGlobalDataStreams.getOrDefault( - keyedEndpoint.getKey(), + ImmutableMap currentGlobalDataStreams) { + return Optional.ofNullable(currentGlobalDataStreams.get(keyedEndpoint.getKey())) + .orElseGet( () -> - streamFactory.createGetDataStream( - newOrExistingStubFor(keyedEndpoint.getValue()), new ThrottleTimer()))); - } - - private CloudWindmillServiceV1Alpha1Stub newOrExistingStubFor(Endpoint endpoint) { - return Optional.ofNullable(connections.get().windmillConnections().get(endpoint)) - .map(WindmillConnection::stub) - .orElseGet(() -> createWindmillStub(endpoint)); + new GlobalDataStreamSender( + () -> + streamFactory.createGetDataStream( + createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer()), + keyedEndpoint.getValue())); } - private WindmillStreamSender createAndStartWindmillStreamSenderFor( - WindmillConnection connection) { - // Initially create each stream with no budget. The budget will be eventually assigned by the - // GetWorkBudgetDistributor. + private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint endpoint) { WindmillStreamSender windmillStreamSender = WindmillStreamSender.create( - connection, + WindmillConnection.from(endpoint, this::createWindmillStub), GetWorkRequest.newBuilder() - .setClientId(clientId) + .setClientId(jobHeader.getClientId()) .setJobId(jobHeader.getJobId()) .setProjectId(jobHeader.getProjectId()) .setWorkerId(jobHeader.getWorkerId()) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java new file mode 100644 index 000000000000..ce5f3a7b6bfc --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.harness; + +import java.io.Closeable; +import java.util.function.Supplier; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; + +@Internal +@ThreadSafe +// TODO (m-trieu): replace Supplier with Stream after github.com/apache/beam/pull/32774/ is +// merged +final class GlobalDataStreamSender implements Closeable, Supplier { + private final Endpoint endpoint; + private final Supplier delegate; + private volatile boolean started; + + GlobalDataStreamSender(Supplier delegate, Endpoint endpoint) { + // Ensures that the Supplier is thread-safe + this.delegate = Suppliers.memoize(delegate::get); + this.started = false; + this.endpoint = endpoint; + } + + @Override + public GetDataStream get() { + if (!started) { + started = true; + } + + return delegate.get(); + } + + @Override + public void close() { + if (started) { + delegate.get().shutdown(); + } + } + + Endpoint endpoint() { + return endpoint; + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java similarity index 55% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java index 3c85ee6abe1f..14290b486830 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineConnectionState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamingEngineBackends.java @@ -18,47 +18,37 @@ package org.apache.beam.runners.dataflow.worker.streaming.harness; import com.google.auto.value.AutoValue; -import java.util.function.Supplier; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; /** - * Represents the current state of connections to Streaming Engine. Connections are updated when - * backend workers assigned to the key ranges being processed by this user worker change during + * Represents the current state of connections to the Streaming Engine backend. Backends are updated + * when backend workers assigned to the key ranges being processed by this user worker change during * pipeline execution. For example, changes can happen via autoscaling, load-balancing, or other * backend updates. */ @AutoValue -abstract class StreamingEngineConnectionState { - static final StreamingEngineConnectionState EMPTY = builder().build(); +abstract class StreamingEngineBackends { + static final StreamingEngineBackends EMPTY = builder().build(); static Builder builder() { - return new AutoValue_StreamingEngineConnectionState.Builder() - .setWindmillConnections(ImmutableMap.of()) + return new AutoValue_StreamingEngineBackends.Builder() .setWindmillStreams(ImmutableMap.of()) .setGlobalDataStreams(ImmutableMap.of()); } - abstract ImmutableMap windmillConnections(); - - abstract ImmutableMap windmillStreams(); + abstract ImmutableMap windmillStreams(); /** Mapping of GlobalDataIds and the direct GetDataStreams used fetch them. */ - abstract ImmutableMap> globalDataStreams(); + abstract ImmutableMap globalDataStreams(); @AutoValue.Builder abstract static class Builder { - public abstract Builder setWindmillConnections( - ImmutableMap value); - - public abstract Builder setWindmillStreams( - ImmutableMap value); + public abstract Builder setWindmillStreams(ImmutableMap value); public abstract Builder setGlobalDataStreams( - ImmutableMap> value); + ImmutableMap value); - public abstract StreamingEngineConnectionState build(); + public abstract StreamingEngineBackends build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java index 45aa403ee71b..744c3d74445f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; +import java.io.Closeable; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -49,7 +50,7 @@ * {@link GetWorkBudget} is set. * *

Once started, the underlying streams are "alive" until they are manually closed via {@link - * #closeAllStreams()}. + * #close()} ()}. * *

If closed, it means that the backend endpoint is no longer in the worker set. Once closed, * these instances are not reused. @@ -59,7 +60,7 @@ */ @Internal @ThreadSafe -final class WindmillStreamSender implements GetWorkBudgetSpender { +final class WindmillStreamSender implements GetWorkBudgetSpender, Closeable { private final AtomicBoolean started; private final AtomicReference getWorkBudget; private final Supplier getWorkStream; @@ -103,9 +104,9 @@ private WindmillStreamSender( connection, withRequestBudget(getWorkRequest, getWorkBudget.get()), streamingEngineThrottleTimers.getWorkThrottleTimer(), - () -> FixedStreamHeartbeatSender.create(getDataStream.get()), - () -> getDataClientFactory.apply(getDataStream.get()), - workCommitter, + FixedStreamHeartbeatSender.create(getDataStream.get()), + getDataClientFactory.apply(getDataStream.get()), + workCommitter.get(), workItemScheduler)); } @@ -141,7 +142,8 @@ void startStreams() { started.set(true); } - void closeAllStreams() { + @Override + public void close() { // Supplier.get() starts the stream which is an expensive operation as it initiates the // streaming RPCs by possibly making calls over the network. Do not close the streams unless // they have already been started. @@ -154,18 +156,13 @@ void closeAllStreams() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { - getWorkBudget.set(getWorkBudget.get().apply(itemsDelta, bytesDelta)); + public void setBudget(long items, long bytes) { + getWorkBudget.set(getWorkBudget.get().apply(items, bytes)); if (started.get()) { - getWorkStream.get().adjustBudget(itemsDelta, bytesDelta); + getWorkStream.get().setBudget(items, bytes); } } - @Override - public GetWorkBudget remainingBudget() { - return started.get() ? getWorkStream.get().remainingBudget() : getWorkBudget.get(); - } - long getAndResetThrottleTime() { return streamingEngineThrottleTimers.getAndResetThrottleTime(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java index d7ed83def43e..eb269eef848f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java @@ -17,8 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; import com.google.auto.value.AutoValue; import java.net.Inet6Address; @@ -27,8 +27,8 @@ import java.util.Map; import java.util.Optional; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress.AuthenticatedGcpServiceAddress; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,6 +41,14 @@ public abstract class WindmillEndpoints { private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class); + public static WindmillEndpoints none() { + return WindmillEndpoints.builder() + .setVersion(Long.MAX_VALUE) + .setWindmillEndpoints(ImmutableSet.of()) + .setGlobalDataEndpoints(ImmutableMap.of()) + .build(); + } + public static WindmillEndpoints from( Windmill.WorkerMetadataResponse workerMetadataResponseProto) { ImmutableMap globalDataServers = @@ -53,14 +61,15 @@ public static WindmillEndpoints from( endpoint.getValue(), workerMetadataResponseProto.getExternalEndpoint()))); - ImmutableList windmillServers = + ImmutableSet windmillServers = workerMetadataResponseProto.getWorkEndpointsList().stream() .map( endpointProto -> Endpoint.from(endpointProto, workerMetadataResponseProto.getExternalEndpoint())) - .collect(toImmutableList()); + .collect(toImmutableSet()); return WindmillEndpoints.builder() + .setVersion(workerMetadataResponseProto.getMetadataVersion()) .setGlobalDataEndpoints(globalDataServers) .setWindmillEndpoints(windmillServers) .build(); @@ -123,6 +132,9 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( directEndpointAddress.getHostAddress(), (int) endpointProto.getPort())); } + /** Version of the endpoints which increases with every modification. */ + public abstract long version(); + /** * Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns a map where the key * is a global data tag and the value is the endpoint where the data associated with the global @@ -138,7 +150,7 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( * Windmill servers. Returns a list of endpoints used to communicate with the corresponding * Windmill servers. */ - public abstract ImmutableList windmillEndpoints(); + public abstract ImmutableSet windmillEndpoints(); /** * Representation of an endpoint in {@link Windmill.WorkerMetadataResponse.Endpoint} proto with @@ -204,13 +216,15 @@ public abstract static class Builder { @AutoValue.Builder public abstract static class Builder { + public abstract Builder setVersion(long version); + public abstract Builder setGlobalDataEndpoints( ImmutableMap globalDataServers); public abstract Builder setWindmillEndpoints( - ImmutableList windmillServers); + ImmutableSet windmillServers); - abstract ImmutableList.Builder windmillEndpointsBuilder(); + abstract ImmutableSet.Builder windmillEndpointsBuilder(); public final Builder addWindmillEndpoint(WindmillEndpoints.Endpoint endpoint) { windmillEndpointsBuilder().add(endpoint); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java index 90f93b072673..0b895652efe2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServiceAddress.java @@ -19,38 +19,36 @@ import com.google.auto.value.AutoOneOf; import com.google.auto.value.AutoValue; -import java.net.Inet6Address; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; /** Used to create channels to communicate with Streaming Engine via gRpc. */ @AutoOneOf(WindmillServiceAddress.Kind.class) public abstract class WindmillServiceAddress { - public static WindmillServiceAddress create(Inet6Address ipv6Address) { - return AutoOneOf_WindmillServiceAddress.ipv6(ipv6Address); - } public static WindmillServiceAddress create(HostAndPort gcpServiceAddress) { return AutoOneOf_WindmillServiceAddress.gcpServiceAddress(gcpServiceAddress); } - public abstract Kind getKind(); + public static WindmillServiceAddress create( + AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) { + return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress( + authenticatedGcpServiceAddress); + } - public abstract Inet6Address ipv6(); + public abstract Kind getKind(); public abstract HostAndPort gcpServiceAddress(); public abstract AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress(); - public static WindmillServiceAddress create( - AuthenticatedGcpServiceAddress authenticatedGcpServiceAddress) { - return AutoOneOf_WindmillServiceAddress.authenticatedGcpServiceAddress( - authenticatedGcpServiceAddress); + public final HostAndPort getServiceAddress() { + return getKind() == WindmillServiceAddress.Kind.GCP_SERVICE_ADDRESS + ? gcpServiceAddress() + : authenticatedGcpServiceAddress().gcpServiceAddress(); } public enum Kind { - IPV6, GCP_SERVICE_ADDRESS, - // TODO(m-trieu): Use for direct connections when ALTS is enabled. AUTHENTICATED_GCP_SERVICE_ADDRESS } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index 31bd4e146a78..f26c56b14ec2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -56,10 +56,11 @@ public interface WindmillStream { @ThreadSafe interface GetWorkStream extends WindmillStream { /** Adjusts the {@link GetWorkBudget} for the stream. */ - void adjustBudget(long itemsDelta, long bytesDelta); + void setBudget(GetWorkBudget newBudget); - /** Returns the remaining in-flight {@link GetWorkBudget}. */ - GetWorkBudget remainingBudget(); + default void setBudget(long newItems, long newBytes) { + setBudget(GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build()); + } } /** Interface for streaming GetDataRequests to Windmill. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 19de998b1da8..b27ebc8e9eee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -21,9 +21,11 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; +import javax.annotation.concurrent.GuardedBy; +import net.jcip.annotations.ThreadSafe; import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -44,8 +46,8 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Implementation of {@link GetWorkStream} that passes along a specific {@link @@ -55,9 +57,10 @@ * these direct streams are used to facilitate these RPC calls to specific backend workers. */ @Internal -public final class GrpcDirectGetWorkStream +final class GrpcDirectGetWorkStream extends AbstractWindmillStream implements GetWorkStream { + private static final Logger LOG = LoggerFactory.getLogger(GrpcDirectGetWorkStream.class); private static final StreamingGetWorkRequest HEALTH_CHECK_REQUEST = StreamingGetWorkRequest.newBuilder() .setRequestExtension( @@ -67,15 +70,14 @@ public final class GrpcDirectGetWorkStream .build()) .build(); - private final AtomicReference inFlightBudget; - private final AtomicReference nextBudgetAdjustment; - private final AtomicReference pendingResponseBudget; - private final GetWorkRequest request; + private final GetWorkBudgetTracker budgetTracker; + private final GetWorkRequest requestHeader; private final WorkItemScheduler workItemScheduler; private final ThrottleTimer getWorkThrottleTimer; - private final Supplier heartbeatSender; - private final Supplier workCommitter; - private final Supplier getDataClient; + private final HeartbeatSender heartbeatSender; + private final WorkCommitter workCommitter; + private final GetDataClient getDataClient; + private final AtomicReference lastRequest; /** * Map of stream IDs to their buffers. Used to aggregate streaming gRPC response chunks as they @@ -92,15 +94,15 @@ private GrpcDirectGetWorkStream( StreamObserver, StreamObserver> startGetWorkRpcFn, - GetWorkRequest request, + GetWorkRequest requestHeader, BackOff backoff, StreamObserverFactory streamObserverFactory, Set> streamRegistry, int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { super( "GetWorkStream", @@ -110,19 +112,23 @@ private GrpcDirectGetWorkStream( streamRegistry, logEveryNStreamFailures, backendWorkerToken); - this.request = request; + this.requestHeader = requestHeader; this.getWorkThrottleTimer = getWorkThrottleTimer; this.workItemScheduler = workItemScheduler; this.workItemAssemblers = new ConcurrentHashMap<>(); - this.heartbeatSender = Suppliers.memoize(heartbeatSender::get); - this.workCommitter = Suppliers.memoize(workCommitter::get); - this.getDataClient = Suppliers.memoize(getDataClient::get); - this.inFlightBudget = new AtomicReference<>(GetWorkBudget.noBudget()); - this.nextBudgetAdjustment = new AtomicReference<>(GetWorkBudget.noBudget()); - this.pendingResponseBudget = new AtomicReference<>(GetWorkBudget.noBudget()); + this.heartbeatSender = heartbeatSender; + this.workCommitter = workCommitter; + this.getDataClient = getDataClient; + this.lastRequest = new AtomicReference<>(); + this.budgetTracker = + new GetWorkBudgetTracker( + GetWorkBudget.builder() + .setItems(requestHeader.getMaxItems()) + .setBytes(requestHeader.getMaxBytes()) + .build()); } - public static GrpcDirectGetWorkStream create( + static GrpcDirectGetWorkStream create( String backendWorkerToken, Function< StreamObserver, @@ -134,9 +140,9 @@ public static GrpcDirectGetWorkStream create( Set> streamRegistry, int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { GrpcDirectGetWorkStream getWorkStream = new GrpcDirectGetWorkStream( @@ -165,46 +171,52 @@ private static Watermarks createWatermarks( .build(); } - private void sendRequestExtension(GetWorkBudget adjustment) { - inFlightBudget.getAndUpdate(budget -> budget.apply(adjustment)); - StreamingGetWorkRequest extension = - StreamingGetWorkRequest.newBuilder() - .setRequestExtension( - Windmill.StreamingGetWorkRequestExtension.newBuilder() - .setMaxItems(adjustment.items()) - .setMaxBytes(adjustment.bytes())) - .build(); - - executor() - .execute( - () -> { - try { - send(extension); - } catch (IllegalStateException e) { - // Stream was closed. - } - }); + /** + * @implNote Do not lock/synchronize here due to this running on grpc serial executor for message + * which can deadlock since we send on the stream beneath the synchronization. {@link + * AbstractWindmillStream#send(Object)} is synchronized so the sends are already guarded. + */ + private void maybeSendRequestExtension(GetWorkBudget extension) { + if (extension.items() > 0 || extension.bytes() > 0) { + executeSafely( + () -> { + StreamingGetWorkRequest request = + StreamingGetWorkRequest.newBuilder() + .setRequestExtension( + Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(extension.items()) + .setMaxBytes(extension.bytes())) + .build(); + lastRequest.set(request); + budgetTracker.recordBudgetRequested(extension); + try { + send(request); + } catch (IllegalStateException e) { + // Stream was closed. + } + }); + } } @Override protected synchronized void onNewStream() { workItemAssemblers.clear(); - // Add the current in-flight budget to the next adjustment. Only positive values are allowed - // here - // with negatives defaulting to 0, since GetWorkBudgets cannot be created with negative values. - GetWorkBudget budgetAdjustment = nextBudgetAdjustment.get().apply(inFlightBudget.get()); - inFlightBudget.set(budgetAdjustment); - send( - StreamingGetWorkRequest.newBuilder() - .setRequest( - request - .toBuilder() - .setMaxBytes(budgetAdjustment.bytes()) - .setMaxItems(budgetAdjustment.items())) - .build()); - - // We just sent the budget, reset it. - nextBudgetAdjustment.set(GetWorkBudget.noBudget()); + if (!isShutdown()) { + budgetTracker.reset(); + GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); + StreamingGetWorkRequest request = + StreamingGetWorkRequest.newBuilder() + .setRequest( + requestHeader + .toBuilder() + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build()) + .build(); + lastRequest.set(request); + budgetTracker.recordBudgetRequested(initialGetWorkBudget); + send(request); + } } @Override @@ -216,8 +228,9 @@ protected boolean hasPendingRequests() { public void appendSpecificHtml(PrintWriter writer) { // Number of buffers is same as distinct workers that sent work on this stream. writer.format( - "GetWorkStream: %d buffers, %s inflight budget allowed.", - workItemAssemblers.size(), inFlightBudget.get()); + "GetWorkStream: %d buffers, " + "last sent request: %s; ", + workItemAssemblers.size(), lastRequest.get()); + writer.print(budgetTracker.debugString()); } @Override @@ -235,30 +248,22 @@ protected void onResponse(StreamingGetWorkResponseChunk chunk) { } private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) { - // Record the fact that there are now fewer outstanding messages and bytes on the stream. - inFlightBudget.updateAndGet(budget -> budget.subtract(1, assembledWorkItem.bufferedSize())); WorkItem workItem = assembledWorkItem.workItem(); GetWorkResponseChunkAssembler.ComputationMetadata metadata = assembledWorkItem.computationMetadata(); - pendingResponseBudget.getAndUpdate(budget -> budget.apply(1, workItem.getSerializedSize())); - try { - workItemScheduler.scheduleWork( - workItem, - createWatermarks(workItem, Preconditions.checkNotNull(metadata)), - createProcessingContext(Preconditions.checkNotNull(metadata.computationId())), - assembledWorkItem.latencyAttributions()); - } finally { - pendingResponseBudget.getAndUpdate(budget -> budget.apply(-1, -workItem.getSerializedSize())); - } + workItemScheduler.scheduleWork( + workItem, + createWatermarks(workItem, metadata), + createProcessingContext(metadata.computationId()), + assembledWorkItem.latencyAttributions()); + budgetTracker.recordBudgetReceived(assembledWorkItem.bufferedSize()); + GetWorkBudget extension = budgetTracker.computeBudgetExtension(); + maybeSendRequestExtension(extension); } private Work.ProcessingContext createProcessingContext(String computationId) { return Work.createProcessingContext( - computationId, - getDataClient.get(), - workCommitter.get()::commit, - heartbeatSender.get(), - backendWorkerToken()); + computationId, getDataClient, workCommitter::commit, heartbeatSender, backendWorkerToken()); } @Override @@ -267,25 +272,110 @@ protected void startThrottleTimer() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { - GetWorkBudget adjustment = - nextBudgetAdjustment - // Get the current value, and reset the nextBudgetAdjustment. This will be set again - // when adjustBudget is called. - .getAndUpdate(unused -> GetWorkBudget.noBudget()) - .apply(itemsDelta, bytesDelta); - sendRequestExtension(adjustment); + public void setBudget(GetWorkBudget newBudget) { + GetWorkBudget extension = budgetTracker.consumeAndComputeBudgetUpdate(newBudget); + maybeSendRequestExtension(extension); } - @Override - public GetWorkBudget remainingBudget() { - // Snapshot the current budgets. - GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get(); - GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get(); - GetWorkBudget currentInflightBudget = inFlightBudget.get(); - - return currentPendingResponseBudget - .apply(currentNextBudgetAdjustment) - .apply(currentInflightBudget); + private void executeSafely(Runnable runnable) { + try { + executor().execute(runnable); + } catch (RejectedExecutionException e) { + LOG.debug("{} has been shutdown.", getClass()); + } + } + + /** + * Tracks sent, received, max {@link GetWorkBudget} and uses this information to generate request + * extensions. + */ + @ThreadSafe + private static final class GetWorkBudgetTracker { + + @GuardedBy("GetWorkBudgetTracker.this") + private GetWorkBudget maxGetWorkBudget; + + @GuardedBy("GetWorkBudgetTracker.this") + private long itemsRequested = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long bytesRequested = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long itemsReceived = 0; + + @GuardedBy("GetWorkBudgetTracker.this") + private long bytesReceived = 0; + + private GetWorkBudgetTracker(GetWorkBudget maxGetWorkBudget) { + this.maxGetWorkBudget = maxGetWorkBudget; + } + + private synchronized void reset() { + itemsRequested = 0; + bytesRequested = 0; + itemsReceived = 0; + bytesReceived = 0; + } + + private synchronized String debugString() { + return String.format( + "max budget: %s; " + + "in-flight budget: %s; " + + "total budget requested: %s; " + + "total budget received: %s.", + maxGetWorkBudget, inFlightBudget(), totalRequestedBudget(), totalReceivedBudget()); + } + + /** Consumes the new budget and computes an extension based on the new budget. */ + private synchronized GetWorkBudget consumeAndComputeBudgetUpdate(GetWorkBudget newBudget) { + maxGetWorkBudget = newBudget; + return computeBudgetExtension(); + } + + private synchronized void recordBudgetRequested(GetWorkBudget budgetRequested) { + itemsRequested += budgetRequested.items(); + bytesRequested += budgetRequested.bytes(); + } + + private synchronized void recordBudgetReceived(long returnedBudget) { + itemsReceived++; + bytesReceived += returnedBudget; + } + + /** + * If the outstanding items or bytes limit has gotten too low, top both off with a + * GetWorkExtension. The goal is to keep the limits relatively close to their maximum values + * without sending too many extension requests. + */ + private synchronized GetWorkBudget computeBudgetExtension() { + // Expected items and bytes can go negative here, since WorkItems returned might be larger + // than the initially requested budget. + long inFlightItems = itemsRequested - itemsReceived; + long inFlightBytes = bytesRequested - bytesReceived; + + // Don't send negative budget extensions. + long requestBytes = Math.max(0, maxGetWorkBudget.bytes() - inFlightBytes); + long requestItems = Math.max(0, maxGetWorkBudget.items() - inFlightItems); + + return (inFlightItems > requestItems / 2 && inFlightBytes > requestBytes / 2) + ? GetWorkBudget.noBudget() + : GetWorkBudget.builder().setItems(requestItems).setBytes(requestBytes).build(); + } + + private synchronized GetWorkBudget inFlightBudget() { + return GetWorkBudget.builder() + .setItems(itemsRequested - itemsReceived) + .setBytes(bytesRequested - bytesReceived) + .build(); + } + + private synchronized GetWorkBudget totalRequestedBudget() { + return GetWorkBudget.builder().setItems(itemsRequested).setBytes(bytesRequested).build(); + } + + private synchronized GetWorkBudget totalReceivedBudget() { + return GetWorkBudget.builder().setItems(itemsReceived).setBytes(bytesReceived).build(); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 0e9a0c6316ee..c99e05a77074 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -59,7 +59,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class GrpcGetDataStream +final class GrpcGetDataStream extends AbstractWindmillStream implements GetDataStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStream.class); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index 09ecbf3f3051..a368f3fec235 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -194,15 +194,7 @@ protected void startThrottleTimer() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { + public void setBudget(GetWorkBudget newBudget) { // no-op } - - @Override - public GetWorkBudget remainingBudget() { - return GetWorkBudget.builder() - .setBytes(request.getMaxBytes() - inflightBytes.get()) - .setItems(request.getMaxItems() - inflightMessages.get()) - .build(); - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 92f031db9972..9e6a02d135e2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -198,9 +198,9 @@ public GetWorkStream createDirectGetWorkStream( WindmillConnection connection, GetWorkRequest request, ThrottleTimer getWorkThrottleTimer, - Supplier heartbeatSender, - Supplier getDataClient, - Supplier workCommitter, + HeartbeatSender heartbeatSender, + GetDataClient getDataClient, + WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { return GrpcDirectGetWorkStream.create( connection.backendWorkerToken(), diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java index 9aec29a3ba4d..f0ea2f550a74 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/stubs/WindmillChannelFactory.java @@ -36,7 +36,6 @@ /** Utility class used to create different RPC Channels. */ public final class WindmillChannelFactory { public static final String LOCALHOST = "localhost"; - private static final int DEFAULT_GRPC_PORT = 443; private static final int MAX_REMOTE_TRACE_EVENTS = 100; private WindmillChannelFactory() {} @@ -55,8 +54,6 @@ public static Channel localhostChannel(int port) { public static ManagedChannel remoteChannel( WindmillServiceAddress windmillServiceAddress, int windmillServiceRpcChannelTimeoutSec) { switch (windmillServiceAddress.getKind()) { - case IPV6: - return remoteChannel(windmillServiceAddress.ipv6(), windmillServiceRpcChannelTimeoutSec); case GCP_SERVICE_ADDRESS: return remoteChannel( windmillServiceAddress.gcpServiceAddress(), windmillServiceRpcChannelTimeoutSec); @@ -67,7 +64,8 @@ public static ManagedChannel remoteChannel( windmillServiceRpcChannelTimeoutSec); default: throw new UnsupportedOperationException( - "Only IPV6, GCP_SERVICE_ADDRESS, AUTHENTICATED_GCP_SERVICE_ADDRESS are supported WindmillServiceAddresses."); + "Only GCP_SERVICE_ADDRESS and AUTHENTICATED_GCP_SERVICE_ADDRESS are supported" + + " WindmillServiceAddresses."); } } @@ -105,17 +103,6 @@ public static Channel remoteChannel( } } - public static ManagedChannel remoteChannel( - Inet6Address directEndpoint, int windmillServiceRpcChannelTimeoutSec) { - try { - return createRemoteChannel( - NettyChannelBuilder.forAddress(new InetSocketAddress(directEndpoint, DEFAULT_GRPC_PORT)), - windmillServiceRpcChannelTimeoutSec); - } catch (SSLException sslException) { - throw new WindmillChannelCreationException(directEndpoint.toString(), sslException); - } - } - @SuppressWarnings("nullness") private static ManagedChannel createRemoteChannel( NettyChannelBuilder channelBuilder, int windmillServiceRpcChannelTimeoutSec) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java index 403bb99efb4c..8a1ba2556cf2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java @@ -17,18 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.DoubleMath.roundToLong; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.math.LongMath.divide; import java.math.RoundingMode; -import java.util.Map; -import java.util.Map.Entry; -import java.util.function.Function; -import java.util.function.Supplier; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,22 +29,11 @@ @Internal final class EvenGetWorkBudgetDistributor implements GetWorkBudgetDistributor { private static final Logger LOG = LoggerFactory.getLogger(EvenGetWorkBudgetDistributor.class); - private final Supplier activeWorkBudgetSupplier; - - EvenGetWorkBudgetDistributor(Supplier activeWorkBudgetSupplier) { - this.activeWorkBudgetSupplier = activeWorkBudgetSupplier; - } - - private static boolean isBelowFiftyPercentOfTarget( - GetWorkBudget remaining, GetWorkBudget target) { - return remaining.items() < roundToLong(target.items() * 0.5, RoundingMode.CEILING) - || remaining.bytes() < roundToLong(target.bytes() * 0.5, RoundingMode.CEILING); - } @Override public void distributeBudget( - ImmutableCollection budgetOwners, GetWorkBudget getWorkBudget) { - if (budgetOwners.isEmpty()) { + ImmutableCollection budgetSpenders, GetWorkBudget getWorkBudget) { + if (budgetSpenders.isEmpty()) { LOG.debug("Cannot distribute budget to no owners."); return; } @@ -61,38 +43,15 @@ public void distributeBudget( return; } - Map desiredBudgets = computeDesiredBudgets(budgetOwners, getWorkBudget); - - for (Entry streamAndDesiredBudget : desiredBudgets.entrySet()) { - GetWorkBudgetSpender getWorkBudgetSpender = streamAndDesiredBudget.getKey(); - GetWorkBudget desired = streamAndDesiredBudget.getValue(); - GetWorkBudget remaining = getWorkBudgetSpender.remainingBudget(); - if (isBelowFiftyPercentOfTarget(remaining, desired)) { - GetWorkBudget adjustment = desired.subtract(remaining); - getWorkBudgetSpender.adjustBudget(adjustment); - } - } + GetWorkBudget budgetPerStream = computeDesiredPerStreamBudget(budgetSpenders, getWorkBudget); + budgetSpenders.forEach(getWorkBudgetSpender -> getWorkBudgetSpender.setBudget(budgetPerStream)); } - private ImmutableMap computeDesiredBudgets( + private GetWorkBudget computeDesiredPerStreamBudget( ImmutableCollection streams, GetWorkBudget totalGetWorkBudget) { - GetWorkBudget activeWorkBudget = activeWorkBudgetSupplier.get(); - LOG.info("Current active work budget: {}", activeWorkBudget); - // TODO: Fix possibly non-deterministic handing out of budgets. - // Rounding up here will drift upwards over the lifetime of the streams. - GetWorkBudget budgetPerStream = - GetWorkBudget.builder() - .setItems( - divide( - totalGetWorkBudget.items() - activeWorkBudget.items(), - streams.size(), - RoundingMode.CEILING)) - .setBytes( - divide( - totalGetWorkBudget.bytes() - activeWorkBudget.bytes(), - streams.size(), - RoundingMode.CEILING)) - .build(); - return streams.stream().collect(toImmutableMap(Function.identity(), unused -> budgetPerStream)); + return GetWorkBudget.builder() + .setItems(divide(totalGetWorkBudget.items(), streams.size(), RoundingMode.CEILING)) + .setBytes(divide(totalGetWorkBudget.bytes(), streams.size(), RoundingMode.CEILING)) + .build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java index 43c0d46139da..2013c9ff1cb7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetDistributors.java @@ -17,13 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import java.util.function.Supplier; import org.apache.beam.sdk.annotations.Internal; @Internal public final class GetWorkBudgetDistributors { - public static GetWorkBudgetDistributor distributeEvenly( - Supplier activeWorkBudgetSupplier) { - return new EvenGetWorkBudgetDistributor(activeWorkBudgetSupplier); + public static GetWorkBudgetDistributor distributeEvenly() { + return new EvenGetWorkBudgetDistributor(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java index 254b2589062e..decf101a641b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetSpender.java @@ -22,11 +22,9 @@ * org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget} */ public interface GetWorkBudgetSpender { - void adjustBudget(long itemsDelta, long bytesDelta); + void setBudget(long items, long bytes); - default void adjustBudget(GetWorkBudget adjustment) { - adjustBudget(adjustment.items(), adjustment.bytes()); + default void setBudget(GetWorkBudget budget) { + setBudget(budget.items(), budget.bytes()); } - - GetWorkBudget remainingBudget(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index b3f7467cdbd3..90ffb3d3fbcf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -245,18 +245,10 @@ public void halfClose() { } @Override - public void adjustBudget(long itemsDelta, long bytesDelta) { + public void setBudget(GetWorkBudget newBudget) { // no-op. } - @Override - public GetWorkBudget remainingBudget() { - return GetWorkBudget.builder() - .setItems(request.getMaxItems()) - .setBytes(request.getMaxBytes()) - .build(); - } - @Override public boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException { while (done.getCount() > 0) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index ed8815c48e76..0092fcc7bcd1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -30,9 +30,7 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Comparator; import java.util.HashSet; -import java.util.List; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; @@ -46,7 +44,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; import org.apache.beam.runners.dataflow.worker.windmill.WindmillServiceAddress; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker; @@ -71,7 +68,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableCollection; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.junit.After; import org.junit.Before; @@ -92,7 +88,6 @@ public class FanOutStreamingEngineWorkerHarnessTest { .setDirectEndpoint(DEFAULT_WINDMILL_SERVICE_ADDRESS.gcpServiceAddress().toString()) .build()); - private static final long CLIENT_ID = 1L; private static final String JOB_ID = "jobId"; private static final String PROJECT_ID = "projectId"; private static final String WORKER_ID = "workerId"; @@ -101,6 +96,7 @@ public class FanOutStreamingEngineWorkerHarnessTest { .setJobId(JOB_ID) .setProjectId(PROJECT_ID) .setWorkerId(WORKER_ID) + .setClientId(1L) .build(); @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @@ -134,7 +130,7 @@ private static GetWorkRequest getWorkRequest(long items, long bytes) { .setJobId(JOB_ID) .setProjectId(PROJECT_ID) .setWorkerId(WORKER_ID) - .setClientId(CLIENT_ID) + .setClientId(JOB_HEADER.getClientId()) .setMaxItems(items) .setMaxBytes(bytes) .build(); @@ -174,7 +170,7 @@ public void cleanUp() { stubFactory.shutdown(); } - private FanOutStreamingEngineWorkerHarness newStreamingEngineClient( + private FanOutStreamingEngineWorkerHarness newFanOutStreamingEngineWorkerHarness( GetWorkBudget getWorkBudget, GetWorkBudgetDistributor getWorkBudgetDistributor, WorkItemScheduler workItemScheduler) { @@ -186,7 +182,6 @@ private FanOutStreamingEngineWorkerHarness newStreamingEngineClient( stubFactory, getWorkBudgetDistributor, dispatcherClient, - CLIENT_ID, ignored -> mock(WorkCommitter.class), new ThrottlingGetDataMetricTracker(mock(MemoryMonitor.class))); } @@ -201,7 +196,7 @@ public void testStreamsStartCorrectly() throws InterruptedException { spy(new TestGetWorkBudgetDistributor(numBudgetDistributionsExpected)); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(items).setBytes(bytes).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); @@ -219,16 +214,14 @@ public void testStreamsStartCorrectly() throws InterruptedException { getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); - StreamingEngineConnectionState currentConnections = - fanOutStreamingEngineWorkProvider.getCurrentConnections(); + StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); - assertEquals(2, currentConnections.windmillConnections().size()); - assertEquals(2, currentConnections.windmillStreams().size()); + assertEquals(2, currentBackends.windmillStreams().size()); Set workerTokens = - currentConnections.windmillConnections().values().stream() - .map(WindmillConnection::backendWorkerToken) + currentBackends.windmillStreams().keySet().stream() + .map(endpoint -> endpoint.workerToken().orElseThrow(IllegalStateException::new)) .collect(Collectors.toSet()); assertTrue(workerTokens.contains(workerToken)); @@ -252,27 +245,6 @@ public void testStreamsStartCorrectly() throws InterruptedException { verify(streamFactory, times(2)).createCommitWorkStream(any(), any()); } - @Test - public void testScheduledBudgetRefresh() throws InterruptedException { - TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(2)); - fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( - GetWorkBudget.builder().setItems(1L).setBytes(1L).build(), - getWorkBudgetDistributor, - noOpProcessWorkItemFn()); - - getWorkerMetadataReady.await(); - fakeGetWorkerMetadataStub.injectWorkerMetadata( - WorkerMetadataResponse.newBuilder() - .setMetadataVersion(1) - .addWorkEndpoints(metadataResponseEndpoint("workerToken")) - .putAllGlobalDataEndpoints(DEFAULT) - .build()); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - verify(getWorkBudgetDistributor, atLeast(2)).distributeBudget(any(), any()); - } - @Test public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() throws InterruptedException { @@ -280,7 +252,7 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor(metadataCount)); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); @@ -309,32 +281,28 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() WorkerMetadataResponse.Endpoint.newBuilder() .setBackendWorkerToken(workerToken3) .build()) - .putAllGlobalDataEndpoints(DEFAULT) .build(); getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - StreamingEngineConnectionState currentConnections = - fanOutStreamingEngineWorkProvider.getCurrentConnections(); - assertEquals(1, currentConnections.windmillConnections().size()); - assertEquals(1, currentConnections.windmillStreams().size()); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); + StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); + assertEquals(1, currentBackends.windmillStreams().size()); Set workerTokens = - fanOutStreamingEngineWorkProvider.getCurrentConnections().windmillConnections().values() - .stream() - .map(WindmillConnection::backendWorkerToken) + fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream() + .map(endpoint -> endpoint.workerToken().orElseThrow(IllegalStateException::new)) .collect(Collectors.toSet()); assertFalse(workerTokens.contains(workerToken)); assertFalse(workerTokens.contains(workerToken2)); + assertTrue(currentBackends.globalDataStreams().isEmpty()); } @Test public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedException { String workerToken = "workerToken1"; String workerToken2 = "workerToken2"; - String workerToken3 = "workerToken3"; WorkerMetadataResponse firstWorkerMetadata = WorkerMetadataResponse.newBuilder() @@ -354,42 +322,24 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce .build()) .putAllGlobalDataEndpoints(DEFAULT) .build(); - WorkerMetadataResponse thirdWorkerMetadata = - WorkerMetadataResponse.newBuilder() - .setMetadataVersion(3) - .addWorkEndpoints( - WorkerMetadataResponse.Endpoint.newBuilder() - .setBackendWorkerToken(workerToken3) - .build()) - .putAllGlobalDataEndpoints(DEFAULT) - .build(); - - List workerMetadataResponses = - Lists.newArrayList(firstWorkerMetadata, secondWorkerMetadata, thirdWorkerMetadata); TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(workerMetadataResponses.size())); + spy(new TestGetWorkBudgetDistributor(1)); fanOutStreamingEngineWorkProvider = - newStreamingEngineClient( + newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), getWorkBudgetDistributor, noOpProcessWorkItemFn()); getWorkerMetadataReady.await(); - // Make sure we are injecting the metadata from smallest to largest. - workerMetadataResponses.stream() - .sorted(Comparator.comparingLong(WorkerMetadataResponse::getMetadataVersion)) - .forEach(fakeGetWorkerMetadataStub::injectWorkerMetadata); - - waitForWorkerMetadataToBeConsumed(getWorkBudgetDistributor); - verify(getWorkBudgetDistributor, atLeast(workerMetadataResponses.size())) - .distributeBudget(any(), any()); - } + fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); + getWorkBudgetDistributor.expectNumDistributions(1); + fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); - private void waitForWorkerMetadataToBeConsumed( - TestGetWorkBudgetDistributor getWorkBudgetDistributor) throws InterruptedException { - getWorkBudgetDistributor.waitForBudgetDistribution(); + verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any()); } private static class GetWorkerMetadataTestStub @@ -434,21 +384,24 @@ private void injectWorkerMetadata(WorkerMetadataResponse response) { } private static class TestGetWorkBudgetDistributor implements GetWorkBudgetDistributor { - private final CountDownLatch getWorkBudgetDistributorTriggered; + private CountDownLatch getWorkBudgetDistributorTriggered; private TestGetWorkBudgetDistributor(int numBudgetDistributionsExpected) { this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); } - @SuppressWarnings("ReturnValueIgnored") - private void waitForBudgetDistribution() throws InterruptedException { - getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); + private boolean waitForBudgetDistribution() throws InterruptedException { + return getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS); + } + + private void expectNumDistributions(int numBudgetDistributionsExpected) { + this.getWorkBudgetDistributorTriggered = new CountDownLatch(numBudgetDistributionsExpected); } @Override public void distributeBudget( ImmutableCollection streams, GetWorkBudget getWorkBudget) { - streams.forEach(stream -> stream.adjustBudget(getWorkBudget.items(), getWorkBudget.bytes())); + streams.forEach(stream -> stream.setBudget(getWorkBudget.items(), getWorkBudget.bytes())); getWorkBudgetDistributorTriggered.countDown(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java index dc6cc5641055..32d1f5738086 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java @@ -193,7 +193,7 @@ public void testCloseAllStreams_doesNotCloseUnstartedStreams() { WindmillStreamSender windmillStreamSender = newWindmillStreamSender(GetWorkBudget.builder().setBytes(1L).setItems(1L).build()); - windmillStreamSender.closeAllStreams(); + windmillStreamSender.close(); verifyNoInteractions(streamFactory); } @@ -230,7 +230,7 @@ public void testCloseAllStreams_closesAllStreams() { mockStreamFactory); windmillStreamSender.startStreams(); - windmillStreamSender.closeAllStreams(); + windmillStreamSender.close(); verify(mockGetWorkStream).shutdown(); verify(mockGetDataStream).shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java new file mode 100644 index 000000000000..fd2b30238836 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcDirectGetWorkStreamTest { + private static final WorkItemScheduler NO_OP_WORK_ITEM_SCHEDULER = + (workItem, watermarks, processingContext, getWorkStreamLatencies) -> {}; + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setClientId(1L) + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String FAKE_SERVER_NAME = "Fake server for GrpcDirectGetWorkStreamTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + private GrpcDirectGetWorkStream stream; + + private static Windmill.StreamingGetWorkRequestExtension extension(GetWorkBudget budget) { + return Windmill.StreamingGetWorkRequestExtension.newBuilder() + .setMaxItems(budget.items()) + .setMaxBytes(budget.bytes()) + .build(); + } + + private static void assertHeader( + Windmill.StreamingGetWorkRequest getWorkRequest, GetWorkBudget expectedInitialBudget) { + assertTrue(getWorkRequest.hasRequest()); + assertFalse(getWorkRequest.hasRequestExtension()); + assertThat(getWorkRequest.getRequest()) + .isEqualTo( + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(expectedInitialBudget.items()) + .setMaxBytes(expectedInitialBudget.bytes()) + .build()); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + checkNotNull(stream).shutdown(); + } + + private GrpcDirectGetWorkStream createGetWorkStream( + GetWorkStreamTestStub testStub, + GetWorkBudget initialGetWorkBudget, + ThrottleTimer throttleTimer, + WorkItemScheduler workItemScheduler) { + serviceRegistry.addService(testStub); + return (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + } + + private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { + return Windmill.StreamingGetWorkResponseChunk.newBuilder() + .setStreamId(1L) + .setComputationMetadata( + Windmill.ComputationWorkItemMetadata.newBuilder() + .setComputationId("compId") + .setInputDataWatermark(1L) + .setDependentRealtimeInputWatermark(1L) + .build()) + .setSerializedWorkItem(workItem.toByteString()) + .setRemainingBytesForWorkItem(0) + .build(); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream.setBudget(newBudget); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + // Header and extension. + assertThat(requestObserver.sent()).hasSize(expectedRequests); + assertHeader(requestObserver.sent().get(0), GetWorkBudget.noBudget()); + assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension()) + .isEqualTo(extension(newBudget)); + } + + @Test + public void testSetBudget_computesAndSendsCorrectExtension_existingBudget() + throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(10).setBytes(10).build(); + stream = + createGetWorkStream( + testStub, initialBudget, new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + GetWorkBudget newBudget = GetWorkBudget.builder().setItems(100).setBytes(100).build(); + stream.setBudget(newBudget); + GetWorkBudget diff = newBudget.subtract(initialBudget); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Header and extension. + assertThat(requests).hasSize(expectedRequests); + assertHeader(requests.get(0), initialBudget); + assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff)); + } + + @Test + public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(); + stream = + createGetWorkStream( + testStub, initialBudget, new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build()); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertHeader(Iterables.getOnlyElement(requests), initialBudget); + } + + @Test + public void testSetBudget_doesNothingIfStreamShutdown() throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), new ThrottleTimer(), NO_OP_WORK_ITEM_SCHEDULER); + stream.shutdown(); + stream.setBudget( + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build()); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + List requests = requestObserver.sent(); + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(1); + assertHeader(Iterables.getOnlyElement(requests), GetWorkBudget.noBudget()); + } + + @Test + public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws InterruptedException { + int expectedRequests = 2; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + GetWorkBudget initialBudget = GetWorkBudget.builder().setItems(1).setBytes(100).build(); + Set scheduledWorkItems = new HashSet<>(); + stream = + createGetWorkStream( + testStub, + initialBudget, + new ThrottleTimer(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> { + scheduledWorkItems.add(work); + }); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List requests = requestObserver.sent(); + long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize(); + + assertThat(requests).hasSize(expectedRequests); + assertHeader(requests.get(0), initialBudget); + assertThat(Iterables.getLast(requests).getRequestExtension()) + .isEqualTo( + extension( + GetWorkBudget.builder() + .setItems(1) + .setBytes(initialBudget.bytes() - inFlightBytes) + .build())); + } + + @Test + public void testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh() + throws InterruptedException { + int expectedRequests = 1; + CountDownLatch waitForRequests = new CountDownLatch(expectedRequests); + TestGetWorkRequestObserver requestObserver = new TestGetWorkRequestObserver(waitForRequests); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + Set scheduledWorkItems = new HashSet<>(); + GetWorkBudget initialBudget = + GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(); + stream = + createGetWorkStream( + testStub, + initialBudget, + new ThrottleTimer(), + (work, watermarks, processingContext, getWorkStreamLatencies) -> + scheduledWorkItems.add(work)); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.copyFromUtf8("somewhat_long_key")) + .setWorkToken(1L) + .setShardingKey(1L) + .setCacheToken(1L) + .build(); + + testStub.injectResponse(createResponse(workItem)); + + assertTrue(waitForRequests.await(5, TimeUnit.SECONDS)); + + assertThat(scheduledWorkItems).containsExactly(workItem); + List requests = requestObserver.sent(); + + // Assert that the extension was never sent, only the header. + assertThat(requests).hasSize(expectedRequests); + assertHeader(Iterables.getOnlyElement(requests), initialBudget); + } + + @Test + public void testOnResponse_stopsThrottling() { + ThrottleTimer throttleTimer = new ThrottleTimer(); + TestGetWorkRequestObserver requestObserver = + new TestGetWorkRequestObserver(new CountDownLatch(1)); + GetWorkStreamTestStub testStub = new GetWorkStreamTestStub(requestObserver); + stream = + createGetWorkStream( + testStub, GetWorkBudget.noBudget(), throttleTimer, NO_OP_WORK_ITEM_SCHEDULER); + stream.startThrottleTimer(); + assertTrue(throttleTimer.throttled()); + testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance()); + assertFalse(throttleTimer.throttled()); + } + + private static class GetWorkStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestGetWorkRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private GetWorkStreamTestStub(TestGetWorkRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver getWorkStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + + private void injectResponse(Windmill.StreamingGetWorkResponseChunk responseChunk) { + checkNotNull(responseObserver).onNext(responseChunk); + } + } + + private static class TestGetWorkRequestObserver + implements StreamObserver { + private final List requests = + Collections.synchronizedList(new ArrayList<>()); + private final CountDownLatch waitForRequests; + private @Nullable volatile StreamObserver + responseObserver; + + public TestGetWorkRequestObserver(CountDownLatch waitForRequests) { + this.waitForRequests = waitForRequests; + } + + @Override + public void onNext(Windmill.StreamingGetWorkRequest request) { + requests.add(request); + waitForRequests.countDown(); + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + + List sent() { + return requests; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java index 3cda4559c100..c76d5a584184 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java @@ -17,9 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.budget; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -40,169 +38,79 @@ public class EvenGetWorkBudgetDistributorTest { @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - private static GetWorkBudgetDistributor createBudgetDistributor(GetWorkBudget activeWorkBudget) { - return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget); - } + private static GetWorkBudgetSpender createGetWorkBudgetOwner() { + // Lambdas are final and cannot be spied. + return spy( + new GetWorkBudgetSpender() { - private static GetWorkBudgetDistributor createBudgetDistributor(long activeWorkItemsAndBytes) { - return createBudgetDistributor( - GetWorkBudget.builder() - .setItems(activeWorkItemsAndBytes) - .setBytes(activeWorkItemsAndBytes) - .build()); + @Override + public void setBudget(long items, long bytes) {} + }); } @Test public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() { - createBudgetDistributor(1L) + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.of(), GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); } @Test public void testDistributeBudget_doesNothingWithNoBudget() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget())); - createBudgetDistributor(1L) + GetWorkBudgetSpender getWorkBudgetSpender = createGetWorkBudgetOwner(); + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget(ImmutableList.of(getWorkBudgetSpender), GetWorkBudget.noBudget()); verifyNoInteractions(getWorkBudgetSpender); } @Test - public void testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighNoActiveWork() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy( - createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget.builder().setItems(10L).setBytes(10L).build())); - createBudgetDistributor(0L) - .distributeBudget( - ImmutableList.of(getWorkBudgetSpender), - GetWorkBudget.builder().setItems(10L).setBytes(10L).build()); - - verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong()); - } - - @Test - public void - testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighWithActiveWork() { - GetWorkBudgetSpender getWorkBudgetSpender = - spy( - createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget.builder().setItems(5L).setBytes(5L).build())); - createBudgetDistributor(10L) + public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { + int totalStreams = 10; + long totalItems = 10L; + long totalBytes = 100L; + List streams = new ArrayList<>(); + for (int i = 0; i < totalStreams; i++) { + streams.add(createGetWorkBudgetOwner()); + } + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( - ImmutableList.of(getWorkBudgetSpender), - GetWorkBudget.builder().setItems(20L).setBytes(20L).build()); - - verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong()); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithNoActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(0L) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(1L).setBytes(10L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - long activeWorkItemsAndBytes = 2L; - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(activeWorkItemsAndBytes) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq( - totalGetWorkBudget.items() - - streamRemainingBudget.items() - - activeWorkItemsAndBytes), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowNoActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(0L) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); - - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes())); - } - - @Test - public void - testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowWithActiveWork() { - GetWorkBudget streamRemainingBudget = - GetWorkBudget.builder().setItems(10L).setBytes(1L).build(); - GetWorkBudget totalGetWorkBudget = GetWorkBudget.builder().setItems(10L).setBytes(10L).build(); - long activeWorkItemsAndBytes = 2L; - - GetWorkBudgetSpender getWorkBudgetSpender = - spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget)); - createBudgetDistributor(activeWorkItemsAndBytes) - .distributeBudget(ImmutableList.of(getWorkBudgetSpender), totalGetWorkBudget); + ImmutableList.copyOf(streams), + GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build()); - verify(getWorkBudgetSpender, times(1)) - .adjustBudget( - eq(totalGetWorkBudget.items() - streamRemainingBudget.items()), - eq( - totalGetWorkBudget.bytes() - - streamRemainingBudget.bytes() - - activeWorkItemsAndBytes)); + streams.forEach( + stream -> + verify(stream, times(1)) + .setBudget(eq(GetWorkBudget.builder().setItems(1L).setBytes(10L).build()))); } @Test - public void testDistributeBudget_distributesBudgetEvenlyIfPossible() { - long totalItemsAndBytes = 10L; + public void testDistributeBudget_distributesFairlyWhenNotEven() { + long totalItems = 10L; + long totalBytes = 19L; List streams = new ArrayList<>(); - for (int i = 0; i < totalItemsAndBytes; i++) { - streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()))); + for (int i = 0; i < 3; i++) { + streams.add(createGetWorkBudgetOwner()); } - createBudgetDistributor(0L) + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.copyOf(streams), - GetWorkBudget.builder() - .setItems(totalItemsAndBytes) - .setBytes(totalItemsAndBytes) - .build()); + GetWorkBudget.builder().setItems(totalItems).setBytes(totalBytes).build()); - long itemsAndBytesPerStream = totalItemsAndBytes / streams.size(); streams.forEach( stream -> verify(stream, times(1)) - .adjustBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); + .setBudget(eq(GetWorkBudget.builder().setItems(4L).setBytes(7L).build()))); } @Test - public void testDistributeBudget_distributesFairlyWhenNotEven() { + public void testDistributeBudget_distributesBudgetEvenly() { long totalItemsAndBytes = 10L; List streams = new ArrayList<>(); - for (int i = 0; i < 3; i++) { - streams.add(spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()))); + for (int i = 0; i < totalItemsAndBytes; i++) { + streams.add(createGetWorkBudgetOwner()); } - createBudgetDistributor(0L) + + GetWorkBudgetDistributors.distributeEvenly() .distributeBudget( ImmutableList.copyOf(streams), GetWorkBudget.builder() @@ -210,24 +118,10 @@ public void testDistributeBudget_distributesFairlyWhenNotEven() { .setBytes(totalItemsAndBytes) .build()); - long itemsAndBytesPerStream = (long) Math.ceil(totalItemsAndBytes / (streams.size() * 1.0)); + long itemsAndBytesPerStream = totalItemsAndBytes / streams.size(); streams.forEach( stream -> verify(stream, times(1)) - .adjustBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); - } - - private GetWorkBudgetSpender createGetWorkBudgetOwnerWithRemainingBudgetOf( - GetWorkBudget getWorkBudget) { - return spy( - new GetWorkBudgetSpender() { - @Override - public void adjustBudget(long itemsDelta, long bytesDelta) {} - - @Override - public GetWorkBudget remainingBudget() { - return getWorkBudget; - } - }); + .setBudget(eq(itemsAndBytesPerStream), eq(itemsAndBytesPerStream))); } } From b8a2b9aff9fa09e68bf0848278610d56cdc23345 Mon Sep 17 00:00:00 2001 From: Reuven Lax Date: Mon, 21 Oct 2024 11:25:02 -0700 Subject: [PATCH 043/181] Revert "Merge pull request #32757: Schema inference parameterized types" This reverts commit a50f91c386c00940b08ef8a5e4d0817422ea230f. --- .../beam/sdk/schemas/AutoValueSchema.java | 8 +- .../schemas/FieldValueTypeInformation.java | 89 +++++------ .../beam/sdk/schemas/JavaBeanSchema.java | 12 +- .../beam/sdk/schemas/JavaFieldSchema.java | 10 +- .../beam/sdk/schemas/SchemaProvider.java | 3 +- .../beam/sdk/schemas/SchemaRegistry.java | 39 +++-- .../transforms/providers/JavaRowUdf.java | 3 +- .../sdk/schemas/utils/AutoValueUtils.java | 20 +-- .../sdk/schemas/utils/ByteBuddyUtils.java | 53 +++---- .../sdk/schemas/utils/ConvertHelpers.java | 6 +- .../beam/sdk/schemas/utils/JavaBeanUtils.java | 10 +- .../beam/sdk/schemas/utils/POJOUtils.java | 20 +-- .../beam/sdk/schemas/utils/ReflectUtils.java | 83 ++-------- .../schemas/utils/StaticSchemaInference.java | 91 ++++++----- .../beam/sdk/schemas/AutoValueSchemaTest.java | 149 ------------------ .../beam/sdk/schemas/JavaBeanSchemaTest.java | 124 --------------- .../beam/sdk/schemas/JavaFieldSchemaTest.java | 120 -------------- .../sdk/schemas/utils/JavaBeanUtilsTest.java | 33 +--- .../beam/sdk/schemas/utils/POJOUtilsTest.java | 36 ++--- .../beam/sdk/schemas/utils/TestJavaBeans.java | 91 ----------- .../beam/sdk/schemas/utils/TestPOJOs.java | 121 +------------- .../schemas/utils/AvroByteBuddyUtils.java | 6 +- .../avro/schemas/utils/AvroUtils.java | 10 +- .../protobuf/ProtoByteBuddyUtils.java | 4 +- .../protobuf/ProtoMessageSchema.java | 8 +- .../python/PythonExternalTransform.java | 4 +- .../beam/sdk/io/thrift/ThriftSchema.java | 5 +- 27 files changed, 197 insertions(+), 961 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java index c369eefeb65c..5ccfe39b92af 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java @@ -19,10 +19,8 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.lang.reflect.Type; import java.util.Comparator; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; import org.apache.beam.sdk.schemas.utils.AutoValueUtils; @@ -63,9 +61,8 @@ public List get(TypeDescriptor typeDescriptor) { .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes)); + types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); } types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); @@ -146,8 +143,7 @@ public SchemaUserTypeCreator schemaTypeCreator( @Override public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); return JavaBeanUtils.schemaFromJavaBeanClass( - typeDescriptor, AbstractGetterTypeSupplier.INSTANCE, boundTypes); + typeDescriptor, AbstractGetterTypeSupplier.INSTANCE); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java index 64687e6d3381..750709192c08 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java @@ -24,12 +24,10 @@ import java.lang.reflect.Field; import java.lang.reflect.Member; import java.lang.reflect.Method; -import java.lang.reflect.Type; import java.util.Arrays; import java.util.Collections; import java.util.Map; import java.util.stream.Stream; -import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; @@ -46,7 +44,6 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) -@Internal public abstract class FieldValueTypeInformation implements Serializable { /** Optionally returns the field index. */ public abstract @Nullable Integer getNumber(); @@ -128,10 +125,8 @@ public static FieldValueTypeInformation forOneOf( .build(); } - public static FieldValueTypeInformation forField( - Field field, int index, Map boundTypes) { - TypeDescriptor type = - TypeDescriptor.of(ReflectUtils.resolveType(field.getGenericType(), boundTypes)); + public static FieldValueTypeInformation forField(Field field, int index) { + TypeDescriptor type = TypeDescriptor.of(field.getGenericType()); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(field.getName(), field)) .setNumber(getNumberOverride(index, field)) @@ -139,9 +134,9 @@ public static FieldValueTypeInformation forField( .setType(type) .setRawType(type.getRawType()) .setField(field) - .setElementType(getIterableComponentType(field, boundTypes)) - .setMapKeyType(getMapKeyType(field, boundTypes)) - .setMapValueType(getMapValueType(field, boundTypes)) + .setElementType(getIterableComponentType(field)) + .setMapKeyType(getMapKeyType(field)) + .setMapValueType(getMapValueType(field)) .setOneOfTypes(Collections.emptyMap()) .setDescription(getFieldDescription(field)) .build(); @@ -189,8 +184,7 @@ public static String getNameOverride( return fieldDescription.value(); } - public static FieldValueTypeInformation forGetter( - Method method, int index, Map boundTypes) { + public static FieldValueTypeInformation forGetter(Method method, int index) { String name; if (method.getName().startsWith("get")) { name = ReflectUtils.stripPrefix(method.getName(), "get"); @@ -200,8 +194,7 @@ public static FieldValueTypeInformation forGetter( throw new RuntimeException("Getter has wrong prefix " + method.getName()); } - TypeDescriptor type = - TypeDescriptor.of(ReflectUtils.resolveType(method.getGenericReturnType(), boundTypes)); + TypeDescriptor type = TypeDescriptor.of(method.getGenericReturnType()); boolean nullable = hasNullableReturnType(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(name, method)) @@ -210,9 +203,9 @@ public static FieldValueTypeInformation forGetter( .setType(type) .setRawType(type.getRawType()) .setMethod(method) - .setElementType(getIterableComponentType(type, boundTypes)) - .setMapKeyType(getMapKeyType(type, boundTypes)) - .setMapValueType(getMapValueType(type, boundTypes)) + .setElementType(getIterableComponentType(type)) + .setMapKeyType(getMapKeyType(type)) + .setMapValueType(getMapValueType(type)) .setOneOfTypes(Collections.emptyMap()) .setDescription(getFieldDescription(method)) .build(); @@ -259,13 +252,11 @@ private static boolean isNullableAnnotation(Annotation annotation) { return annotation.annotationType().getSimpleName().equals("Nullable"); } - public static FieldValueTypeInformation forSetter( - Method method, Map boundParameters) { - return forSetter(method, "set", boundParameters); + public static FieldValueTypeInformation forSetter(Method method) { + return forSetter(method, "set"); } - public static FieldValueTypeInformation forSetter( - Method method, String setterPrefix, Map boundTypes) { + public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) { String name; if (method.getName().startsWith(setterPrefix)) { name = ReflectUtils.stripPrefix(method.getName(), setterPrefix); @@ -273,9 +264,7 @@ public static FieldValueTypeInformation forSetter( throw new RuntimeException("Setter has wrong prefix " + method.getName()); } - TypeDescriptor type = - TypeDescriptor.of( - ReflectUtils.resolveType(method.getGenericParameterTypes()[0], boundTypes)); + TypeDescriptor type = TypeDescriptor.of(method.getGenericParameterTypes()[0]); boolean nullable = hasSingleNullableParameter(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(name) @@ -283,9 +272,9 @@ public static FieldValueTypeInformation forSetter( .setType(type) .setRawType(type.getRawType()) .setMethod(method) - .setElementType(getIterableComponentType(type, boundTypes)) - .setMapKeyType(getMapKeyType(type, boundTypes)) - .setMapValueType(getMapValueType(type, boundTypes)) + .setElementType(getIterableComponentType(type)) + .setMapKeyType(getMapKeyType(type)) + .setMapValueType(getMapValueType(type)) .setOneOfTypes(Collections.emptyMap()) .build(); } @@ -294,15 +283,13 @@ public FieldValueTypeInformation withName(String name) { return toBuilder().setName(name).build(); } - private static FieldValueTypeInformation getIterableComponentType( - Field field, Map boundTypes) { - return getIterableComponentType(TypeDescriptor.of(field.getGenericType()), boundTypes); + private static FieldValueTypeInformation getIterableComponentType(Field field) { + return getIterableComponentType(TypeDescriptor.of(field.getGenericType())); } - static @Nullable FieldValueTypeInformation getIterableComponentType( - TypeDescriptor valueType, Map boundTypes) { + static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) { // TODO: Figure out nullable elements. - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType, boundTypes); + TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType); if (componentType == null) { return null; } @@ -312,43 +299,41 @@ private static FieldValueTypeInformation getIterableComponentType( .setNullable(false) .setType(componentType) .setRawType(componentType.getRawType()) - .setElementType(getIterableComponentType(componentType, boundTypes)) - .setMapKeyType(getMapKeyType(componentType, boundTypes)) - .setMapValueType(getMapValueType(componentType, boundTypes)) + .setElementType(getIterableComponentType(componentType)) + .setMapKeyType(getMapKeyType(componentType)) + .setMapValueType(getMapValueType(componentType)) .setOneOfTypes(Collections.emptyMap()) .build(); } // If the Field is a map type, returns the key type, otherwise returns a null reference. - private static @Nullable FieldValueTypeInformation getMapKeyType( - Field field, Map boundTypes) { - return getMapKeyType(TypeDescriptor.of(field.getGenericType()), boundTypes); + private static @Nullable FieldValueTypeInformation getMapKeyType(Field field) { + return getMapKeyType(TypeDescriptor.of(field.getGenericType())); } private static @Nullable FieldValueTypeInformation getMapKeyType( - TypeDescriptor typeDescriptor, Map boundTypes) { - return getMapType(typeDescriptor, 0, boundTypes); + TypeDescriptor typeDescriptor) { + return getMapType(typeDescriptor, 0); } // If the Field is a map type, returns the value type, otherwise returns a null reference. - private static @Nullable FieldValueTypeInformation getMapValueType( - Field field, Map boundTypes) { - return getMapType(TypeDescriptor.of(field.getGenericType()), 1, boundTypes); + private static @Nullable FieldValueTypeInformation getMapValueType(Field field) { + return getMapType(TypeDescriptor.of(field.getGenericType()), 1); } private static @Nullable FieldValueTypeInformation getMapValueType( - TypeDescriptor typeDescriptor, Map boundTypes) { - return getMapType(typeDescriptor, 1, boundTypes); + TypeDescriptor typeDescriptor) { + return getMapType(typeDescriptor, 1); } // If the Field is a map type, returns the key or value type (0 is key type, 1 is value). // Otherwise returns a null reference. @SuppressWarnings("unchecked") private static @Nullable FieldValueTypeInformation getMapType( - TypeDescriptor valueType, int index, Map boundTypes) { - TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index, boundTypes); + TypeDescriptor valueType, int index) { + TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index); if (mapType == null) { return null; } @@ -357,9 +342,9 @@ private static FieldValueTypeInformation getIterableComponentType( .setNullable(false) .setType(mapType) .setRawType(mapType.getRawType()) - .setElementType(getIterableComponentType(mapType, boundTypes)) - .setMapKeyType(getMapKeyType(mapType, boundTypes)) - .setMapValueType(getMapValueType(mapType, boundTypes)) + .setElementType(getIterableComponentType(mapType)) + .setMapKeyType(getMapKeyType(mapType)) + .setMapValueType(getMapValueType(mapType)) .setOneOfTypes(Collections.emptyMap()) .build(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java index ad71576670bf..a9cf01c52057 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java @@ -19,10 +19,8 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; -import java.lang.reflect.Type; import java.util.Comparator; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; @@ -69,9 +67,8 @@ public List get(TypeDescriptor typeDescriptor) { .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i, boundTypes)); + types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); } types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); @@ -114,11 +111,10 @@ public static class SetterTypeSupplier implements FieldValueTypeSupplier { @Override public List get(TypeDescriptor typeDescriptor) { - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); return ReflectUtils.getMethods(typeDescriptor.getRawType()).stream() .filter(ReflectUtils::isSetter) .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) - .map(m -> FieldValueTypeInformation.forSetter(m, boundTypes)) + .map(FieldValueTypeInformation::forSetter) .map( t -> { if (t.getMethod().getAnnotation(SchemaFieldNumber.class) != null) { @@ -160,10 +156,8 @@ public boolean equals(@Nullable Object obj) { @Override public Schema schemaFor(TypeDescriptor typeDescriptor) { - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); Schema schema = - JavaBeanUtils.schemaFromJavaBeanClass( - typeDescriptor, GetterTypeSupplier.INSTANCE, boundTypes); + JavaBeanUtils.schemaFromJavaBeanClass(typeDescriptor, GetterTypeSupplier.INSTANCE); // If there are no creator methods, then validate that we have setters for every field. // Otherwise, we will have no way of creating instances of the class. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java index da0f59c8ee96..21f07c47b47f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java @@ -21,10 +21,8 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.lang.reflect.Type; import java.util.Comparator; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -64,11 +62,9 @@ public List get(TypeDescriptor typeDescriptor) { ReflectUtils.getFields(typeDescriptor.getRawType()).stream() .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) .collect(Collectors.toList()); - List types = Lists.newArrayListWithCapacity(fields.size()); - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); for (int i = 0; i < fields.size(); ++i) { - types.add(FieldValueTypeInformation.forField(fields.get(i), i, boundTypes)); + types.add(FieldValueTypeInformation.forField(fields.get(i), i)); } types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); @@ -115,9 +111,7 @@ private static void validateFieldNumbers(List types) @Override public Schema schemaFor(TypeDescriptor typeDescriptor) { - Map boundTypes = ReflectUtils.getAllBoundTypes(typeDescriptor); - return POJOUtils.schemaFromPojoClass( - typeDescriptor, JavaFieldTypeSupplier.INSTANCE, boundTypes); + return POJOUtils.schemaFromPojoClass(typeDescriptor, JavaFieldTypeSupplier.INSTANCE); } @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java index b7e3cdf60c18..37b4952e529c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaProvider.java @@ -38,7 +38,8 @@ public interface SchemaProvider extends Serializable { * Given a type, return a function that converts that type to a {@link Row} object If no schema * exists, returns null. */ - @Nullable SerializableFunction toRowFunction(TypeDescriptor typeDescriptor); + @Nullable + SerializableFunction toRowFunction(TypeDescriptor typeDescriptor); /** * Given a type, returns a function that converts from a {@link Row} object to that type. If no diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java index 5d8b7aab6193..679a1fcf54fc 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java @@ -76,12 +76,13 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid providers.put(typeDescriptor, schemaProvider); } - private @Nullable SchemaProvider schemaProviderFor(TypeDescriptor typeDescriptor) { + @Override + public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { TypeDescriptor type = typeDescriptor; do { SchemaProvider schemaProvider = providers.get(type); if (schemaProvider != null) { - return schemaProvider; + return schemaProvider.schemaFor(type); } Class superClass = type.getRawType().getSuperclass(); if (superClass == null || superClass.equals(Object.class)) { @@ -91,24 +92,38 @@ void registerProvider(TypeDescriptor typeDescriptor, SchemaProvider schemaProvid } while (true); } - @Override - public @Nullable Schema schemaFor(TypeDescriptor typeDescriptor) { - @Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor); - return schemaProvider != null ? schemaProvider.schemaFor(typeDescriptor) : null; - } - @Override public @Nullable SerializableFunction toRowFunction( TypeDescriptor typeDescriptor) { - @Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor); - return schemaProvider != null ? schemaProvider.toRowFunction(typeDescriptor) : null; + TypeDescriptor type = typeDescriptor; + do { + SchemaProvider schemaProvider = providers.get(type); + if (schemaProvider != null) { + return (SerializableFunction) schemaProvider.toRowFunction(type); + } + Class superClass = type.getRawType().getSuperclass(); + if (superClass == null || superClass.equals(Object.class)) { + return null; + } + type = TypeDescriptor.of(superClass); + } while (true); } @Override public @Nullable SerializableFunction fromRowFunction( TypeDescriptor typeDescriptor) { - @Nullable SchemaProvider schemaProvider = schemaProviderFor(typeDescriptor); - return schemaProvider != null ? schemaProvider.fromRowFunction(typeDescriptor) : null; + TypeDescriptor type = typeDescriptor; + do { + SchemaProvider schemaProvider = providers.get(type); + if (schemaProvider != null) { + return (SerializableFunction) schemaProvider.fromRowFunction(type); + } + Class superClass = type.getRawType().getSuperclass(); + if (superClass == null || superClass.equals(Object.class)) { + return null; + } + type = TypeDescriptor.of(superClass); + } while (true); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java index c3a71bbb454b..54e2a595fa71 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/providers/JavaRowUdf.java @@ -160,8 +160,7 @@ public FunctionAndType(Type outputType, Function function) { public FunctionAndType(TypeDescriptor outputType, Function function) { this( - StaticSchemaInference.fieldFromType( - outputType, new EmptyFieldValueTypeSupplier(), Collections.emptyMap()), + StaticSchemaInference.fieldFromType(outputType, new EmptyFieldValueTypeSupplier()), function); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java index 74e97bad4f0f..d7fddd8abfed 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java @@ -53,7 +53,6 @@ import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import net.bytebuddy.jar.asm.ClassWriter; import net.bytebuddy.matcher.ElementMatchers; -import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; @@ -64,7 +63,6 @@ import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.checkerframework.checker.nullness.qual.Nullable; /** Utilities for managing AutoValue schemas. */ @@ -72,7 +70,6 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) -@Internal public class AutoValueUtils { public static TypeDescriptor getBaseAutoValueClass(TypeDescriptor typeDescriptor) { // AutoValue extensions may be nested @@ -164,7 +161,7 @@ private static boolean matchConstructor( // Verify that constructor parameters match (name and type) the inferred schema. for (Parameter parameter : constructor.getParameters()) { FieldValueTypeInformation type = typeMap.get(parameter.getName()); - if (type == null || !type.getRawType().equals(parameter.getType())) { + if (type == null || type.getRawType() != parameter.getType()) { valid = false; break; } @@ -181,7 +178,7 @@ private static boolean matchConstructor( } name = name.substring(0, name.length() - 1); FieldValueTypeInformation type = typeMap.get(name); - if (type == null || !type.getRawType().equals(parameter.getType())) { + if (type == null || type.getRawType() != parameter.getType()) { return false; } } @@ -199,12 +196,11 @@ private static boolean matchConstructor( return null; } - Map boundTypes = ReflectUtils.getAllBoundTypes(TypeDescriptor.of(builderClass)); - Map setterTypes = Maps.newHashMap(); - ReflectUtils.getMethods(builderClass).stream() - .filter(ReflectUtils::isSetter) - .map(m -> FieldValueTypeInformation.forSetter(m, boundTypes)) - .forEach(fv -> setterTypes.putIfAbsent(fv.getName(), fv)); + Map setterTypes = + ReflectUtils.getMethods(builderClass).stream() + .filter(ReflectUtils::isSetter) + .map(FieldValueTypeInformation::forSetter) + .collect(Collectors.toMap(FieldValueTypeInformation::getName, Function.identity())); List setterMethods = Lists.newArrayList(); // The builder methods to call in order. @@ -325,7 +321,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { Duplication.SINGLE, typeConversionsFactory .createSetterConversions(readParameter) - .convert(TypeDescriptor.of(parameter.getParameterizedType())), + .convert(TypeDescriptor.of(parameter.getType())), MethodInvocation.invoke(new ForLoadedMethod(setterMethod)), Removal.SINGLE); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java index 3b2428ebb999..540f09b7b553 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java @@ -345,22 +345,19 @@ protected Type convertArray(TypeDescriptor type) { @Override protected Type convertCollection(TypeDescriptor type) { - TypeDescriptor ret = - createCollectionType(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); + TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertList(TypeDescriptor type) { - TypeDescriptor ret = - createCollectionType(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); + TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertIterable(TypeDescriptor type) { - TypeDescriptor ret = - createIterableType(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); + TypeDescriptor ret = createIterableType(ReflectUtils.getIterableComponentType(type)); return returnRawTypes ? ret.getRawType() : ret.getType(); } @@ -691,8 +688,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { @Override protected StackManipulation convertIterable(TypeDescriptor type) { - TypeDescriptor componentType = - ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); + TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); @@ -712,8 +708,7 @@ protected StackManipulation convertIterable(TypeDescriptor type) { @Override protected StackManipulation convertCollection(TypeDescriptor type) { - TypeDescriptor componentType = - ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); + TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { @@ -732,8 +727,7 @@ protected StackManipulation convertCollection(TypeDescriptor type) { @Override protected StackManipulation convertList(TypeDescriptor type) { - TypeDescriptor componentType = - ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); + TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { @@ -752,8 +746,8 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { - final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0, Collections.emptyMap()); - final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1, Collections.emptyMap()); + final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0); + final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1); Type convertedKeyType = getFactory().createTypeConversion(true).convert(keyType); Type convertedValueType = getFactory().createTypeConversion(true).convert(valueType); @@ -1045,9 +1039,8 @@ protected StackManipulation convertIterable(TypeDescriptor type) { Type rowElementType = getFactory() .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); - final TypeDescriptor iterableElementType = - ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); + .convert(ReflectUtils.getIterableComponentType(type)); + final TypeDescriptor iterableElementType = ReflectUtils.getIterableComponentType(type); if (!iterableElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = new ForLoadedType( @@ -1068,9 +1061,8 @@ protected StackManipulation convertCollection(TypeDescriptor type) { Type rowElementType = getFactory() .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); - final TypeDescriptor collectionElementType = - ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); + .convert(ReflectUtils.getIterableComponentType(type)); + final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); if (!collectionElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = @@ -1092,9 +1084,8 @@ protected StackManipulation convertList(TypeDescriptor type) { Type rowElementType = getFactory() .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type, Collections.emptyMap())); - final TypeDescriptor collectionElementType = - ReflectUtils.getIterableComponentType(type, Collections.emptyMap()); + .convert(ReflectUtils.getIterableComponentType(type)); + final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); StackManipulation readTrasformedValue = readValue; if (!collectionElementType.hasUnresolvedParameters()) { @@ -1123,17 +1114,11 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { Type rowKeyType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getMapType(type, 0, Collections.emptyMap())); - final TypeDescriptor keyElementType = - ReflectUtils.getMapType(type, 0, Collections.emptyMap()); + getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 0)); + final TypeDescriptor keyElementType = ReflectUtils.getMapType(type, 0); Type rowValueType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getMapType(type, 1, Collections.emptyMap())); - final TypeDescriptor valueElementType = - ReflectUtils.getMapType(type, 1, Collections.emptyMap()); + getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 1)); + final TypeDescriptor valueElementType = ReflectUtils.getMapType(type, 1); StackManipulation readTrasformedValue = readValue; if (!keyElementType.hasUnresolvedParameters() @@ -1491,7 +1476,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { Parameter parameter = parameters.get(i); ForLoadedType convertedType = new ForLoadedType( - (Class) convertType.convert(TypeDescriptor.of(parameter.getParameterizedType()))); + (Class) convertType.convert(TypeDescriptor.of(parameter.getType()))); // The instruction to read the parameter. Use the fieldMapping to reorder parameters as // necessary. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java index e98a0b9495cf..7f2403035d97 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java @@ -22,7 +22,6 @@ import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Type; -import java.util.Collections; import java.util.ServiceLoader; import net.bytebuddy.ByteBuddy; import net.bytebuddy.asm.AsmVisitorWrapper; @@ -37,7 +36,6 @@ import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import net.bytebuddy.jar.asm.ClassWriter; import net.bytebuddy.matcher.ElementMatchers; -import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.JavaFieldSchema.JavaFieldTypeSupplier; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; @@ -58,7 +56,6 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) -@Internal public class ConvertHelpers { private static class SchemaInformationProviders { private static final ServiceLoader INSTANCE = @@ -151,8 +148,7 @@ public static SerializableFunction getConvertPrimitive( TypeDescriptor outputTypeDescriptor, TypeConversionsFactory typeConversionsFactory) { FieldType expectedFieldType = - StaticSchemaInference.fieldFromType( - outputTypeDescriptor, JavaFieldTypeSupplier.INSTANCE, Collections.emptyMap()); + StaticSchemaInference.fieldFromType(outputTypeDescriptor, JavaFieldTypeSupplier.INSTANCE); if (!expectedFieldType.equals(fieldType)) { throw new IllegalArgumentException( "Element argument type " diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java index 83f6b5c928d8..911f79f6eeed 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java @@ -22,7 +22,6 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import java.lang.reflect.Type; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -43,7 +42,6 @@ import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import net.bytebuddy.jar.asm.ClassWriter; import net.bytebuddy.matcher.ElementMatchers; -import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; @@ -63,15 +61,11 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) -@Internal public class JavaBeanUtils { /** Create a {@link Schema} for a Java Bean class. */ public static Schema schemaFromJavaBeanClass( - TypeDescriptor typeDescriptor, - FieldValueTypeSupplier fieldValueTypeSupplier, - Map boundTypes) { - return StaticSchemaInference.schemaFromClass( - typeDescriptor, fieldValueTypeSupplier, boundTypes); + TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { + return StaticSchemaInference.schemaFromClass(typeDescriptor, fieldValueTypeSupplier); } private static final String CONSTRUCTOR_HELP_STRING = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java index 1e60c9312cb3..571b9c690900 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java @@ -49,7 +49,6 @@ import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; import net.bytebuddy.jar.asm.ClassWriter; import net.bytebuddy.matcher.ElementMatchers; -import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; @@ -71,15 +70,11 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) }) -@Internal public class POJOUtils { public static Schema schemaFromPojoClass( - TypeDescriptor typeDescriptor, - FieldValueTypeSupplier fieldValueTypeSupplier, - Map boundTypes) { - return StaticSchemaInference.schemaFromClass( - typeDescriptor, fieldValueTypeSupplier, boundTypes); + TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { + return StaticSchemaInference.schemaFromClass(typeDescriptor, fieldValueTypeSupplier); } // Static ByteBuddy instance used by all helpers. @@ -306,7 +301,7 @@ public static SchemaUserTypeCreator createStaticCreator( field.getDeclaringClass(), typeConversionsFactory .createTypeConversion(false) - .convert(TypeDescriptor.of(field.getGenericType()))); + .convert(TypeDescriptor.of(field.getType()))); builder = implementGetterMethods(builder, field, typeInformation.getName(), typeConversionsFactory); try { @@ -388,7 +383,7 @@ private static FieldValueSetter createSetter( field.getDeclaringClass(), typeConversionsFactory .createTypeConversion(false) - .convert(TypeDescriptor.of(field.getGenericType()))); + .convert(TypeDescriptor.of(field.getType()))); builder = implementSetterMethods(builder, field, typeConversionsFactory); try { return builder @@ -496,7 +491,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Do any conversions necessary. typeConversionsFactory .createSetterConversions(readField) - .convert(TypeDescriptor.of(field.getGenericType())), + .convert(TypeDescriptor.of(field.getType())), // Now update the field and return void. FieldAccess.forField(new ForLoadedField(field)).write(), MethodReturn.VOID); @@ -551,8 +546,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { Field field = fields.get(i); ForLoadedType convertedType = - new ForLoadedType( - (Class) convertType.convert(TypeDescriptor.of(field.getGenericType()))); + new ForLoadedType((Class) convertType.convert(TypeDescriptor.of(field.getType()))); // The instruction to read the parameter. StackManipulation readParameter = @@ -569,7 +563,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Do any conversions necessary. typeConversionsFactory .createSetterConversions(readParameter) - .convert(TypeDescriptor.of(field.getGenericType())), + .convert(TypeDescriptor.of(field.getType())), // Now update the field. FieldAccess.forField(new ForLoadedField(field)).write()); stackManipulation = new StackManipulation.Compound(stackManipulation, updateField); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java index 32cfa5689193..4349a04c28ad 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java @@ -26,17 +26,16 @@ import java.lang.reflect.Modifier; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; -import java.lang.reflect.TypeVariable; import java.security.InvalidParameterException; import java.util.Arrays; import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.SchemaCreate; -import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; @@ -89,23 +88,14 @@ public static List getMethods(Class clazz) { return DECLARED_METHODS.computeIfAbsent( clazz, c -> { - List methods = Lists.newArrayList(); - do { - if (c.getPackage() != null && c.getPackage().getName().startsWith("java.")) { - break; // skip java built-in classes - } - Arrays.stream(c.getDeclaredMethods()) - .filter( - m -> - !m.isBridge()) // Covariant overloads insert bridge functions, which we must - // ignore. - .filter(m -> !Modifier.isPrivate(m.getModifiers())) - .filter(m -> !Modifier.isProtected(m.getModifiers())) - .filter(m -> !Modifier.isStatic(m.getModifiers())) - .forEach(methods::add); - c = c.getSuperclass(); - } while (c != null); - return methods; + return Arrays.stream(c.getDeclaredMethods()) + .filter( + m -> !m.isBridge()) // Covariant overloads insert bridge functions, which we must + // ignore. + .filter(m -> !Modifier.isPrivate(m.getModifiers())) + .filter(m -> !Modifier.isProtected(m.getModifiers())) + .filter(m -> !Modifier.isStatic(m.getModifiers())) + .collect(Collectors.toList()); }); } @@ -211,8 +201,7 @@ public static String stripSetterPrefix(String method) { } /** For an array T[] or a subclass of Iterable, return a TypeDescriptor describing T. */ - public static @Nullable TypeDescriptor getIterableComponentType( - TypeDescriptor valueType, Map boundTypes) { + public static @Nullable TypeDescriptor getIterableComponentType(TypeDescriptor valueType) { TypeDescriptor componentType = null; if (valueType.isArray()) { Type component = valueType.getComponentType().getType(); @@ -226,7 +215,7 @@ public static String stripSetterPrefix(String method) { ParameterizedType ptype = (ParameterizedType) collection.getType(); java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); checkArgument(params.length == 1); - componentType = TypeDescriptor.of(resolveType(params[0], boundTypes)); + componentType = TypeDescriptor.of(params[0]); } else { throw new RuntimeException("Collection parameter is not parameterized!"); } @@ -234,15 +223,14 @@ public static String stripSetterPrefix(String method) { return componentType; } - public static TypeDescriptor getMapType( - TypeDescriptor valueType, int index, Map boundTypes) { + public static TypeDescriptor getMapType(TypeDescriptor valueType, int index) { TypeDescriptor mapType = null; if (valueType.isSubtypeOf(TypeDescriptor.of(Map.class))) { TypeDescriptor> map = valueType.getSupertype(Map.class); if (map.getType() instanceof ParameterizedType) { ParameterizedType ptype = (ParameterizedType) map.getType(); java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); - mapType = TypeDescriptor.of(resolveType(params[index], boundTypes)); + mapType = TypeDescriptor.of(params[index]); } else { throw new RuntimeException("Map type is not parameterized! " + map); } @@ -255,49 +243,4 @@ public static TypeDescriptor boxIfPrimitive(TypeDescriptor typeDescriptor) { ? TypeDescriptor.of(Primitives.wrap(typeDescriptor.getRawType())) : typeDescriptor; } - - /** - * If this (or a base class)is a paremeterized type, return a map of all TypeVariable->Type - * bindings. This allows us to resolve types in any contained fields or methods. - */ - public static Map getAllBoundTypes(TypeDescriptor typeDescriptor) { - Map boundParameters = Maps.newHashMap(); - TypeDescriptor currentType = typeDescriptor; - do { - if (currentType.getType() instanceof ParameterizedType) { - ParameterizedType parameterizedType = (ParameterizedType) currentType.getType(); - TypeVariable[] typeVariables = currentType.getRawType().getTypeParameters(); - Type[] typeArguments = parameterizedType.getActualTypeArguments(); - ; - if (typeArguments.length != typeVariables.length) { - throw new RuntimeException("Unmatching arguments lengths in type " + typeDescriptor); - } - for (int i = 0; i < typeVariables.length; ++i) { - boundParameters.put(typeVariables[i], typeArguments[i]); - } - } - Type superClass = currentType.getRawType().getGenericSuperclass(); - if (superClass == null || superClass.equals(Object.class)) { - break; - } - currentType = TypeDescriptor.of(superClass); - } while (true); - return boundParameters; - } - - public static Type resolveType(Type type, Map boundTypes) { - TypeDescriptor typeDescriptor = TypeDescriptor.of(type); - if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Iterable.class)) - || typeDescriptor.isSubtypeOf(TypeDescriptor.of(Map.class))) { - // Don't resolve these as we special case map and interable. - return type; - } - - if (type instanceof TypeVariable) { - TypeVariable typeVariable = (TypeVariable) type; - return Preconditions.checkArgumentNotNull(boundTypes.get(typeVariable)); - } else { - return type; - } - } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java index 275bc41be53d..196ee6f86593 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java @@ -19,7 +19,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import java.lang.reflect.Type; +import java.lang.reflect.ParameterizedType; import java.math.BigDecimal; import java.nio.ByteBuffer; import java.util.Arrays; @@ -29,12 +29,10 @@ import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; -import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.FieldValueTypeInformation; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; -import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.ReadableInstant; @@ -44,7 +42,6 @@ "nullness", // TODO(https://github.com/apache/beam/issues/20497) "rawtypes" }) -@Internal public class StaticSchemaInference { public static List sortBySchema( List types, Schema schema) { @@ -88,17 +85,14 @@ enum MethodType { * public getter methods, or special annotations on the class. */ public static Schema schemaFromClass( - TypeDescriptor typeDescriptor, - FieldValueTypeSupplier fieldValueTypeSupplier, - Map boundTypes) { - return schemaFromClass(typeDescriptor, fieldValueTypeSupplier, new HashMap<>(), boundTypes); + TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { + return schemaFromClass(typeDescriptor, fieldValueTypeSupplier, new HashMap<>()); } private static Schema schemaFromClass( TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier, - Map, Schema> alreadyVisitedSchemas, - Map boundTypes) { + Map, Schema> alreadyVisitedSchemas) { if (alreadyVisitedSchemas.containsKey(typeDescriptor)) { Schema existingSchema = alreadyVisitedSchemas.get(typeDescriptor); if (existingSchema == null) { @@ -112,7 +106,7 @@ private static Schema schemaFromClass( Schema.Builder builder = Schema.builder(); for (FieldValueTypeInformation type : fieldValueTypeSupplier.get(typeDescriptor)) { Schema.FieldType fieldType = - fieldFromType(type.getType(), fieldValueTypeSupplier, alreadyVisitedSchemas, boundTypes); + fieldFromType(type.getType(), fieldValueTypeSupplier, alreadyVisitedSchemas); Schema.Field f = type.isNullable() ? Schema.Field.nullable(type.getName(), fieldType) @@ -129,18 +123,15 @@ private static Schema schemaFromClass( /** Map a Java field type to a Beam Schema FieldType. */ public static Schema.FieldType fieldFromType( - TypeDescriptor type, - FieldValueTypeSupplier fieldValueTypeSupplier, - Map boundTypes) { - return fieldFromType(type, fieldValueTypeSupplier, new HashMap<>(), boundTypes); + TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier) { + return fieldFromType(type, fieldValueTypeSupplier, new HashMap<>()); } // TODO(https://github.com/apache/beam/issues/21567): support type inference for logical types private static Schema.FieldType fieldFromType( TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier, - Map, Schema> alreadyVisitedSchemas, - Map boundTypes) { + Map, Schema> alreadyVisitedSchemas) { FieldType primitiveType = PRIMITIVE_TYPES.get(type.getRawType()); if (primitiveType != null) { return primitiveType; @@ -161,25 +152,27 @@ private static Schema.FieldType fieldFromType( } else { // Otherwise this is an array type. return FieldType.array( - fieldFromType(component, fieldValueTypeSupplier, alreadyVisitedSchemas, boundTypes)); + fieldFromType(component, fieldValueTypeSupplier, alreadyVisitedSchemas)); } } else if (type.isSubtypeOf(TypeDescriptor.of(Map.class))) { - FieldType keyType = - fieldFromType( - ReflectUtils.getMapType(type, 0, boundTypes), - fieldValueTypeSupplier, - alreadyVisitedSchemas, - boundTypes); - FieldType valueType = - fieldFromType( - ReflectUtils.getMapType(type, 1, boundTypes), - fieldValueTypeSupplier, - alreadyVisitedSchemas, - boundTypes); - checkArgument( - keyType.getTypeName().isPrimitiveType(), - "Only primitive types can be map keys. type: " + keyType.getTypeName()); - return FieldType.map(keyType, valueType); + TypeDescriptor> map = type.getSupertype(Map.class); + if (map.getType() instanceof ParameterizedType) { + ParameterizedType ptype = (ParameterizedType) map.getType(); + java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); + checkArgument(params.length == 2); + FieldType keyType = + fieldFromType( + TypeDescriptor.of(params[0]), fieldValueTypeSupplier, alreadyVisitedSchemas); + FieldType valueType = + fieldFromType( + TypeDescriptor.of(params[1]), fieldValueTypeSupplier, alreadyVisitedSchemas); + checkArgument( + keyType.getTypeName().isPrimitiveType(), + "Only primitive types can be map keys. type: " + keyType.getTypeName()); + return FieldType.map(keyType, valueType); + } else { + throw new RuntimeException("Cannot infer schema from unparameterized map."); + } } else if (type.isSubtypeOf(TypeDescriptor.of(CharSequence.class))) { return FieldType.STRING; } else if (type.isSubtypeOf(TypeDescriptor.of(ReadableInstant.class))) { @@ -187,22 +180,26 @@ private static Schema.FieldType fieldFromType( } else if (type.isSubtypeOf(TypeDescriptor.of(ByteBuffer.class))) { return FieldType.BYTES; } else if (type.isSubtypeOf(TypeDescriptor.of(Iterable.class))) { - FieldType elementType = - fieldFromType( - Preconditions.checkArgumentNotNull( - ReflectUtils.getIterableComponentType(type, boundTypes)), - fieldValueTypeSupplier, - alreadyVisitedSchemas, - boundTypes); - // TODO: should this be AbstractCollection? - if (type.isSubtypeOf(TypeDescriptor.of(Collection.class))) { - return FieldType.array(elementType); + TypeDescriptor> iterable = type.getSupertype(Iterable.class); + if (iterable.getType() instanceof ParameterizedType) { + ParameterizedType ptype = (ParameterizedType) iterable.getType(); + java.lang.reflect.Type[] params = ptype.getActualTypeArguments(); + checkArgument(params.length == 1); + // TODO: should this be AbstractCollection? + if (type.isSubtypeOf(TypeDescriptor.of(Collection.class))) { + return FieldType.array( + fieldFromType( + TypeDescriptor.of(params[0]), fieldValueTypeSupplier, alreadyVisitedSchemas)); + } else { + return FieldType.iterable( + fieldFromType( + TypeDescriptor.of(params[0]), fieldValueTypeSupplier, alreadyVisitedSchemas)); + } } else { - return FieldType.iterable(elementType); + throw new RuntimeException("Cannot infer schema from unparameterized collection."); } } else { - return FieldType.row( - schemaFromClass(type, fieldValueTypeSupplier, alreadyVisitedSchemas, boundTypes)); + return FieldType.row(schemaFromClass(type, fieldValueTypeSupplier, alreadyVisitedSchemas)); } } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java index 49fd2bfe2259..d0ee623dea7c 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java @@ -28,7 +28,6 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.util.Map; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; @@ -40,7 +39,6 @@ import org.apache.beam.sdk.schemas.utils.SchemaTestUtils; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.CaseFormat; import org.joda.time.DateTime; import org.joda.time.Instant; @@ -888,151 +886,4 @@ public void testSchema_SchemaFieldDescription() throws NoSuchSchemaException { assertEquals(FIELD_DESCRIPTION_SCHEMA.getField("lng"), schema.getField("lng")); assertEquals(FIELD_DESCRIPTION_SCHEMA.getField("str"), schema.getField("str")); } - - @AutoValue - @DefaultSchema(AutoValueSchema.class) - abstract static class ParameterizedAutoValue { - abstract W getValue1(); - - abstract T getValue2(); - - abstract V getValue3(); - - abstract X getValue4(); - } - - @Test - public void testAutoValueWithTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor> typeDescriptor = - new TypeDescriptor>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_SCHEMA) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @DefaultSchema(AutoValueSchema.class) - abstract static class ParameterizedAutoValueSubclass - extends ParameterizedAutoValue { - abstract T getValue5(); - } - - @Test - public void testAutoValueWithInheritedTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor> typeDescriptor = - new TypeDescriptor>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_SCHEMA) - .addInt16Field("value5") - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @AutoValue - @DefaultSchema(AutoValueSchema.class) - abstract static class NestedParameterizedCollectionAutoValue { - abstract Iterable getNested(); - - abstract Map getMap(); - } - - @Test - public void testAutoValueWithNestedCollectionTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - NestedParameterizedCollectionAutoValue< - ParameterizedAutoValue, String>> - typeDescriptor = - new TypeDescriptor< - NestedParameterizedCollectionAutoValue< - ParameterizedAutoValue, String>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder() - .addIterableField("nested", FieldType.row(expectedInnerSchema)) - .addMapField("map", FieldType.STRING, FieldType.row(expectedInnerSchema)) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testAutoValueWithDoublyNestedCollectionTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - NestedParameterizedCollectionAutoValue< - Iterable>, String>> - typeDescriptor = - new TypeDescriptor< - NestedParameterizedCollectionAutoValue< - Iterable>, - String>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder() - .addIterableField("nested", FieldType.iterable(FieldType.row(expectedInnerSchema))) - .addMapField( - "map", FieldType.STRING, FieldType.iterable(FieldType.row(expectedInnerSchema))) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @AutoValue - @DefaultSchema(AutoValueSchema.class) - abstract static class NestedParameterizedAutoValue { - abstract T getNested(); - } - - @Test - public void testAutoValueWithNestedTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - NestedParameterizedAutoValue< - ParameterizedAutoValue>> - typeDescriptor = - new TypeDescriptor< - NestedParameterizedAutoValue< - ParameterizedAutoValue>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder().addRowField("nested", expectedInnerSchema).build(); - assertTrue(expectedSchema.equivalent(schema)); - } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java index 2252c3aef0db..5313feb5c6c0 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java @@ -68,7 +68,6 @@ import org.apache.beam.sdk.schemas.utils.TestJavaBeans.SimpleBeanWithAnnotations; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -626,127 +625,4 @@ public void testSetterConstructionWithRenamedFields() throws NoSuchSchemaExcepti assertEquals( registry.getFromRowFunction(BeanWithCaseFormat.class).apply(row), beanWithCaseFormat); } - - @Test - public void testBeanWithTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor> - typeDescriptor = - new TypeDescriptor< - TestJavaBeans.SimpleParameterizedBean>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_BEAN_SCHEMA) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testBeanWithInheritedTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor> typeDescriptor = - new TypeDescriptor>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_BEAN_SCHEMA) - .addInt16Field("value5") - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testBeanWithNestedCollectionTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - TestJavaBeans.NestedParameterizedCollectionBean< - TestJavaBeans.SimpleParameterizedBean, String>> - typeDescriptor = - new TypeDescriptor< - TestJavaBeans.NestedParameterizedCollectionBean< - TestJavaBeans.SimpleParameterizedBean, - String>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_BEAN_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder() - .addIterableField("nested", Schema.FieldType.row(expectedInnerSchema)) - .addMapField("map", Schema.FieldType.STRING, Schema.FieldType.row(expectedInnerSchema)) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testBeanWithDoublyNestedCollectionTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - TestJavaBeans.NestedParameterizedCollectionBean< - Iterable>, - String>> - typeDescriptor = - new TypeDescriptor< - TestJavaBeans.NestedParameterizedCollectionBean< - Iterable< - TestJavaBeans.SimpleParameterizedBean>, - String>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_BEAN_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder() - .addIterableField( - "nested", Schema.FieldType.iterable(Schema.FieldType.row(expectedInnerSchema))) - .addMapField( - "map", - Schema.FieldType.STRING, - Schema.FieldType.iterable(Schema.FieldType.row(expectedInnerSchema))) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testBeanWithNestedTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - TestJavaBeans.NestedParameterizedBean< - TestJavaBeans.SimpleParameterizedBean>> - typeDescriptor = - new TypeDescriptor< - TestJavaBeans.NestedParameterizedBean< - TestJavaBeans.SimpleParameterizedBean>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_BEAN_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder().addRowField("nested", expectedInnerSchema).build(); - assertTrue(expectedSchema.equivalent(schema)); - } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java index 70bc3030924b..11bef79b26f7 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java @@ -76,7 +76,6 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -782,123 +781,4 @@ public void testCircularNestedPOJOThrows() throws NoSuchSchemaException { thrown.getMessage(), containsString("TestPOJOs$FirstCircularNestedPOJO")); } - - @Test - public void testPojoWithTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor> - typeDescriptor = - new TypeDescriptor< - TestPOJOs.SimpleParameterizedPOJO>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_POJO_SCHEMA) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testPojoWithInheritedTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor> typeDescriptor = - new TypeDescriptor>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_POJO_SCHEMA) - .addInt16Field("value5") - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testPojoWithNestedCollectionTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - TestPOJOs.NestedParameterizedCollectionPOJO< - TestPOJOs.SimpleParameterizedPOJO, String>> - typeDescriptor = - new TypeDescriptor< - TestPOJOs.NestedParameterizedCollectionPOJO< - TestPOJOs.SimpleParameterizedPOJO, - String>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_POJO_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder() - .addIterableField("nested", FieldType.row(expectedInnerSchema)) - .addMapField("map", FieldType.STRING, FieldType.row(expectedInnerSchema)) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testPojoWithDoublyNestedCollectionTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - TestPOJOs.NestedParameterizedCollectionPOJO< - Iterable>, - String>> - typeDescriptor = - new TypeDescriptor< - TestPOJOs.NestedParameterizedCollectionPOJO< - Iterable>, - String>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_POJO_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder() - .addIterableField("nested", FieldType.iterable(FieldType.row(expectedInnerSchema))) - .addMapField( - "map", FieldType.STRING, FieldType.iterable(FieldType.row(expectedInnerSchema))) - .build(); - assertTrue(expectedSchema.equivalent(schema)); - } - - @Test - public void testPojoWithNestedTypeParameter() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - TypeDescriptor< - TestPOJOs.NestedParameterizedPOJO< - TestPOJOs.SimpleParameterizedPOJO>> - typeDescriptor = - new TypeDescriptor< - TestPOJOs.NestedParameterizedPOJO< - TestPOJOs.SimpleParameterizedPOJO>>() {}; - Schema schema = registry.getSchema(typeDescriptor); - - final Schema expectedInnerSchema = - Schema.builder() - .addBooleanField("value1") - .addStringField("value2") - .addInt64Field("value3") - .addRowField("value4", SIMPLE_POJO_SCHEMA) - .build(); - final Schema expectedSchema = - Schema.builder().addRowField("nested", expectedInnerSchema).build(); - assertTrue(expectedSchema.equivalent(schema)); - } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java index e0a45c2c82fe..021e39b84849 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java @@ -34,7 +34,6 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.util.Collections; import java.util.List; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; @@ -66,9 +65,7 @@ public class JavaBeanUtilsTest { public void testNullable() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); assertTrue(schema.getField("str").getType().getNullable()); assertFalse(schema.getField("anInt").getType().getNullable()); } @@ -77,9 +74,7 @@ public void testNullable() { public void testSimpleBean() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(SIMPLE_BEAN_SCHEMA, schema); } @@ -87,9 +82,7 @@ public void testSimpleBean() { public void testNestedBean() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_BEAN_SCHEMA, schema); } @@ -97,9 +90,7 @@ public void testNestedBean() { public void testPrimitiveArray() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_ARRAY_BEAN_SCHEMA, schema); } @@ -107,9 +98,7 @@ public void testPrimitiveArray() { public void testNestedArray() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_ARRAY_BEAN_SCHEMA, schema); } @@ -117,9 +106,7 @@ public void testNestedArray() { public void testNestedCollection() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_COLLECTION_BEAN_SCHEMA, schema); } @@ -127,9 +114,7 @@ public void testNestedCollection() { public void testPrimitiveMap() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_MAP_BEAN_SCHEMA, schema); } @@ -137,9 +122,7 @@ public void testPrimitiveMap() { public void testNestedMap() { Schema schema = JavaBeanUtils.schemaFromJavaBeanClass( - new TypeDescriptor() {}, - GetterTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, GetterTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_MAP_BEAN_SCHEMA, schema); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java index 46c098dddaeb..723353ed8d15 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java @@ -35,7 +35,6 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.util.Collections; import java.util.List; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.FieldValueSetter; @@ -72,9 +71,7 @@ public class POJOUtilsTest { public void testNullables() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); assertTrue(schema.getField("str").getType().getNullable()); assertFalse(schema.getField("anInt").getType().getNullable()); } @@ -83,9 +80,7 @@ public void testNullables() { public void testSimplePOJO() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); assertEquals(SIMPLE_POJO_SCHEMA, schema); } @@ -93,9 +88,7 @@ public void testSimplePOJO() { public void testNestedPOJO() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_POJO_SCHEMA, schema); } @@ -104,8 +97,7 @@ public void testNestedPOJOWithSimplePOJO() { Schema schema = POJOUtils.schemaFromPojoClass( new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_POJO_WITH_SIMPLE_POJO_SCHEMA, schema); } @@ -113,9 +105,7 @@ public void testNestedPOJOWithSimplePOJO() { public void testPrimitiveArray() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_ARRAY_POJO_SCHEMA, schema); } @@ -123,9 +113,7 @@ public void testPrimitiveArray() { public void testNestedArray() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_ARRAY_POJO_SCHEMA, schema); } @@ -133,9 +121,7 @@ public void testNestedArray() { public void testNestedCollection() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_COLLECTION_POJO_SCHEMA, schema); } @@ -143,9 +129,7 @@ public void testNestedCollection() { public void testPrimitiveMap() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(PRIMITIVE_MAP_POJO_SCHEMA, schema); } @@ -153,9 +137,7 @@ public void testPrimitiveMap() { public void testNestedMap() { Schema schema = POJOUtils.schemaFromPojoClass( - new TypeDescriptor() {}, - JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap()); + new TypeDescriptor() {}, JavaFieldTypeSupplier.INSTANCE); SchemaTestUtils.assertSchemaEquivalent(NESTED_MAP_POJO_SCHEMA, schema); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java index cbc976144971..b5ad6f989d9e 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java @@ -1397,95 +1397,4 @@ public void setValue(@Nullable Float value) { Schema.Field.nullable("value", FieldType.FLOAT) .withDescription("This value is the value stored in the object as a float.")) .build(); - - @DefaultSchema(JavaBeanSchema.class) - public static class SimpleParameterizedBean { - @Nullable private W value1; - @Nullable private T value2; - @Nullable private V value3; - @Nullable private X value4; - - public W getValue1() { - return value1; - } - - public void setValue1(W value1) { - this.value1 = value1; - } - - public T getValue2() { - return value2; - } - - public void setValue2(T value2) { - this.value2 = value2; - } - - public V getValue3() { - return value3; - } - - public void setValue3(V value3) { - this.value3 = value3; - } - - public X getValue4() { - return value4; - } - - public void setValue4(X value4) { - this.value4 = value4; - } - } - - @DefaultSchema(JavaBeanSchema.class) - public static class SimpleParameterizedBeanSubclass - extends SimpleParameterizedBean { - @Nullable private T value5; - - public SimpleParameterizedBeanSubclass() {} - - public T getValue5() { - return value5; - } - - public void setValue5(T value5) { - this.value5 = value5; - } - } - - @DefaultSchema(JavaBeanSchema.class) - public static class NestedParameterizedCollectionBean { - private Iterable nested; - private Map map; - - public Iterable getNested() { - return nested; - } - - public Map getMap() { - return map; - } - - public void setNested(Iterable nested) { - this.nested = nested; - } - - public void setMap(Map map) { - this.map = map; - } - } - - @DefaultSchema(JavaBeanSchema.class) - public static class NestedParameterizedBean { - private T nested; - - public T getNested() { - return nested; - } - - public void setNested(T nested) { - this.nested = nested; - } - } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java index ce7409365d09..789de02adee8 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java @@ -495,125 +495,6 @@ public int hashCode() { .addStringField("stringBuilder") .build(); - @DefaultSchema(JavaFieldSchema.class) - public static class SimpleParameterizedPOJO { - public W value1; - public T value2; - public V value3; - public X value4; - - public SimpleParameterizedPOJO() {} - - public SimpleParameterizedPOJO(W value1, T value2, V value3, X value4) { - this.value1 = value1; - this.value2 = value2; - this.value3 = value3; - this.value4 = value4; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof SimpleParameterizedPOJO)) { - return false; - } - SimpleParameterizedPOJO that = (SimpleParameterizedPOJO) o; - return Objects.equals(value1, that.value1) - && Objects.equals(value2, that.value2) - && Objects.equals(value3, that.value3) - && Objects.equals(value4, that.value4); - } - - @Override - public int hashCode() { - return Objects.hash(value1, value2, value3, value4); - } - } - - @DefaultSchema(JavaFieldSchema.class) - public static class SimpleParameterizedPOJOSubclass - extends SimpleParameterizedPOJO { - public T value5; - - public SimpleParameterizedPOJOSubclass() {} - - public SimpleParameterizedPOJOSubclass(T value5) { - this.value5 = value5; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof SimpleParameterizedPOJOSubclass)) { - return false; - } - SimpleParameterizedPOJOSubclass that = (SimpleParameterizedPOJOSubclass) o; - return Objects.equals(value5, that.value5); - } - - @Override - public int hashCode() { - return Objects.hash(value4); - } - } - - @DefaultSchema(JavaFieldSchema.class) - public static class NestedParameterizedCollectionPOJO { - public Iterable nested; - public Map map; - - public NestedParameterizedCollectionPOJO(Iterable nested, Map map) { - this.nested = nested; - this.map = map; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof NestedParameterizedCollectionPOJO)) { - return false; - } - NestedParameterizedCollectionPOJO that = (NestedParameterizedCollectionPOJO) o; - return Objects.equals(nested, that.nested) && Objects.equals(map, that.map); - } - - @Override - public int hashCode() { - return Objects.hash(nested, map); - } - } - - @DefaultSchema(JavaFieldSchema.class) - public static class NestedParameterizedPOJO { - public T nested; - - public NestedParameterizedPOJO(T nested) { - this.nested = nested; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof NestedParameterizedPOJO)) { - return false; - } - NestedParameterizedPOJO that = (NestedParameterizedPOJO) o; - return Objects.equals(nested, that.nested); - } - - @Override - public int hashCode() { - return Objects.hash(nested); - } - } /** A POJO containing a nested class. * */ @DefaultSchema(JavaFieldSchema.class) public static class NestedPOJO { @@ -1006,7 +887,7 @@ public boolean equals(@Nullable Object o) { if (this == o) { return true; } - if (!(o instanceof PojoWithIterable)) { + if (!(o instanceof PojoWithNestedArray)) { return false; } PojoWithIterable that = (PojoWithIterable) o; diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java index 1a530a3f6ca5..0a82663c1771 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroByteBuddyUtils.java @@ -78,8 +78,8 @@ private static SchemaUserTypeCreator createCreator(Class clazz, Schema sc // Generate a method call to create and invoke the SpecificRecord's constructor. . MethodCall construct = MethodCall.construct(baseConstructor); - for (int i = 0; i < baseConstructor.getGenericParameterTypes().length; ++i) { - Type baseType = baseConstructor.getGenericParameterTypes()[i]; + for (int i = 0; i < baseConstructor.getParameterTypes().length; ++i) { + Class baseType = baseConstructor.getParameterTypes()[i]; construct = construct.with(readAndConvertParameter(baseType, i), baseType); } @@ -110,7 +110,7 @@ private static SchemaUserTypeCreator createCreator(Class clazz, Schema sc } private static StackManipulation readAndConvertParameter( - Type constructorParameterType, int index) { + Class constructorParameterType, int index) { TypeConversionsFactory typeConversionsFactory = new AvroUtils.AvroTypeConversionFactory(); // The types in the AVRO-generated constructor might be the types returned by Beam's Row class, diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java index 1324d254e44e..1b1c45969307 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java @@ -814,9 +814,6 @@ public List get(TypeDescriptor typeDescriptor) { @Override public List get(TypeDescriptor typeDescriptor, Schema schema) { - Map boundTypes = - ReflectUtils.getAllBoundTypes(typeDescriptor); - Map mapping = getMapping(schema); List methods = ReflectUtils.getMethods(typeDescriptor.getRawType()); List types = Lists.newArrayList(); @@ -824,7 +821,7 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = methods.get(i); if (ReflectUtils.isGetter(method)) { FieldValueTypeInformation fieldValueTypeInformation = - FieldValueTypeInformation.forGetter(method, i, boundTypes); + FieldValueTypeInformation.forGetter(method, i); String name = mapping.get(fieldValueTypeInformation.getName()); if (name != null) { types.add(fieldValueTypeInformation.withName(name)); @@ -868,16 +865,13 @@ private Map getMapping(Schema schema) { private static final class AvroPojoFieldValueTypeSupplier implements FieldValueTypeSupplier { @Override public List get(TypeDescriptor typeDescriptor) { - Map boundTypes = - ReflectUtils.getAllBoundTypes(typeDescriptor); List classFields = ReflectUtils.getFields(typeDescriptor.getRawType()); Map types = Maps.newHashMap(); for (int i = 0; i < classFields.size(); ++i) { java.lang.reflect.Field f = classFields.get(i); if (!f.isAnnotationPresent(AvroIgnore.class)) { - FieldValueTypeInformation typeInformation = - FieldValueTypeInformation.forField(f, i, boundTypes); + FieldValueTypeInformation typeInformation = FieldValueTypeInformation.forField(f, i); AvroName avroname = f.getAnnotation(AvroName.class); if (avroname != null) { typeInformation = typeInformation.withName(avroname.value()); diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java index fcfc40403b43..d159e9de44a8 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java @@ -39,7 +39,6 @@ import java.lang.reflect.Modifier; import java.lang.reflect.Type; import java.util.Arrays; -import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -1046,8 +1045,7 @@ FieldValueSetter getProtoFieldValueSetter( } else { Method method = getProtoSetter(methods, field.getName(), field.getType()); return JavaBeanUtils.createSetter( - FieldValueTypeInformation.forSetter( - method, protoSetterPrefix(field.getType()), Collections.emptyMap()), + FieldValueTypeInformation.forSetter(method, protoSetterPrefix(field.getType())), new ProtoTypeConversionsFactory()); } } diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java index 4b8d51abdea6..faf3ad407af5 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java @@ -23,7 +23,6 @@ import com.google.protobuf.DynamicMessage; import com.google.protobuf.Message; import java.lang.reflect.Method; -import java.util.Collections; import java.util.List; import java.util.Map; import org.apache.beam.sdk.extensions.protobuf.ProtoByteBuddyUtils.ProtoTypeConversionsFactory; @@ -73,8 +72,7 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = getProtoGetter(methods, oneOfField.getName(), oneOfField.getType()); oneOfTypes.put( oneOfField.getName(), - FieldValueTypeInformation.forGetter(method, i, Collections.emptyMap()) - .withName(field.getName())); + FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); } // Add an entry that encapsulates information about all possible getters. types.add( @@ -84,9 +82,7 @@ public List get(TypeDescriptor typeDescriptor, Sch } else { // This is a simple field. Add the getter. Method method = getProtoGetter(methods, field.getName(), field.getType()); - types.add( - FieldValueTypeInformation.forGetter(method, i, Collections.emptyMap()) - .withName(field.getName())); + types.add(FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); } } return types; diff --git a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java index 64f600903d87..d5f1745a9a2c 100644 --- a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java +++ b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java @@ -25,7 +25,6 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -390,8 +389,7 @@ private Schema generateSchemaDirectly( fieldName, StaticSchemaInference.fieldFromType( TypeDescriptor.of(field.getClass()), - JavaFieldSchema.JavaFieldTypeSupplier.INSTANCE, - Collections.emptyMap())); + JavaFieldSchema.JavaFieldTypeSupplier.INSTANCE)); } counter++; diff --git a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java index 73b3709da832..5f4e195f227f 100644 --- a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java +++ b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java @@ -242,11 +242,10 @@ private FieldValueTypeInformation fieldValueTypeInfo(Class type, String field if (factoryMethods.size() > 1) { throw new IllegalStateException("Overloaded factory methods: " + factoryMethods); } - return FieldValueTypeInformation.forSetter(factoryMethods.get(0), "", Collections.emptyMap()); + return FieldValueTypeInformation.forSetter(factoryMethods.get(0), ""); } else { try { - return FieldValueTypeInformation.forField( - type.getDeclaredField(fieldName), 0, Collections.emptyMap()); + return FieldValueTypeInformation.forField(type.getDeclaredField(fieldName), 0); } catch (NoSuchFieldException e) { throw new IllegalArgumentException(e); } From 62b083095966fb11be3121abdebe18ab954339d7 Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:34:16 -0400 Subject: [PATCH 044/181] Remove Python 3.8 Support from Apache Beam (#32643) * Remove Python 3.8 Support from Apache Beam * Remove artifact build/publishing * Address comments --- .../workflows/beam_Publish_Beam_SDK_Snapshots.yml | 1 - .github/workflows/build_wheels.yml | 4 ++-- .test-infra/jenkins/metrics_report/tox.ini | 2 +- .test-infra/mock-apis/pyproject.toml | 2 +- .test-infra/tools/python_installer.sh | 2 +- .../org/apache/beam/gradle/BeamModulePlugin.groovy | 1 - contributor-docs/release-guide.md | 2 +- gradle.properties | 2 +- local-env-setup.sh | 4 ++-- .../cloudbuild/playground_cd_examples.sh | 10 +++++----- .../cloudbuild/playground_ci_examples.sh | 10 +++++----- release/src/main/Dockerfile | 3 +-- .../main/python-release/python_release_automation.sh | 2 +- sdks/python/apache_beam/__init__.py | 8 +------- .../ml/inference/test_resources/vllm.dockerfile | 2 +- sdks/python/expansion-service-container/Dockerfile | 2 +- sdks/python/setup.py | 12 +++--------- 17 files changed, 27 insertions(+), 42 deletions(-) diff --git a/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml b/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml index 7107385c1722..e3791119be90 100644 --- a/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml +++ b/.github/workflows/beam_Publish_Beam_SDK_Snapshots.yml @@ -66,7 +66,6 @@ jobs: - "java:container:java11" - "java:container:java17" - "java:container:java21" - - "python:container:py38" - "python:container:py39" - "python:container:py310" - "python:container:py311" diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index 0a15ba9d150c..828a6328c0cd 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -49,7 +49,7 @@ jobs: env: EVENT_NAME: ${{ github.event_name }} # Keep in sync with py_version matrix value below - if changed, change that as well. - PY_VERSIONS_FULL: "cp38-* cp39-* cp310-* cp311-* cp312-*" + PY_VERSIONS_FULL: "cp39-* cp310-* cp311-* cp312-*" outputs: gcp-variables-set: ${{ steps.check_gcp_variables.outputs.gcp-variables-set }} py-versions-full: ${{ steps.set-py-versions.outputs.py-versions-full }} @@ -229,7 +229,7 @@ jobs: {"os": "ubuntu-20.04", "runner": [self-hosted, ubuntu-20.04, main], "python": "${{ needs.check_env_variables.outputs.py-versions-test }}", arch: "aarch64" } ] # Keep in sync (remove asterisks) with PY_VERSIONS_FULL env var above - if changed, change that as well. - py_version: ["cp38-", "cp39-", "cp310-", "cp311-", "cp312-"] + py_version: ["cp39-", "cp310-", "cp311-", "cp312-"] steps: - name: Download python source distribution from artifacts uses: actions/download-artifact@v4.1.8 diff --git a/.test-infra/jenkins/metrics_report/tox.ini b/.test-infra/jenkins/metrics_report/tox.ini index 026db5dc4860..d143a0dcf59c 100644 --- a/.test-infra/jenkins/metrics_report/tox.ini +++ b/.test-infra/jenkins/metrics_report/tox.ini @@ -17,7 +17,7 @@ ; TODO(https://github.com/apache/beam/issues/20209): Don't hardcode Py3.8 version. [tox] skipsdist = True -envlist = py38-test,py38-generate-report +envlist = py39-test,py39-generate-report [testenv] commands_pre = diff --git a/.test-infra/mock-apis/pyproject.toml b/.test-infra/mock-apis/pyproject.toml index 680bf489ba13..c98d9152cfb9 100644 --- a/.test-infra/mock-apis/pyproject.toml +++ b/.test-infra/mock-apis/pyproject.toml @@ -27,7 +27,7 @@ packages = [ ] [tool.poetry.dependencies] -python = "^3.8" +python = "^3.9" google = "^3.0.0" grpcio = "^1.53.0" grpcio-tools = "^1.53.0" diff --git a/.test-infra/tools/python_installer.sh b/.test-infra/tools/python_installer.sh index b1b05e597cb3..04e10555243a 100644 --- a/.test-infra/tools/python_installer.sh +++ b/.test-infra/tools/python_installer.sh @@ -20,7 +20,7 @@ set -euo pipefail # Variable containing the python versions to install -python_versions_arr=("3.8.16" "3.9.16" "3.10.10" "3.11.4") +python_versions_arr=("3.9.16" "3.10.10" "3.11.4", "3.12.6") # Install pyenv dependencies. pyenv_dep(){ diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 576b8defb40b..8a094fd56217 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -3152,7 +3152,6 @@ class BeamModulePlugin implements Plugin { mustRunAfter = [ ":runners:flink:${project.ext.latestFlinkVersion}:job-server:shadowJar", ':runners:spark:3:job-server:shadowJar', - ':sdks:python:container:py38:docker', ':sdks:python:container:py39:docker', ':sdks:python:container:py310:docker', ':sdks:python:container:py311:docker', diff --git a/contributor-docs/release-guide.md b/contributor-docs/release-guide.md index d351049c96cd..51f06adf50e4 100644 --- a/contributor-docs/release-guide.md +++ b/contributor-docs/release-guide.md @@ -507,7 +507,7 @@ with tags: `${RELEASE_VERSION}rc${RC_NUM}` Verify that third party licenses are included in Docker. You can do this with a simple script: RC_TAG=${RELEASE_VERSION}rc${RC_NUM} - for pyver in 3.8 3.9 3.10 3.11; do + for pyver in 3.9 3.10 3.11 3.12; do docker run --rm --entrypoint sh \ apache/beam_python${pyver}_sdk:${RC_TAG} \ -c 'ls -al /opt/apache/beam/third_party_licenses/ | wc -l' diff --git a/gradle.properties b/gradle.properties index f6e143690a34..4b3a752f0633 100644 --- a/gradle.properties +++ b/gradle.properties @@ -41,4 +41,4 @@ docker_image_default_repo_prefix=beam_ # supported flink versions flink_versions=1.15,1.16,1.17,1.18,1.19 # supported python versions -python_versions=3.8,3.9,3.10,3.11,3.12 +python_versions=3.9,3.10,3.11,3.12 diff --git a/local-env-setup.sh b/local-env-setup.sh index f13dc88432a6..ba30813b2bcc 100755 --- a/local-env-setup.sh +++ b/local-env-setup.sh @@ -55,7 +55,7 @@ if [ "$kernelname" = "Linux" ]; then exit fi - for ver in 3.8 3.9 3.10 3.11 3.12 3; do + for ver in 3.9 3.10 3.11 3.12 3; do apt install --yes python$ver-venv done @@ -89,7 +89,7 @@ elif [ "$kernelname" = "Darwin" ]; then echo "Installing openjdk@8" brew install openjdk@8 fi - for ver in 3.8 3.9 3.10 3.11 3.12; do + for ver in 3.9 3.10 3.11 3.12; do if brew ls --versions python@$ver > /dev/null; then echo "python@$ver already installed. Skipping" brew info python@$ver diff --git a/playground/infrastructure/cloudbuild/playground_cd_examples.sh b/playground/infrastructure/cloudbuild/playground_cd_examples.sh index d05773656b30..e571bc9fc9d9 100644 --- a/playground/infrastructure/cloudbuild/playground_cd_examples.sh +++ b/playground/infrastructure/cloudbuild/playground_cd_examples.sh @@ -97,15 +97,15 @@ LogOutput "Installing python and dependencies." export DEBIAN_FRONTEND=noninteractive apt install -y apt-transport-https ca-certificates software-properties-common curl unzip apt-utils > /dev/null 2>&1 add-apt-repository -y ppa:deadsnakes/ppa > /dev/null 2>&1 && apt update > /dev/null 2>&1 -apt install -y python3.8 python3.8-distutils python3-pip > /dev/null 2>&1 -apt install -y --reinstall python3.8-distutils > /dev/null 2>&1 +apt install -y python3.9 python3-distutils python3-pip > /dev/null 2>&1 +apt install -y --reinstall python3-distutils > /dev/null 2>&1 apt install -y python3-virtualenv virtualenv play_venv source play_venv/bin/activate pip install --upgrade google-api-python-client > /dev/null 2>&1 -python3.8 -m pip install pip --upgrade > /dev/null 2>&1 -ln -s /usr/bin/python3.8 /usr/bin/python > /dev/null 2>&1 -apt install -y python3.8-venv > /dev/null 2>&1 +python3.9 -m pip install pip --upgrade > /dev/null 2>&1 +ln -s /usr/bin/python3.9 /usr/bin/python > /dev/null 2>&1 +apt install -y python3.9-venv > /dev/null 2>&1 LogOutput "Installing Python packages from beam/playground/infrastructure/requirements.txt" cd $BEAM_ROOT_DIR diff --git a/playground/infrastructure/cloudbuild/playground_ci_examples.sh b/playground/infrastructure/cloudbuild/playground_ci_examples.sh index 437cc337faf7..2a63382615a5 100755 --- a/playground/infrastructure/cloudbuild/playground_ci_examples.sh +++ b/playground/infrastructure/cloudbuild/playground_ci_examples.sh @@ -94,12 +94,12 @@ export DEBIAN_FRONTEND=noninteractive LogOutput "Installing Python environment" apt-get install -y apt-transport-https ca-certificates software-properties-common curl unzip apt-utils > /dev/null add-apt-repository -y ppa:deadsnakes/ppa > /dev/null && apt update > /dev/null -apt install -y python3.8 python3.8-distutils python3-pip > /dev/null -apt install --reinstall python3.8-distutils > /dev/null +apt install -y python3.9 python3-distutils python3-pip > /dev/null +apt install --reinstall python3-distutils > /dev/null pip install --upgrade google-api-python-client > /dev/null -python3.8 -m pip install pip --upgrade > /dev/null -ln -s /usr/bin/python3.8 /usr/bin/python > /dev/null -apt install python3.8-venv > /dev/null +python3.9 -m pip install pip --upgrade > /dev/null +ln -s /usr/bin/python3.9 /usr/bin/python > /dev/null +apt install python3.9-venv > /dev/null LogOutput "Installing Python packages from beam/playground/infrastructure/requirements.txt" pip install -r $BEAM_ROOT_DIR/playground/infrastructure/requirements.txt diff --git a/release/src/main/Dockerfile b/release/src/main/Dockerfile index 14fe6fdb5a49..6503c5c42ba8 100644 --- a/release/src/main/Dockerfile +++ b/release/src/main/Dockerfile @@ -42,12 +42,11 @@ RUN curl https://pyenv.run | bash && \ echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> /root/.bashrc && \ echo ''eval "$(pyenv init -)"'' >> /root/.bashrc && \ source /root/.bashrc && \ - pyenv install 3.8.9 && \ pyenv install 3.9.4 && \ pyenv install 3.10.7 && \ pyenv install 3.11.3 && \ pyenv install 3.12.3 && \ - pyenv global 3.8.9 3.9.4 3.10.7 3.11.3 3.12.3 + pyenv global 3.9.4 3.10.7 3.11.3 3.12.3 # Install a Go version >= 1.16 so we can bootstrap higher # Go versions diff --git a/release/src/main/python-release/python_release_automation.sh b/release/src/main/python-release/python_release_automation.sh index 2f6986885a96..248bdd9b65ac 100755 --- a/release/src/main/python-release/python_release_automation.sh +++ b/release/src/main/python-release/python_release_automation.sh @@ -19,7 +19,7 @@ source release/src/main/python-release/run_release_candidate_python_quickstart.sh source release/src/main/python-release/run_release_candidate_python_mobile_gaming.sh -for version in 3.8 3.9 3.10 3.11 3.12 +for version in 3.9 3.10 3.11 3.12 do run_release_candidate_python_quickstart "tar" "python${version}" run_release_candidate_python_mobile_gaming "tar" "python${version}" diff --git a/sdks/python/apache_beam/__init__.py b/sdks/python/apache_beam/__init__.py index 6e08083bc0de..af88934b0e71 100644 --- a/sdks/python/apache_beam/__init__.py +++ b/sdks/python/apache_beam/__init__.py @@ -70,17 +70,11 @@ import warnings if sys.version_info.major == 3: - if sys.version_info.minor <= 7 or sys.version_info.minor >= 13: + if sys.version_info.minor <= 8 or sys.version_info.minor >= 13: warnings.warn( 'This version of Apache Beam has not been sufficiently tested on ' 'Python %s.%s. You may encounter bugs or missing features.' % (sys.version_info.major, sys.version_info.minor)) - elif sys.version_info.minor == 8: - warnings.warn( - 'Python 3.8 reaches EOL in October 2024 and support will ' - 'be removed from Apache Beam in version 2.61.0. See ' - 'https://github.com/apache/beam/issues/31192 for more ' - 'information.') pass else: raise RuntimeError( diff --git a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile index 5abbffdc5a2a..f27abbfd0051 100644 --- a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile +++ b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile @@ -40,7 +40,7 @@ RUN pip install openai vllm RUN apt install libcairo2-dev pkg-config python3-dev -y RUN pip install pycairo -# Copy the Apache Beam worker dependencies from the Beam Python 3.8 SDK image. +# Copy the Apache Beam worker dependencies from the Beam Python 3.12 SDK image. COPY --from=apache/beam_python3.12_sdk:2.58.1 /opt/apache/beam /opt/apache/beam # Set the entrypoint to Apache Beam SDK worker launcher. diff --git a/sdks/python/expansion-service-container/Dockerfile b/sdks/python/expansion-service-container/Dockerfile index 5a5ef0f410bc..4e82165f594c 100644 --- a/sdks/python/expansion-service-container/Dockerfile +++ b/sdks/python/expansion-service-container/Dockerfile @@ -17,7 +17,7 @@ ############################################################################### # We just need to support one Python version supported by Beam. -# Picking the current default Beam Python version which is Python 3.8. +# Picking the current default Beam Python version which is Python 3.9. FROM python:3.9-bookworm as expansion-service LABEL Author "Apache Beam " ARG TARGETOS diff --git a/sdks/python/setup.py b/sdks/python/setup.py index cac27db69803..9ae5d3153f51 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -155,7 +155,7 @@ def cythonize(*args, **kwargs): # Exclude 1.5.0 and 1.5.1 because of # https://github.com/pandas-dev/pandas/issues/45725 dataframe_dependency = [ - 'pandas>=1.4.3,!=1.5.0,!=1.5.1,<2.3;python_version>="3.8"', + 'pandas>=1.4.3,!=1.5.0,!=1.5.1,<2.3', ] @@ -271,18 +271,13 @@ def get_portability_package_data(): return files -python_requires = '>=3.8' +python_requires = '>=3.9' -if sys.version_info.major == 3 and sys.version_info.minor >= 12: +if sys.version_info.major == 3 and sys.version_info.minor >= 13: warnings.warn( 'This version of Apache Beam has not been sufficiently tested on ' 'Python %s.%s. You may encounter bugs or missing features.' % (sys.version_info.major, sys.version_info.minor)) -elif sys.version_info.major == 3 and sys.version_info.minor == 8: - warnings.warn('Python 3.8 reaches EOL in October 2024 and support will ' - 'be removed from Apache Beam in version 2.61.0. See ' - 'https://github.com/apache/beam/issues/31192 for more ' - 'information.') if __name__ == '__main__': # In order to find the tree of proto packages, the directory @@ -534,7 +529,6 @@ def get_portability_package_data(): 'Intended Audience :: End Users/Desktop', 'License :: OSI Approved :: Apache Software License', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', From 1a936c5d3ba2cae60fcb6874f49b5115d4838ead Mon Sep 17 00:00:00 2001 From: claudevdm <33973061+claudevdm@users.noreply.github.com> Date: Mon, 21 Oct 2024 15:06:21 -0400 Subject: [PATCH 045/181] Move biquery enrichment transform notebook to examples/notebooks/beam-ml (#32888) Co-authored-by: Claude --- .../notebooks/beam-ml/bigquery_enrichment_transform.ipynb | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename bigquery_enrichment_transform.ipynb => examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb (100%) diff --git a/bigquery_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb similarity index 100% rename from bigquery_enrichment_transform.ipynb rename to examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb From 56696ecd21ddd3b65c6e8bc70c0023bdf991dbe3 Mon Sep 17 00:00:00 2001 From: Damon Date: Mon, 21 Oct 2024 13:06:18 -0700 Subject: [PATCH 046/181] Enable BuildKit on gradle docker task (#32875) * Enable BuildKit on gradle docker task * Revert setting dockerTag --- .../main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy | 1 + 1 file changed, 1 insertion(+) diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy index cd46c1270f83..388069a03983 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy @@ -130,6 +130,7 @@ class BeamDockerPlugin implements Plugin { group = 'Docker' description = 'Builds Docker image.' dependsOn prepare + environment 'DOCKER_BUILDKIT', '1' }) Task tag = project.tasks.create('dockerTag', { From 3767eda41a00d3db5044e7b339fe17d64e5585ca Mon Sep 17 00:00:00 2001 From: Robert Burke Date: Mon, 21 Oct 2024 15:50:39 -0700 Subject: [PATCH 047/181] [prism] Dev prism builds for python and Python Direct Runner fallbacks. (#32876) --- .../runners/direct/direct_runner.py | 56 +++++++++ .../runners/portability/prism_runner.py | 112 +++++++++++++----- 2 files changed, 140 insertions(+), 28 deletions(-) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 49b6622816ce..8b8937653688 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -110,6 +110,38 @@ def visit_transform(self, applied_ptransform): if timer.time_domain == TimeDomain.REAL_TIME: self.supported_by_fnapi_runner = False + class _PrismRunnerSupportVisitor(PipelineVisitor): + """Visitor determining if a Pipeline can be run on the PrismRunner.""" + def accept(self, pipeline): + self.supported_by_prism_runner = True + pipeline.visit(self) + return self.supported_by_prism_runner + + def visit_transform(self, applied_ptransform): + transform = applied_ptransform.transform + # Python SDK assumes the direct runner TestStream implementation is + # being used. + if isinstance(transform, TestStream): + self.supported_by_prism_runner = False + if isinstance(transform, beam.ParDo): + dofn = transform.dofn + # It's uncertain if the Prism Runner supports execution of CombineFns + # with deferred side inputs. + if isinstance(dofn, CombineValuesDoFn): + args, kwargs = transform.raw_side_inputs + args_to_check = itertools.chain(args, kwargs.values()) + if any(isinstance(arg, ArgumentPlaceholder) + for arg in args_to_check): + self.supported_by_prism_runner = False + if userstate.is_stateful_dofn(dofn): + # https://github.com/apache/beam/issues/32786 - + # Remove once Real time clock is used. + _, timer_specs = userstate.get_dofn_specs(dofn) + for timer in timer_specs: + if timer.time_domain == TimeDomain.REAL_TIME: + self.supported_by_prism_runner = False + + tryingPrism = False # Check whether all transforms used in the pipeline are supported by the # FnApiRunner, and the pipeline was not meant to be run as streaming. if _FnApiRunnerSupportVisitor().accept(pipeline): @@ -122,9 +154,33 @@ def visit_transform(self, applied_ptransform): beam_provision_api_pb2.ProvisionInfo( pipeline_options=encoded_options)) runner = fn_runner.FnApiRunner(provision_info=provision_info) + elif _PrismRunnerSupportVisitor().accept(pipeline): + _LOGGER.info('Running pipeline with PrismRunner.') + from apache_beam.runners.portability import prism_runner + runner = prism_runner.PrismRunner() + tryingPrism = True else: runner = BundleBasedDirectRunner() + if tryingPrism: + try: + pr = runner.run_pipeline(pipeline, options) + # This is non-blocking, so if the state is *already* finished, something + # probably failed on job submission. + if pr.state.is_terminal() and pr.state != PipelineState.DONE: + _LOGGER.info( + 'Pipeline failed on PrismRunner, falling back toDirectRunner.') + runner = BundleBasedDirectRunner() + else: + return pr + except Exception as e: + # If prism fails in Preparing the portable job, then the PortableRunner + # code raises an exception. Catch it, log it, and use the Direct runner + # instead. + _LOGGER.info('Exception with PrismRunner:\n %s\n' % (e)) + _LOGGER.info('Falling back to DirectRunner') + runner = BundleBasedDirectRunner() + return runner.run_pipeline(pipeline, options) diff --git a/sdks/python/apache_beam/runners/portability/prism_runner.py b/sdks/python/apache_beam/runners/portability/prism_runner.py index eeccaf5748ce..77dc8a214e8e 100644 --- a/sdks/python/apache_beam/runners/portability/prism_runner.py +++ b/sdks/python/apache_beam/runners/portability/prism_runner.py @@ -27,6 +27,7 @@ import platform import shutil import stat +import subprocess import typing import urllib import zipfile @@ -167,38 +168,93 @@ def construct_download_url(self, root_tag: str, sys: str, mach: str) -> str: def path_to_binary(self) -> str: if self._path is not None: - if not os.path.exists(self._path): - url = urllib.parse.urlparse(self._path) - if not url.scheme: - raise ValueError( - 'Unable to parse binary URL "%s". If using a full URL, make ' - 'sure the scheme is specified. If using a local file xpath, ' - 'make sure the file exists; you may have to first build prism ' - 'using `go build `.' % (self._path)) - - # We have a URL, see if we need to construct a valid file name. - if self._path.startswith(GITHUB_DOWNLOAD_PREFIX): - # If this URL starts with the download prefix, let it through. - return self._path - # The only other valid option is a github release page. - if not self._path.startswith(GITHUB_TAG_PREFIX): - raise ValueError( - 'Provided --prism_location URL is not an Apache Beam Github ' - 'Release page URL or download URL: %s' % (self._path)) - # Get the root tag for this URL - root_tag = os.path.basename(os.path.normpath(self._path)) - return self.construct_download_url( - root_tag, platform.system(), platform.machine()) - return self._path - else: - if '.dev' in self._version: + # The path is overidden, check various cases. + if os.path.exists(self._path): + # The path is local and exists, use directly. + return self._path + + # Check if the path is a URL. + url = urllib.parse.urlparse(self._path) + if not url.scheme: + raise ValueError( + 'Unable to parse binary URL "%s". If using a full URL, make ' + 'sure the scheme is specified. If using a local file xpath, ' + 'make sure the file exists; you may have to first build prism ' + 'using `go build `.' % (self._path)) + + # We have a URL, see if we need to construct a valid file name. + if self._path.startswith(GITHUB_DOWNLOAD_PREFIX): + # If this URL starts with the download prefix, let it through. + return self._path + # The only other valid option is a github release page. + if not self._path.startswith(GITHUB_TAG_PREFIX): raise ValueError( - 'Unable to derive URL for dev versions "%s". Please provide an ' - 'alternate version to derive the release URL with the ' - '--prism_beam_version_override flag.' % (self._version)) + 'Provided --prism_location URL is not an Apache Beam Github ' + 'Release page URL or download URL: %s' % (self._path)) + # Get the root tag for this URL + root_tag = os.path.basename(os.path.normpath(self._path)) + return self.construct_download_url( + root_tag, platform.system(), platform.machine()) + + if '.dev' not in self._version: + # Not a development version, so construct the production download URL return self.construct_download_url( self._version, platform.system(), platform.machine()) + # This is a development version! Assume Go is installed. + # Set the install directory to the cache location. + envdict = {**os.environ, "GOBIN": self.BIN_CACHE} + PRISMPKG = "github.com/apache/beam/sdks/v2/go/cmd/prism" + + process = subprocess.run(["go", "install", PRISMPKG], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=envdict, + check=False) + if process.returncode == 0: + # Successfully installed + return '%s/prism' % (self.BIN_CACHE) + + # We failed to build for some reason. + output = process.stdout.decode("utf-8") + if ("not in a module" not in output) and ( + "no required module provides" not in output): + # This branch handles two classes of failures: + # 1. Go isn't installed, so it needs to be installed by the Beam SDK + # developer. + # 2. Go is installed, and they are building in a local version of Prism, + # but there was a compile error that the developer should address. + # Either way, the @latest fallback either would fail, or hide the error, + # so fail now. + _LOGGER.info(output) + raise ValueError( + 'Unable to install a local of Prism: "%s";\n' + 'Likely Go is not installed, or a local change to Prism did not ' + 'compile.\nPlease install Go (see https://go.dev/doc/install) to ' + 'enable automatic local builds.\n' + 'Alternatively provide a binary with the --prism_location flag.' + '\nCaptured output:\n %s' % (self._version, output)) + + # Go is installed and claims we're not in a Go module that has access to + # the Prism package. + + # Fallback to using the @latest version of prism, which works everywhere. + process = subprocess.run(["go", "install", PRISMPKG + "@latest"], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=envdict, + check=False) + + if process.returncode == 0: + return '%s/prism' % (self.BIN_CACHE) + + output = process.stdout.decode("utf-8") + raise ValueError( + 'We were unable to execute the subprocess "%s" to automatically ' + 'build prism.\nAlternatively provide an alternate binary with the ' + '--prism_location flag.' + '\nCaptured output:\n %s' % (process.args, output)) + def subprocess_cmd_and_endpoint( self) -> typing.Tuple[typing.List[typing.Any], str]: bin_path = self.local_bin( From 9fa9a4d0c147f70a3dd39985b05e3f8487483dfb Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Mon, 21 Oct 2024 21:33:34 -0400 Subject: [PATCH 048/181] Revert "Remove beam logging in playground" This reverts commit f151824100fcfd6375403c0ee46e60346df82411. --- playground/backend/internal/preparers/python_preparers.go | 2 +- playground/backend/internal/preparers/python_preparers_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/playground/backend/internal/preparers/python_preparers.go b/playground/backend/internal/preparers/python_preparers.go index 4b3d556af861..f050237492b1 100644 --- a/playground/backend/internal/preparers/python_preparers.go +++ b/playground/backend/internal/preparers/python_preparers.go @@ -26,7 +26,7 @@ import ( ) const ( - addLogHandlerCode = "" + addLogHandlerCode = "import logging\nlogging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" oneIndentation = " " findWithPipelinePattern = `(\s*)with.+Pipeline.+as (.+):` indentationPattern = `^(%s){0,1}\w+` diff --git a/playground/backend/internal/preparers/python_preparers_test.go b/playground/backend/internal/preparers/python_preparers_test.go index 549fe8431783..b2cfa7eccaac 100644 --- a/playground/backend/internal/preparers/python_preparers_test.go +++ b/playground/backend/internal/preparers/python_preparers_test.go @@ -53,7 +53,7 @@ func TestGetPythonPreparers(t *testing.T) { } func Test_addCodeToFile(t *testing.T) { - wantCode := pyCode + wantCode := "import logging\nlogging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" + pyCode type args struct { args []interface{} From 0fc6f3e9e8b18c5acd634b41592d82396aa5ddbf Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Mon, 21 Oct 2024 21:34:42 -0400 Subject: [PATCH 049/181] Set logging level of playground to ERROR --- playground/backend/internal/preparers/python_preparers.go | 2 +- playground/backend/internal/preparers/python_preparers_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/playground/backend/internal/preparers/python_preparers.go b/playground/backend/internal/preparers/python_preparers.go index f050237492b1..96a4ed32910a 100644 --- a/playground/backend/internal/preparers/python_preparers.go +++ b/playground/backend/internal/preparers/python_preparers.go @@ -26,7 +26,7 @@ import ( ) const ( - addLogHandlerCode = "import logging\nlogging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" + addLogHandlerCode = "import logging\nlogging.basicConfig(\n level=logging.ERROR,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" oneIndentation = " " findWithPipelinePattern = `(\s*)with.+Pipeline.+as (.+):` indentationPattern = `^(%s){0,1}\w+` diff --git a/playground/backend/internal/preparers/python_preparers_test.go b/playground/backend/internal/preparers/python_preparers_test.go index b2cfa7eccaac..f333a1639b7c 100644 --- a/playground/backend/internal/preparers/python_preparers_test.go +++ b/playground/backend/internal/preparers/python_preparers_test.go @@ -53,7 +53,7 @@ func TestGetPythonPreparers(t *testing.T) { } func Test_addCodeToFile(t *testing.T) { - wantCode := "import logging\nlogging.basicConfig(\n level=logging.INFO,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" + pyCode + wantCode := "import logging\nlogging.basicConfig(\n level=logging.ERROR,\n format=\"%(asctime)s [%(levelname)s] %(message)s\",\n handlers=[\n logging.FileHandler(\"logs.log\"),\n ]\n)\n" + pyCode type args struct { args []interface{} From 8d22fc2f72e6d0781eea465e773c542b5907686d Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Tue, 22 Oct 2024 00:57:48 -0700 Subject: [PATCH 050/181] Remove experiments guarding isolated channels enablement based on jobsettings (#32782) --- .../worker/StreamingDataflowWorker.java | 7 +------ .../client/grpc/GrpcDispatcherClient.java | 19 +++++-------------- .../client/grpc/GrpcDispatcherClientTest.java | 15 +-------------- 3 files changed, 7 insertions(+), 34 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index ecdba404151e..524906023722 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -140,8 +140,6 @@ public final class StreamingDataflowWorker { private static final int DEFAULT_STATUS_PORT = 8081; private static final Random CLIENT_ID_GENERATOR = new Random(); private static final String CHANNELZ_PATH = "/channelz"; - public static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL = - "streaming_engine_use_job_settings_for_heartbeat_pool"; private final WindmillStateCache stateCache; private final StreamingWorkerStatusPages statusPages; @@ -249,10 +247,7 @@ private StreamingDataflowWorker( GET_DATA_STREAM_TIMEOUT, windmillServer::getDataStream); getDataClient = new StreamPoolGetDataClient(getDataMetricTracker, getDataStreamPool); - // Experiment gates the logic till backend changes are rollback safe - if (!DataflowRunner.hasExperiment( - options, STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_HEARTBEAT_POOL) - || options.getUseSeparateWindmillHeartbeatStreams() != null) { + if (options.getUseSeparateWindmillHeartbeatStreams() != null) { heartbeatSender = StreamPoolHeartbeatSender.Create( Boolean.TRUE.equals(options.getUseSeparateWindmillHeartbeatStreams()) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java index f96464150d4a..6bae84483d16 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClient.java @@ -30,7 +30,6 @@ import java.util.concurrent.atomic.AtomicReference; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; -import org.apache.beam.runners.dataflow.DataflowRunner; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; @@ -53,8 +52,6 @@ public class GrpcDispatcherClient { private static final Logger LOG = LoggerFactory.getLogger(GrpcDispatcherClient.class); - static final String STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS = - "streaming_engine_use_job_settings_for_isolated_channels"; private final CountDownLatch onInitializedEndpoints; /** @@ -80,18 +77,12 @@ private GrpcDispatcherClient( DispatcherStubs initialDispatcherStubs, Random rand) { this.windmillStubFactoryFactory = windmillStubFactoryFactory; - if (DataflowRunner.hasExperiment( - options, STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS)) { - if (options.getUseWindmillIsolatedChannels() != null) { - this.useIsolatedChannels.set(options.getUseWindmillIsolatedChannels()); - this.reactToIsolatedChannelsJobSetting = false; - } else { - this.useIsolatedChannels.set(false); - this.reactToIsolatedChannelsJobSetting = true; - } - } else { - this.useIsolatedChannels.set(Boolean.TRUE.equals(options.getUseWindmillIsolatedChannels())); + if (options.getUseWindmillIsolatedChannels() != null) { + this.useIsolatedChannels.set(options.getUseWindmillIsolatedChannels()); this.reactToIsolatedChannelsJobSetting = false; + } else { + this.useIsolatedChannels.set(false); + this.reactToIsolatedChannelsJobSetting = true; } this.windmillStubFactory.set( windmillStubFactoryFactory.makeWindmillStubFactory(useIsolatedChannels.get())); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java index 3f746d91a868..c04456906ea2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDispatcherClientTest.java @@ -34,7 +34,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillStubFactoryFactoryImpl; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.hamcrest.Matcher; import org.junit.Test; @@ -55,9 +54,6 @@ public static class RespectsJobSettingTest { public void createsNewStubWhenIsolatedChannelsConfigIsChanged() { DataflowWorkerHarnessOptions options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); - options.setExperiments( - Lists.newArrayList( - GrpcDispatcherClient.STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS)); GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options)); // Create first time with Isolated channels disabled @@ -91,27 +87,18 @@ public static class RespectsPipelineOptionsTest { public static Collection data() { List list = new ArrayList<>(); for (Boolean pipelineOption : new Boolean[] {true, false}) { - list.add(new Object[] {/*experimentEnabled=*/ false, pipelineOption}); - list.add(new Object[] {/*experimentEnabled=*/ true, pipelineOption}); + list.add(new Object[] {pipelineOption}); } return list; } @Parameter(0) - public Boolean experimentEnabled; - - @Parameter(1) public Boolean pipelineOption; @Test public void ignoresIsolatedChannelsConfigWithPipelineOption() { DataflowWorkerHarnessOptions options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); - if (experimentEnabled) { - options.setExperiments( - Lists.newArrayList( - GrpcDispatcherClient.STREAMING_ENGINE_USE_JOB_SETTINGS_FOR_ISOLATED_CHANNELS)); - } options.setUseWindmillIsolatedChannels(pipelineOption); GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.create(options, new WindmillStubFactoryFactoryImpl(options)); From f28ca3ca10647ae5d98cfe65a196929c6972711d Mon Sep 17 00:00:00 2001 From: Damon Date: Tue, 22 Oct 2024 09:58:16 -0700 Subject: [PATCH 051/181] Add target parameter to BeamDockerPlugin (#32890) --- .../groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy index 388069a03983..b3949223f074 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamDockerPlugin.groovy @@ -59,6 +59,7 @@ class BeamDockerPlugin implements Plugin { boolean load = false boolean push = false String builder = null + String target = null File resolvedDockerfile = null File resolvedDockerComposeTemplate = null @@ -289,6 +290,9 @@ class BeamDockerPlugin implements Plugin { } else { buildCommandLine.addAll(['-t', "${-> ext.name}", '.']) } + if (ext.target != null && ext.target != "") { + buildCommandLine.addAll '--target', ext.target + } logger.debug("${buildCommandLine}" as String) return buildCommandLine } From 3ce219434a36cba3aca3b85808311f9573a820bb Mon Sep 17 00:00:00 2001 From: liferoad Date: Tue, 22 Oct 2024 15:27:30 -0400 Subject: [PATCH 052/181] Update generate_pydoc.sh Try this idea: https://stackoverflow.com/questions/68709496/searching-takes-forever-on-readthedocs-when-the-phrase-is-not-present-on-th --- sdks/python/scripts/generate_pydoc.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index 827df30861cb..21561e1bf6a9 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -124,7 +124,6 @@ extensions = [ ] master_doc = 'index' html_theme = 'sphinx_rtd_theme' -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] project = 'Apache Beam' version = beam_version.__version__ release = version From 4f4853e2a43354299e8a0c5822f00b3888cb485d Mon Sep 17 00:00:00 2001 From: pablo rodriguez defino Date: Tue, 22 Oct 2024 14:48:20 -0700 Subject: [PATCH 053/181] Support Map in BQ for StorageWrites API for Beam Rows (#32512) --- .../bigquery/BeamRowToStorageApiProto.java | 84 ++++++++-- .../BeamRowToStorageApiProtoTest.java | 152 +++++++++++++++++- .../io/gcp/bigquery/BigQueryUtilsTest.java | 12 ++ 3 files changed, 229 insertions(+), 19 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java index 7a5aa2408d2e..d7ca787feea3 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java @@ -31,7 +31,6 @@ import java.time.LocalTime; import java.time.temporal.ChronoUnit; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.function.BiFunction; import java.util.function.Function; @@ -221,11 +220,18 @@ private static TableFieldSchema fieldDescriptorFromBeamField(Field field) { case ITERABLE: @Nullable FieldType elementType = field.getType().getCollectionElementType(); if (elementType == null) { - throw new RuntimeException("Unexpected null element type!"); + throw new RuntimeException("Unexpected null element type on " + field.getName()); } + TypeName containedTypeName = + Preconditions.checkNotNull( + elementType.getTypeName(), + "Null type name found in contained type at " + field.getName()); Preconditions.checkState( - !Preconditions.checkNotNull(elementType.getTypeName()).isCollectionType(), - "Nested arrays not supported by BigQuery."); + !(containedTypeName.isCollectionType() || containedTypeName.isMapType()), + "Nested container types are not supported by BigQuery. Field " + + field.getName() + + " contains a type " + + containedTypeName.name()); TableFieldSchema elementFieldSchema = fieldDescriptorFromBeamField(Field.of(field.getName(), elementType)); builder = builder.setType(elementFieldSchema.getType()); @@ -244,7 +250,24 @@ private static TableFieldSchema fieldDescriptorFromBeamField(Field field) { builder = builder.setType(type); break; case MAP: - throw new RuntimeException("Map types not supported by BigQuery."); + @Nullable FieldType keyType = field.getType().getMapKeyType(); + @Nullable FieldType valueType = field.getType().getMapValueType(); + if (keyType == null) { + throw new RuntimeException( + "Unexpected null element type for the map's key on " + field.getName()); + } + if (valueType == null) { + throw new RuntimeException( + "Unexpected null element type for the map's value on " + field.getName()); + } + + builder = + builder + .setType(TableFieldSchema.Type.STRUCT) + .addFields(fieldDescriptorFromBeamField(Field.of("key", keyType))) + .addFields(fieldDescriptorFromBeamField(Field.of("value", valueType))) + .setMode(TableFieldSchema.Mode.REPEATED); + break; default: @Nullable TableFieldSchema.Type primitiveType = PRIMITIVE_TYPES.get(field.getType().getTypeName()); @@ -289,25 +312,34 @@ private static Object toProtoValue( case ROW: return messageFromBeamRow(fieldDescriptor.getMessageType(), (Row) value, null, -1); case ARRAY: - List list = (List) value; - @Nullable FieldType arrayElementType = beamFieldType.getCollectionElementType(); - if (arrayElementType == null) { - throw new RuntimeException("Unexpected null element type!"); - } - return list.stream() - .map(v -> toProtoValue(fieldDescriptor, arrayElementType, v)) - .collect(Collectors.toList()); case ITERABLE: Iterable iterable = (Iterable) value; @Nullable FieldType iterableElementType = beamFieldType.getCollectionElementType(); if (iterableElementType == null) { - throw new RuntimeException("Unexpected null element type!"); + throw new RuntimeException("Unexpected null element type: " + fieldDescriptor.getName()); } + return StreamSupport.stream(iterable.spliterator(), false) .map(v -> toProtoValue(fieldDescriptor, iterableElementType, v)) .collect(Collectors.toList()); case MAP: - throw new RuntimeException("Map types not supported by BigQuery."); + Map map = (Map) value; + @Nullable FieldType keyType = beamFieldType.getMapKeyType(); + @Nullable FieldType valueType = beamFieldType.getMapValueType(); + if (keyType == null) { + throw new RuntimeException("Unexpected null for key type: " + fieldDescriptor.getName()); + } + if (valueType == null) { + throw new RuntimeException( + "Unexpected null for value type: " + fieldDescriptor.getName()); + } + + return map.entrySet().stream() + .map( + (Map.Entry entry) -> + mapEntryToProtoValue( + fieldDescriptor.getMessageType(), keyType, valueType, entry)) + .collect(Collectors.toList()); default: return scalarToProtoValue(beamFieldType, value); } @@ -337,6 +369,28 @@ static Object scalarToProtoValue(FieldType beamFieldType, Object value) { } } + static Object mapEntryToProtoValue( + Descriptor descriptor, + FieldType keyFieldType, + FieldType valueFieldType, + Map.Entry entryValue) { + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + FieldDescriptor keyFieldDescriptor = + Preconditions.checkNotNull(descriptor.findFieldByName("key")); + @Nullable Object key = toProtoValue(keyFieldDescriptor, keyFieldType, entryValue.getKey()); + if (key != null) { + builder.setField(keyFieldDescriptor, key); + } + FieldDescriptor valueFieldDescriptor = + Preconditions.checkNotNull(descriptor.findFieldByName("value")); + @Nullable + Object value = toProtoValue(valueFieldDescriptor, valueFieldType, entryValue.getValue()); + if (value != null) { + builder.setField(valueFieldDescriptor, value); + } + return builder.build(); + } + static ByteString serializeBigDecimalToNumeric(BigDecimal o) { return serializeBigDecimal(o, NUMERIC_SCALE, MAX_NUMERIC_VALUE, MIN_NUMERIC_VALUE, "Numeric"); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java index 4013f0018553..d8c580a0cd18 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import com.google.protobuf.ByteString; import com.google.protobuf.DescriptorProtos.DescriptorProto; @@ -36,8 +37,11 @@ import java.time.LocalTime; import java.time.temporal.ChronoUnit; import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; @@ -284,12 +288,14 @@ public class BeamRowToStorageApiProtoTest { .addField("nested", FieldType.row(BASE_SCHEMA).withNullable(true)) .addField("nestedArray", FieldType.array(FieldType.row(BASE_SCHEMA))) .addField("nestedIterable", FieldType.iterable(FieldType.row(BASE_SCHEMA))) + .addField("nestedMap", FieldType.map(FieldType.STRING, FieldType.row(BASE_SCHEMA))) .build(); private static final Row NESTED_ROW = Row.withSchema(NESTED_SCHEMA) .withFieldValue("nested", BASE_ROW) .withFieldValue("nestedArray", ImmutableList.of(BASE_ROW, BASE_ROW)) .withFieldValue("nestedIterable", ImmutableList.of(BASE_ROW, BASE_ROW)) + .withFieldValue("nestedMap", ImmutableMap.of("key1", BASE_ROW, "key2", BASE_ROW)) .build(); @Test @@ -347,12 +353,12 @@ public void testNestedFromSchema() { .collect( Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel)); - assertEquals(3, types.size()); + assertEquals(4, types.size()); Map nestedTypes = descriptor.getNestedTypeList().stream() .collect(Collectors.toMap(DescriptorProto::getName, Functions.identity())); - assertEquals(3, nestedTypes.size()); + assertEquals(4, nestedTypes.size()); assertEquals(Type.TYPE_MESSAGE, types.get("nested")); assertEquals(Label.LABEL_OPTIONAL, typeLabels.get("nested")); String nestedTypeName1 = typeNames.get("nested"); @@ -379,6 +385,87 @@ public void testNestedFromSchema() { .collect( Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); assertEquals(expectedBaseTypes, nestedTypes3); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedmap")); + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmap")); + String nestedTypeName4 = typeNames.get("nestedmap"); + // expects 2 fields in the nested map, key and value + assertEquals(2, nestedTypes.get(nestedTypeName4).getFieldList().size()); + Supplier> stream = + () -> nestedTypes.get(nestedTypeName4).getFieldList().stream(); + assertTrue(stream.get().anyMatch(fdp -> fdp.getName().equals("key"))); + assertTrue(stream.get().anyMatch(fdp -> fdp.getName().equals("value"))); + + Map nestedTypes4 = + nestedTypes.get(nestedTypeName4).getNestedTypeList().stream() + .flatMap(vdesc -> vdesc.getFieldList().stream()) + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + assertEquals(expectedBaseTypes, nestedTypes4); + } + + @Test + public void testParticularMapsFromSchemas() { + Schema nestedMapSchemaVariations = + Schema.builder() + .addField( + "nestedMultiMap", + FieldType.map(FieldType.STRING, FieldType.array(FieldType.STRING))) + .addField( + "nestedMapNullable", + FieldType.map(FieldType.STRING, FieldType.DOUBLE).withNullable(true)) + .build(); + + DescriptorProto descriptor = + TableRowToStorageApiProto.descriptorSchemaFromTableSchema( + BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema((nestedMapSchemaVariations)), + true, + false); + + Map types = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType)); + Map typeNames = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getTypeName)); + Map typeLabels = + descriptor.getFieldList().stream() + .collect( + Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel)); + + Map nestedTypes = + descriptor.getNestedTypeList().stream() + .collect(Collectors.toMap(DescriptorProto::getName, Functions.identity())); + assertEquals(2, nestedTypes.size()); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedmultimap")); + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmultimap")); + String nestedMultiMapName = typeNames.get("nestedmultimap"); + // expects 2 fields for the nested array of maps, key and value + assertEquals(2, nestedTypes.get(nestedMultiMapName).getFieldList().size()); + Supplier> stream = + () -> nestedTypes.get(nestedMultiMapName).getFieldList().stream(); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("key")).count() == 1); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("value")).count() == 1); + assertTrue( + stream + .get() + .filter(fdp -> fdp.getName().equals("value")) + .filter(fdp -> fdp.getLabel().equals(Label.LABEL_REPEATED)) + .count() + == 1); + + assertEquals(Type.TYPE_MESSAGE, types.get("nestedmapnullable")); + // even though the field is marked as optional in the row we will should see repeated in proto + assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmapnullable")); + String nestedMapNullableName = typeNames.get("nestedmapnullable"); + // expects 2 fields in the nullable maps, key and value + assertEquals(2, nestedTypes.get(nestedMapNullableName).getFieldList().size()); + stream = () -> nestedTypes.get(nestedMapNullableName).getFieldList().stream(); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("key")).count() == 1); + assertTrue(stream.get().filter(fdp -> fdp.getName().equals("value")).count() == 1); } private void assertBaseRecord(DynamicMessage msg) { @@ -395,7 +482,7 @@ public void testMessageFromTableRow() throws Exception { BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema(NESTED_SCHEMA), true, false); DynamicMessage msg = BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW, null, -1); - assertEquals(3, msg.getAllFields().size()); + assertEquals(4, msg.getAllFields().size()); Map fieldDescriptors = descriptor.getFields().stream() @@ -404,6 +491,63 @@ public void testMessageFromTableRow() throws Exception { assertBaseRecord(nestedMsg); } + @Test + public void testMessageFromTableRowForArraysAndMaps() throws Exception { + Schema nestedMapSchemaVariations = + Schema.builder() + .addField("nestedArrayNullable", FieldType.array(FieldType.STRING).withNullable(true)) + .addField("nestedMap", FieldType.map(FieldType.STRING, FieldType.STRING)) + .addField( + "nestedMultiMap", + FieldType.map(FieldType.STRING, FieldType.iterable(FieldType.STRING))) + .addField( + "nestedMapNullable", + FieldType.map(FieldType.STRING, FieldType.DOUBLE).withNullable(true)) + .build(); + + Row nestedRow = + Row.withSchema(nestedMapSchemaVariations) + .withFieldValue("nestedArrayNullable", null) + .withFieldValue("nestedMap", ImmutableMap.of("key1", "value1")) + .withFieldValue( + "nestedMultiMap", + ImmutableMap.of("multikey1", ImmutableList.of("multivalue1", "multivalue2"))) + .withFieldValue("nestedMapNullable", null) + .build(); + + Descriptor descriptor = + TableRowToStorageApiProto.getDescriptorFromTableSchema( + BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema(nestedMapSchemaVariations), + true, + false); + DynamicMessage msg = + BeamRowToStorageApiProto.messageFromBeamRow(descriptor, nestedRow, null, -1); + + Map fieldDescriptors = + descriptor.getFields().stream() + .collect(Collectors.toMap(FieldDescriptor::getName, Functions.identity())); + + DynamicMessage nestedMapEntryMsg = + (DynamicMessage) msg.getRepeatedField(fieldDescriptors.get("nestedmap"), 0); + String value = + (String) + nestedMapEntryMsg.getField( + fieldDescriptors.get("nestedmap").getMessageType().findFieldByName("value")); + assertEquals("value1", value); + + DynamicMessage nestedMultiMapEntryMsg = + (DynamicMessage) msg.getRepeatedField(fieldDescriptors.get("nestedmultimap"), 0); + List values = + (List) + nestedMultiMapEntryMsg.getField( + fieldDescriptors.get("nestedmultimap").getMessageType().findFieldByName("value")); + assertTrue(values.size() == 2); + assertEquals("multivalue1", values.get(0)); + + assertTrue(msg.getRepeatedFieldCount(fieldDescriptors.get("nestedarraynullable")) == 0); + assertTrue(msg.getRepeatedFieldCount(fieldDescriptors.get("nestedmapnullable")) == 0); + } + @Test public void testCdcFields() throws Exception { Descriptor descriptor = @@ -413,7 +557,7 @@ public void testCdcFields() throws Exception { assertNotNull(descriptor.findFieldByName(StorageApiCDC.CHANGE_SQN_COLUMN)); DynamicMessage msg = BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW, "UPDATE", 42); - assertEquals(5, msg.getAllFields().size()); + assertEquals(6, msg.getAllFields().size()); Map fieldDescriptors = descriptor.getFields().stream() diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java index e26348b7b478..8b65e58a4601 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtilsTest.java @@ -698,6 +698,18 @@ public void testToTableSchema_map() { assertThat(field.getFields(), containsInAnyOrder(MAP_KEY, MAP_VALUE)); } + @Test + public void testToTableSchema_map_array() { + TableSchema schema = toTableSchema(MAP_ARRAY_TYPE); + + assertThat(schema.getFields().size(), equalTo(1)); + TableFieldSchema field = schema.getFields().get(0); + assertThat(field.getName(), equalTo("map")); + assertThat(field.getType(), equalTo(StandardSQLTypeName.STRUCT.toString())); + assertThat(field.getMode(), equalTo(Mode.REPEATED.toString())); + assertThat(field.getFields(), containsInAnyOrder(MAP_KEY, MAP_VALUE)); + } + @Test public void testToTableRow_flat() { TableRow row = toTableRow().apply(FLAT_ROW); From 078f87c09ca9f7bdbe3ed5bda25e1cf2dae4198d Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Tue, 22 Oct 2024 18:53:45 -0400 Subject: [PATCH 054/181] Fix java sdk container dependency for Python PostCommit (#32900) * Fix java sdk container dependency for Python PostCommit * trigger tests --- .github/trigger_files/beam_PostCommit_Python.json | 2 +- sdks/python/test-suites/portable/common.gradle | 5 +++-- .../site/content/en/documentation/runtime/environments.md | 8 ++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 9e1d1e1b80dd..30ee463ad4e9 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 4 + "modification": 2 } diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle index 5fd1b182a471..fbd65a1657cb 100644 --- a/sdks/python/test-suites/portable/common.gradle +++ b/sdks/python/test-suites/portable/common.gradle @@ -23,6 +23,7 @@ import org.apache.tools.ant.taskdefs.condition.Os def pythonRootDir = "${rootDir}/sdks/python" def pythonVersionSuffix = project.ext.pythonVersion.replace('.', '') def latestFlinkVersion = project.ext.latestFlinkVersion +def currentJavaVersion = project.ext.currentJavaVersion ext { pythonContainerTask = ":sdks:python:container:py${pythonVersionSuffix}:docker" @@ -369,7 +370,7 @@ project.tasks.register("postCommitPy${pythonVersionSuffix}IT") { 'setupVirtualenv', 'installGcpTest', ":runners:flink:${latestFlinkVersion}:job-server:shadowJar", - ':sdks:java:container:java8:docker', + ":sdks:java:container:${currentJavaVersion}:docker", ':sdks:java:testing:kafka-service:buildTestKafkaServiceJar', ':sdks:java:io:expansion-service:shadowJar', ':sdks:java:io:google-cloud-platform:expansion-service:shadowJar', @@ -420,7 +421,7 @@ project.tasks.register("xlangSpannerIOIT") { 'setupVirtualenv', 'installGcpTest', ":runners:flink:${latestFlinkVersion}:job-server:shadowJar", - ':sdks:java:container:java8:docker', + ":sdks:java:container:${currentJavaVersion}:docker", ':sdks:java:io:expansion-service:shadowJar', ':sdks:java:io:google-cloud-platform:expansion-service:shadowJar', ':sdks:java:io:kinesis:expansion-service:shadowJar', diff --git a/website/www/site/content/en/documentation/runtime/environments.md b/website/www/site/content/en/documentation/runtime/environments.md index d9a42db29e24..a048c21046ba 100644 --- a/website/www/site/content/en/documentation/runtime/environments.md +++ b/website/www/site/content/en/documentation/runtime/environments.md @@ -111,14 +111,14 @@ This method requires building image artifacts from Beam source. For additional i cd $BEAM_WORKDIR # The default repository of each SDK - ./gradlew :sdks:java:container:java8:docker ./gradlew :sdks:java:container:java11:docker ./gradlew :sdks:java:container:java17:docker + ./gradlew :sdks:java:container:java21:docker ./gradlew :sdks:go:container:docker - ./gradlew :sdks:python:container:py38:docker ./gradlew :sdks:python:container:py39:docker ./gradlew :sdks:python:container:py310:docker ./gradlew :sdks:python:container:py311:docker + ./gradlew :sdks:python:container:py312:docker # Shortcut for building all Python SDKs ./gradlew :sdks:python:container:buildAll @@ -168,9 +168,9 @@ builds the Python 3.6 container and tags it as `example-repo/beam_python3.6_sdk: From Beam 2.21.0 and later, a `docker-pull-licenses` flag was introduced to add licenses/notices for third party dependencies to the docker images. For example: ``` -./gradlew :sdks:java:container:java8:docker -Pdocker-pull-licenses +./gradlew :sdks:java:container:java11:docker -Pdocker-pull-licenses ``` -creates a Java 8 SDK image with appropriate licenses in `/opt/apache/beam/third_party_licenses/`. +creates a Java 11 SDK image with appropriate licenses in `/opt/apache/beam/third_party_licenses/`. By default, no licenses/notices are added to the docker images. From 5ad00572b7dd50db5aa56e5075878a54cd75931c Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Wed, 23 Oct 2024 13:03:04 -0400 Subject: [PATCH 055/181] Remove extraneous/incorrect ValidatesRunner classification from examples unit tests --- .../test/java/org/apache/beam/examples/WordCountTest.java | 3 --- .../beam/examples/complete/TopWikipediaSessionsTest.java | 3 --- .../apache/beam/examples/complete/game/GameStatsTest.java | 3 --- .../beam/examples/complete/game/HourlyTeamScoreTest.java | 3 --- .../apache/beam/examples/complete/game/UserScoreTest.java | 5 ----- .../beam/examples/cookbook/BigQueryTornadoesTest.java | 6 ------ .../apache/beam/examples/cookbook/DistinctExampleTest.java | 4 ---- .../apache/beam/examples/cookbook/FilterExamplesTest.java | 4 ---- .../org/apache/beam/examples/cookbook/JoinExamplesTest.java | 3 --- .../beam/examples/cookbook/MaxPerKeyExamplesTest.java | 4 ---- .../examples/cookbook/MinimalBigQueryTornadoesTest.java | 6 ------ .../apache/beam/examples/cookbook/TriggerExampleTest.java | 3 --- .../beam/examples/kotlin/cookbook/DistinctExampleTest.kt | 2 -- .../beam/examples/kotlin/cookbook/FilterExamplesTest.kt | 2 -- .../beam/examples/kotlin/cookbook/JoinExamplesTest.kt | 1 - .../beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt | 2 -- 16 files changed, 54 deletions(-) diff --git a/examples/java/src/test/java/org/apache/beam/examples/WordCountTest.java b/examples/java/src/test/java/org/apache/beam/examples/WordCountTest.java index 6b9916e87271..43b06347a39a 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/WordCountTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/WordCountTest.java @@ -25,7 +25,6 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.MapElements; @@ -33,7 +32,6 @@ import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -66,7 +64,6 @@ public void testExtractWordsFn() throws Exception { /** Example test that tests a PTransform by using an in-memory input and inspecting the output. */ @Test - @Category(ValidatesRunner.class) public void testCountWords() throws Exception { PCollection input = p.apply(Create.of(WORDS).withCoder(StringUtf8Coder.of())); diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/TopWikipediaSessionsTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/TopWikipediaSessionsTest.java index 96d10a1f72ed..614d289d2d60 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/TopWikipediaSessionsTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/TopWikipediaSessionsTest.java @@ -21,12 +21,10 @@ import java.util.Arrays; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,7 +35,6 @@ public class TopWikipediaSessionsTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testComputeTopUsers() { PCollection output = diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/game/GameStatsTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/game/GameStatsTest.java index 33d3c5699477..9c99c3aafdcc 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/game/GameStatsTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/game/GameStatsTest.java @@ -24,13 +24,11 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -69,7 +67,6 @@ public class GameStatsTest implements Serializable { /** Test the calculation of 'spammy users'. */ @Test - @Category(ValidatesRunner.class) public void testCalculateSpammyUsers() throws Exception { PCollection> input = p.apply(Create.of(USER_SCORES)); PCollection> output = input.apply(new CalculateSpammyUsers()); diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/game/HourlyTeamScoreTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/game/HourlyTeamScoreTest.java index 1d89351adcf8..46d7b41746ab 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/game/HourlyTeamScoreTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/game/HourlyTeamScoreTest.java @@ -26,7 +26,6 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Filter; import org.apache.beam.sdk.transforms.MapElements; @@ -37,7 +36,6 @@ import org.joda.time.Instant; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -90,7 +88,6 @@ public class HourlyTeamScoreTest implements Serializable { /** Test the filtering. */ @Test - @Category(ValidatesRunner.class) public void testUserScoresFilter() throws Exception { final Instant startMinTimestamp = new Instant(1447965680000L); diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/game/UserScoreTest.java b/examples/java/src/test/java/org/apache/beam/examples/complete/game/UserScoreTest.java index 04aa122054bd..22fe98a50304 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/complete/game/UserScoreTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/game/UserScoreTest.java @@ -26,7 +26,6 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; @@ -36,7 +35,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -114,7 +112,6 @@ public void testParseEventFn() throws Exception { /** Tests ExtractAndSumScore("user"). */ @Test - @Category(ValidatesRunner.class) public void testUserScoreSums() throws Exception { PCollection input = p.apply(Create.of(GAME_EVENTS)); @@ -133,7 +130,6 @@ public void testUserScoreSums() throws Exception { /** Tests ExtractAndSumScore("team"). */ @Test - @Category(ValidatesRunner.class) public void testTeamScoreSums() throws Exception { PCollection input = p.apply(Create.of(GAME_EVENTS)); @@ -152,7 +148,6 @@ public void testTeamScoreSums() throws Exception { /** Test that bad input data is dropped appropriately. */ @Test - @Category(ValidatesRunner.class) public void testUserScoresBadInput() throws Exception { PCollection input = p.apply(Create.of(GAME_EVENTS2).withCoder(StringUtf8Coder.of())); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesTest.java index 2bd37a3caa52..110349c353e8 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/BigQueryTornadoesTest.java @@ -22,7 +22,6 @@ import org.apache.beam.examples.cookbook.BigQueryTornadoes.FormatCountsFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; @@ -31,7 +30,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -41,7 +39,6 @@ public class BigQueryTornadoesTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testExtractTornadoes() { TableRow row = new TableRow().set("month", "6").set("tornado", true); PCollection input = p.apply(Create.of(ImmutableList.of(row))); @@ -51,7 +48,6 @@ public void testExtractTornadoes() { } @Test - @Category(ValidatesRunner.class) public void testNoTornadoes() { TableRow row = new TableRow().set("month", 6).set("tornado", false); PCollection inputs = p.apply(Create.of(ImmutableList.of(row))); @@ -61,7 +57,6 @@ public void testNoTornadoes() { } @Test - @Category(ValidatesRunner.class) public void testEmpty() { PCollection> inputs = p.apply(Create.empty(new TypeDescriptor>() {})); @@ -71,7 +66,6 @@ public void testEmpty() { } @Test - @Category(ValidatesRunner.class) public void testFormatCounts() { PCollection> inputs = p.apply(Create.of(KV.of(3, 0L), KV.of(4, Long.MAX_VALUE), KV.of(5, Long.MIN_VALUE))); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/DistinctExampleTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/DistinctExampleTest.java index 988a492ad4a9..7ec889f0d2b4 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/DistinctExampleTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/DistinctExampleTest.java @@ -22,13 +22,11 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Distinct; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -39,7 +37,6 @@ public class DistinctExampleTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testDistinct() { List strings = Arrays.asList("k1", "k5", "k5", "k2", "k1", "k2", "k3"); @@ -52,7 +49,6 @@ public void testDistinct() { } @Test - @Category(ValidatesRunner.class) public void testDistinctEmpty() { List strings = Arrays.asList(); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/FilterExamplesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/FilterExamplesTest.java index dedc0e313350..4eeb5c4b7dd0 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/FilterExamplesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/FilterExamplesTest.java @@ -22,13 +22,11 @@ import org.apache.beam.examples.cookbook.FilterExamples.ProjectionFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -68,7 +66,6 @@ public class FilterExamplesTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testProjectionFn() { PCollection input = p.apply(Create.of(row1, row2, row3)); @@ -79,7 +76,6 @@ public void testProjectionFn() { } @Test - @Category(ValidatesRunner.class) public void testFilterSingleMonthDataFn() { PCollection input = p.apply(Create.of(outRow1, outRow2, outRow3)); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/JoinExamplesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/JoinExamplesTest.java index b91ca985ddcd..d27572667752 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/JoinExamplesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/JoinExamplesTest.java @@ -24,14 +24,12 @@ import org.apache.beam.examples.cookbook.JoinExamples.ExtractEventDataFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -107,7 +105,6 @@ public void testExtractCountryInfoFn() throws Exception { } @Test - @Category(ValidatesRunner.class) public void testJoin() throws java.lang.Exception { PCollection input1 = p.apply("CreateEvent", Create.of(EVENT_ARRAY)); PCollection input2 = p.apply("CreateCC", Create.of(CC_ARRAY)); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MaxPerKeyExamplesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MaxPerKeyExamplesTest.java index 5c32f36660d6..410b151ed32f 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MaxPerKeyExamplesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MaxPerKeyExamplesTest.java @@ -23,7 +23,6 @@ import org.apache.beam.examples.cookbook.MaxPerKeyExamples.FormatMaxesFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; @@ -31,7 +30,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -78,7 +76,6 @@ public class MaxPerKeyExamplesTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testExtractTempFn() { PCollection> results = p.apply(Create.of(TEST_ROWS)).apply(ParDo.of(new ExtractTempFn())); @@ -87,7 +84,6 @@ public void testExtractTempFn() { } @Test - @Category(ValidatesRunner.class) public void testFormatMaxesFn() { PCollection results = p.apply(Create.of(TEST_KVS)).apply(ParDo.of(new FormatMaxesFn())); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MinimalBigQueryTornadoesTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MinimalBigQueryTornadoesTest.java index 7e922bc87965..fb08730a9f54 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/MinimalBigQueryTornadoesTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/MinimalBigQueryTornadoesTest.java @@ -22,7 +22,6 @@ import org.apache.beam.examples.cookbook.MinimalBigQueryTornadoes.FormatCountsFn; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; @@ -31,7 +30,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -41,7 +39,6 @@ public class MinimalBigQueryTornadoesTest { @Rule public TestPipeline p = TestPipeline.create(); @Test - @Category(ValidatesRunner.class) public void testExtractTornadoes() { TableRow row = new TableRow().set("month", "6").set("tornado", true); PCollection input = p.apply(Create.of(ImmutableList.of(row))); @@ -51,7 +48,6 @@ public void testExtractTornadoes() { } @Test - @Category(ValidatesRunner.class) public void testNoTornadoes() { TableRow row = new TableRow().set("month", 6).set("tornado", false); PCollection inputs = p.apply(Create.of(ImmutableList.of(row))); @@ -61,7 +57,6 @@ public void testNoTornadoes() { } @Test - @Category(ValidatesRunner.class) public void testEmpty() { PCollection> inputs = p.apply(Create.empty(new TypeDescriptor>() {})); @@ -71,7 +66,6 @@ public void testEmpty() { } @Test - @Category(ValidatesRunner.class) public void testFormatCounts() { PCollection> inputs = p.apply(Create.of(KV.of(3, 0L), KV.of(4, Long.MAX_VALUE), KV.of(5, Long.MIN_VALUE))); diff --git a/examples/java/src/test/java/org/apache/beam/examples/cookbook/TriggerExampleTest.java b/examples/java/src/test/java/org/apache/beam/examples/cookbook/TriggerExampleTest.java index 8f076a9d8d89..19c83c6eb73c 100644 --- a/examples/java/src/test/java/org/apache/beam/examples/cookbook/TriggerExampleTest.java +++ b/examples/java/src/test/java/org/apache/beam/examples/cookbook/TriggerExampleTest.java @@ -27,7 +27,6 @@ import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.ValidatesRunner; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; @@ -42,7 +41,6 @@ import org.joda.time.Instant; import org.junit.Rule; import org.junit.Test; -import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -118,7 +116,6 @@ public void testExtractTotalFlow() { } @Test - @Category(ValidatesRunner.class) public void testTotalFlow() { PCollection> flow = pipeline diff --git a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/DistinctExampleTest.kt b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/DistinctExampleTest.kt index 56702a3a1746..514727878a44 100644 --- a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/DistinctExampleTest.kt +++ b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/DistinctExampleTest.kt @@ -39,7 +39,6 @@ class DistinctExampleTest { fun pipeline(): TestPipeline = pipeline @Test - @Category(ValidatesRunner::class) fun testDistinct() { val strings = listOf("k1", "k5", "k5", "k2", "k1", "k2", "k3") val input = pipeline.apply(Create.of(strings).withCoder(StringUtf8Coder.of())) @@ -49,7 +48,6 @@ class DistinctExampleTest { } @Test - @Category(ValidatesRunner::class) fun testDistinctEmpty() { val strings = listOf() val input = pipeline.apply(Create.of(strings).withCoder(StringUtf8Coder.of())) diff --git a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/FilterExamplesTest.kt b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/FilterExamplesTest.kt index d5cb544a7606..8cfffe15fc65 100644 --- a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/FilterExamplesTest.kt +++ b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/FilterExamplesTest.kt @@ -64,7 +64,6 @@ class FilterExamplesTest { fun pipeline(): TestPipeline = pipeline @Test - @Category(ValidatesRunner::class) fun testProjectionFn() { val input = pipeline.apply(Create.of(row1, row2, row3)) val results = input.apply(ParDo.of(ProjectionFn())) @@ -73,7 +72,6 @@ class FilterExamplesTest { } @Test - @Category(ValidatesRunner::class) fun testFilterSingleMonthDataFn() { val input = pipeline.apply(Create.of(outRow1, outRow2, outRow3)) val results = input.apply(ParDo.of(FilterSingleMonthDataFn(7))) diff --git a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/JoinExamplesTest.kt b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/JoinExamplesTest.kt index 8728a827229b..6bb818f5efa9 100644 --- a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/JoinExamplesTest.kt +++ b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/JoinExamplesTest.kt @@ -93,7 +93,6 @@ class JoinExamplesTest { } @Test - @Category(ValidatesRunner::class) fun testJoin() { val input1 = pipeline.apply("CreateEvent", Create.of(EVENT_ARRAY)) val input2 = pipeline.apply("CreateCC", Create.of(CC_ARRAY)) diff --git a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt index 7995d9c1c795..409434e02686 100644 --- a/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt +++ b/examples/kotlin/src/test/java/org/apache/beam/examples/kotlin/cookbook/MaxPerKeyExamplesTest.kt @@ -69,7 +69,6 @@ class MaxPerKeyExamplesTest { fun pipeline(): TestPipeline = pipeline @Test - @Category(ValidatesRunner::class) fun testExtractTempFn() { val results = pipeline.apply(Create.of(testRows)).apply(ParDo.of>(MaxPerKeyExamples.ExtractTempFn())) PAssert.that(results).containsInAnyOrder(ImmutableList.of(kv1, kv2, kv3)) @@ -77,7 +76,6 @@ class MaxPerKeyExamplesTest { } @Test - @Category(ValidatesRunner::class) fun testFormatMaxesFn() { val results = pipeline.apply(Create.of(testKvs)).apply(ParDo.of, TableRow>(MaxPerKeyExamples.FormatMaxesFn())) PAssert.that(results).containsInAnyOrder(resultRow1, resultRow2, resultRow3) From 68b600dc779e05ffced112bcb2f354beab5b772f Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 23 Oct 2024 14:11:30 -0400 Subject: [PATCH 056/181] Doc fixes after 2.60 release (#32908) --- .github/workflows/README.md | 1 + contributor-docs/release-guide.md | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/README.md b/.github/workflows/README.md index d386f4dc40f9..971bfd857b27 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -285,6 +285,7 @@ Additional PreCommit jobs running basic SDK unit test on a matrices of operating | [Java Tests](https://github.com/apache/beam/actions/workflows/java_tests.yml) | [![.github/workflows/java_tests.yml](https://github.com/apache/beam/actions/workflows/java_tests.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/java_tests.yml?query=event%3Aschedule) | | [Python Tests](https://github.com/apache/beam/actions/workflows/python_tests.yml) | [![.github/workflows/python_tests.yml](https://github.com/apache/beam/actions/workflows/python_tests.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/python_tests.yml?query=event%3Aschedule) | | [TypeScript Tests](https://github.com/apache/beam/actions/workflows/typescript_tests.yml) | [![.github/workflows/typescript_tests.yml](https://github.com/apache/beam/actions/workflows/typescript_tests.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/typescript_tests.yml?query=event%3Aschedule) | +| [Build Wheels](https://github.com/apache/beam/actions/workflows/build_wheels.yml) | [![.github/workflows/build_wheels.yml](https://github.com/apache/beam/actions/workflows/build_wheels.yml/badge.svg?event=schedule)](https://github.com/apache/beam/actions/workflows/build_wheels.yml?query=event%3Aschedule) | ### PostCommit Jobs diff --git a/contributor-docs/release-guide.md b/contributor-docs/release-guide.md index 51f06adf50e4..df7f45cc8179 100644 --- a/contributor-docs/release-guide.md +++ b/contributor-docs/release-guide.md @@ -554,10 +554,10 @@ to PyPI with an `rc` suffix. __Attention:__ Verify that: - [ ] The File names version include ``rc-#`` suffix -- [ ] [Download Files](https://pypi.org/project/apache-beam/#files) have: - - [ ] All wheels uploaded as artifacts - - [ ] Release source's zip published - - [ ] Signatures and hashes do not need to be uploaded +- [Download Files](https://pypi.org/project/apache-beam/#files) have: +- [ ] All wheels uploaded as artifacts +- [ ] Release source's zip published +- [ ] Signatures and hashes do not need to be uploaded ### Propose pull requests for website updates @@ -1148,7 +1148,7 @@ All wheels should be published, in addition to the zip of the release source. ### Merge Website pull requests Merge all of the website pull requests -- [listing the release](/get-started/downloads/) +- [listing the release](https://beam.apache.org/get-started/downloads/) - publishing the [Python API reference manual](https://beam.apache.org/releases/pydoc/) and the [Java API reference manual](https://beam.apache.org/releases/javadoc/), and - adding the release blog post. From 0ee13b2865b9f9bd4a5113c883ea812b327f0bfa Mon Sep 17 00:00:00 2001 From: Naireen Hussain Date: Wed, 23 Oct 2024 12:06:22 -0700 Subject: [PATCH 057/181] Kafka metrics (#32402) * Add kafka poll latency metrics * Address Sam's comments [Dataflow Streaming] Use isolated windmill streams based on job settings (#32503) * Add kafka poll latency metrics * address comments * Ensure this is disabled for now until flag to enable it is explicitly passed --------- Co-authored-by: Naireen --- .../worker/build.gradle | 1 + ...icsToPerStepNamespaceMetricsConverter.java | 9 +- .../worker/StreamingDataflowWorker.java | 5 + .../dataflow/worker/streaming/StageInfo.java | 5 +- .../StreamingStepMetricsContainerTest.java | 39 +++++- .../beam/sdk/io/kafka/KafkaMetrics.java | 131 ++++++++++++++++++ .../beam/sdk/io/kafka/KafkaSinkMetrics.java | 89 ++++++++++++ .../sdk/io/kafka/KafkaUnboundedReader.java | 32 ++++- .../beam/sdk/io/kafka/KafkaMetricsTest.java | 129 +++++++++++++++++ .../sdk/io/kafka/KafkaSinkMetricsTest.java | 43 ++++++ 10 files changed, 476 insertions(+), 7 deletions(-) create mode 100644 sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaMetrics.java create mode 100644 sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetrics.java create mode 100644 sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMetricsTest.java create mode 100644 sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetricsTest.java diff --git a/runners/google-cloud-dataflow-java/worker/build.gradle b/runners/google-cloud-dataflow-java/worker/build.gradle index b7e6e981effe..92beccd067e2 100644 --- a/runners/google-cloud-dataflow-java/worker/build.gradle +++ b/runners/google-cloud-dataflow-java/worker/build.gradle @@ -54,6 +54,7 @@ def sdk_provided_project_dependencies = [ ":runners:google-cloud-dataflow-java", ":sdks:java:extensions:avro", ":sdks:java:extensions:google-cloud-platform-core", + ":sdks:java:io:kafka", // For metric propagation into worker ":sdks:java:io:google-cloud-platform", ] diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java index 30e920119120..77f867793ae2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/MetricsToPerStepNamespaceMetricsConverter.java @@ -32,13 +32,15 @@ import java.util.Map.Entry; import java.util.Optional; import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics; +import org.apache.beam.sdk.io.kafka.KafkaSinkMetrics; import org.apache.beam.sdk.metrics.LabeledMetricNameUtils; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.util.HistogramData; /** * Converts metric updates to {@link PerStepNamespaceMetrics} protos. Currently we only support - * converting metrics from {@link BigQuerySinkMetrics} with this converter. + * converting metrics from {@link BigQuerySinkMetrics} and from {@link KafkaSinkMetrics} with this + * converter. */ public class MetricsToPerStepNamespaceMetricsConverter { @@ -65,7 +67,10 @@ private static Optional convertCounterToMetricValue( MetricName metricName, Long value, Map parsedPerWorkerMetricsCache) { - if (value == 0 || !metricName.getNamespace().equals(BigQuerySinkMetrics.METRICS_NAMESPACE)) { + + if (value == 0 + || (!metricName.getNamespace().equals(BigQuerySinkMetrics.METRICS_NAMESPACE) + && !metricName.getNamespace().equals(KafkaSinkMetrics.METRICS_NAMESPACE))) { return Optional.empty(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 524906023722..c478341c1c39 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -93,6 +93,7 @@ import org.apache.beam.sdk.fn.JvmInitializers; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics; +import org.apache.beam.sdk.io.kafka.KafkaSinkMetrics; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.util.construction.CoderTranslation; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; @@ -663,6 +664,10 @@ public static void main(String[] args) throws Exception { enableBigQueryMetrics(); } + if (DataflowRunner.hasExperiment(options, "enable_kafka_metrics")) { + KafkaSinkMetrics.setSupportKafkaMetrics(true); + } + JvmInitializers.runBeforeProcessing(options); worker.startStatusPages(); worker.start(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java index a18ca8cfd6dc..525464ef2e1f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/StageInfo.java @@ -35,6 +35,7 @@ import org.apache.beam.runners.dataflow.worker.counters.DataflowCounterUpdateExtractor; import org.apache.beam.runners.dataflow.worker.counters.NameContext; import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics; +import org.apache.beam.sdk.io.kafka.KafkaSinkMetrics; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; /** Contains a few of the stage specific fields. E.g. metrics container registry, counters etc. */ @@ -118,7 +119,9 @@ public List extractPerWorkerMetricValues() { private void translateKnownPerWorkerCounters(List metrics) { for (PerStepNamespaceMetrics perStepnamespaceMetrics : metrics) { if (!BigQuerySinkMetrics.METRICS_NAMESPACE.equals( - perStepnamespaceMetrics.getMetricsNamespace())) { + perStepnamespaceMetrics.getMetricsNamespace()) + && !KafkaSinkMetrics.METRICS_NAMESPACE.equals( + perStepnamespaceMetrics.getMetricsNamespace())) { continue; } for (MetricValue metric : perStepnamespaceMetrics.getMetricValues()) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java index 2d5a8d8266ae..37c5ad261280 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingStepMetricsContainerTest.java @@ -366,7 +366,6 @@ public void testExtractPerWorkerMetricUpdates_populatedMetrics() { .setMetricsNamespace("BigQuerySink") .setMetricValues(Collections.singletonList(expectedCounter)); - // Expected histogram metric List bucketCounts = Collections.singletonList(1L); Linear linearOptions = new Linear().setNumberOfBuckets(10).setWidth(10.0).setStart(0.0); @@ -393,6 +392,44 @@ public void testExtractPerWorkerMetricUpdates_populatedMetrics() { assertThat(updates, containsInAnyOrder(histograms, counters)); } + @Test + public void testExtractPerWorkerMetricUpdatesKafka_populatedMetrics() { + StreamingStepMetricsContainer.setEnablePerWorkerMetrics(true); + + MetricName histogramMetricName = MetricName.named("KafkaSink", "histogram"); + HistogramData.LinearBuckets linearBuckets = HistogramData.LinearBuckets.of(0, 10, 10); + c2.getPerWorkerHistogram(histogramMetricName, linearBuckets).update(5.0); + + Iterable updates = + StreamingStepMetricsContainer.extractPerWorkerMetricUpdates(registry); + + // Expected histogram metric + List bucketCounts = Collections.singletonList(1L); + + Linear linearOptions = new Linear().setNumberOfBuckets(10).setWidth(10.0).setStart(0.0); + BucketOptions bucketOptions = new BucketOptions().setLinear(linearOptions); + + DataflowHistogramValue linearHistogram = + new DataflowHistogramValue() + .setCount(1L) + .setBucketOptions(bucketOptions) + .setBucketCounts(bucketCounts); + + MetricValue expectedHistogram = + new MetricValue() + .setMetric("histogram") + .setMetricLabels(new HashMap<>()) + .setValueHistogram(linearHistogram); + + PerStepNamespaceMetrics histograms = + new PerStepNamespaceMetrics() + .setOriginalStep("s2") + .setMetricsNamespace("KafkaSink") + .setMetricValues(Collections.singletonList(expectedHistogram)); + + assertThat(updates, containsInAnyOrder(histograms)); + } + @Test public void testExtractPerWorkerMetricUpdates_emptyMetrics() { StreamingStepMetricsContainer.setEnablePerWorkerMetrics(true); diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaMetrics.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaMetrics.java new file mode 100644 index 000000000000..147a30dcdd1a --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaMetrics.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import com.google.auto.value.AutoValue; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Stores and exports metrics for a batch of Kafka Client RPCs. */ +public interface KafkaMetrics { + + void updateSuccessfulRpcMetrics(String topic, Duration elapsedTime); + + void updateKafkaMetrics(); + + /** No-op implementation of {@code KafkaResults}. */ + class NoOpKafkaMetrics implements KafkaMetrics { + private NoOpKafkaMetrics() {} + + @Override + public void updateSuccessfulRpcMetrics(String topic, Duration elapsedTime) {} + + @Override + public void updateKafkaMetrics() {} + + private static NoOpKafkaMetrics singleton = new NoOpKafkaMetrics(); + + static NoOpKafkaMetrics getInstance() { + return singleton; + } + } + + /** + * Metrics of a batch of RPCs. Member variables are thread safe; however, this class does not have + * atomicity across member variables. + * + *

Expected usage: A number of threads record metrics in an instance of this class with the + * member methods. Afterwards, a single thread should call {@code updateStreamingInsertsMetrics} + * which will export all counters metrics and RPC latency distribution metrics to the underlying + * {@code perWorkerMetrics} container. Afterwards, metrics should not be written/read from this + * object. + */ + @AutoValue + abstract class KafkaMetricsImpl implements KafkaMetrics { + + private static final Logger LOG = LoggerFactory.getLogger(KafkaMetricsImpl.class); + + static HashMap latencyHistograms = new HashMap(); + + abstract HashMap> perTopicRpcLatencies(); + + abstract AtomicBoolean isWritable(); + + public static KafkaMetricsImpl create() { + return new AutoValue_KafkaMetrics_KafkaMetricsImpl( + new HashMap>(), new AtomicBoolean(true)); + } + + /** Record the rpc status and latency of a successful Kafka poll RPC call. */ + @Override + public void updateSuccessfulRpcMetrics(String topic, Duration elapsedTime) { + if (isWritable().get()) { + ConcurrentLinkedQueue latencies = perTopicRpcLatencies().get(topic); + if (latencies == null) { + latencies = new ConcurrentLinkedQueue(); + latencies.add(elapsedTime); + perTopicRpcLatencies().put(topic, latencies); + } else { + latencies.add(elapsedTime); + } + } + } + + /** Record rpc latency histogram metrics for all recorded topics. */ + private void recordRpcLatencyMetrics() { + for (Map.Entry> topicLatencies : + perTopicRpcLatencies().entrySet()) { + Histogram topicHistogram; + if (latencyHistograms.containsKey(topicLatencies.getKey())) { + topicHistogram = latencyHistograms.get(topicLatencies.getKey()); + } else { + topicHistogram = + KafkaSinkMetrics.createRPCLatencyHistogram( + KafkaSinkMetrics.RpcMethod.POLL, topicLatencies.getKey()); + latencyHistograms.put(topicLatencies.getKey(), topicHistogram); + } + // update all the latencies + for (Duration d : topicLatencies.getValue()) { + Preconditions.checkArgumentNotNull(topicHistogram); + topicHistogram.update(d.toMillis()); + } + } + } + + /** + * Export all metrics recorded in this instance to the underlying {@code perWorkerMetrics} + * containers. This function will only report metrics once per instance. Subsequent calls to + * this function will no-op. + */ + @Override + public void updateKafkaMetrics() { + if (!isWritable().compareAndSet(true, false)) { + LOG.warn("Updating stale Kafka metrics container"); + return; + } + recordRpcLatencyMetrics(); + } + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetrics.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetrics.java new file mode 100644 index 000000000000..f71926f97d27 --- /dev/null +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetrics.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import org.apache.beam.sdk.metrics.DelegatingHistogram; +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.metrics.LabeledMetricNameUtils; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.util.HistogramData; + +/** + * Helper class to create per worker metrics for Kafka Sink stages. + * + *

Metrics will be in the namespace 'KafkaSink' and have their name formatted as: + * + *

'{baseName}-{metricLabelKey1}:{metricLabelVal1};...{metricLabelKeyN}:{metricLabelValN};' ???? + */ + +// TODO, refactor out common parts for BQ sink, so it can be reused with other sinks, eg, GCS? +// @SuppressWarnings("unused") +public class KafkaSinkMetrics { + private static boolean supportKafkaMetrics = false; + + public static final String METRICS_NAMESPACE = "KafkaSink"; + + // Base Metric names + private static final String RPC_LATENCY = "RpcLatency"; + + // Kafka Consumer Method names + enum RpcMethod { + POLL, + } + + // Metric labels + private static final String TOPIC_LABEL = "topic_name"; + private static final String RPC_METHOD = "rpc_method"; + + /** + * Creates an Histogram metric to record RPC latency. Metric will have name. + * + *

'RpcLatency*rpc_method:{method};topic_name:{topic};' + * + * @param method Kafka method associated with this metric. + * @param topic Kafka topic associated with this metric. + * @return Histogram with exponential buckets with a sqrt(2) growth factor. + */ + public static Histogram createRPCLatencyHistogram(RpcMethod method, String topic) { + LabeledMetricNameUtils.MetricNameBuilder nameBuilder = + LabeledMetricNameUtils.MetricNameBuilder.baseNameBuilder(RPC_LATENCY); + nameBuilder.addLabel(RPC_METHOD, method.toString()); + nameBuilder.addLabel(TOPIC_LABEL, topic); + + MetricName metricName = nameBuilder.build(METRICS_NAMESPACE); + HistogramData.BucketType buckets = HistogramData.ExponentialBuckets.of(1, 17); + + return new DelegatingHistogram(metricName, buckets, false, true); + } + + /** + * Returns a container to store metrics for Kafka metrics in Unbounded Readed. If these metrics + * are disabled, then we return a no-op container. + */ + static KafkaMetrics kafkaMetrics() { + if (supportKafkaMetrics) { + return KafkaMetrics.KafkaMetricsImpl.create(); + } else { + return KafkaMetrics.NoOpKafkaMetrics.getInstance(); + } + } + + public static void setSupportKafkaMetrics(boolean supportKafkaMetrics) { + KafkaSinkMetrics.supportKafkaMetrics = supportKafkaMetrics; + } +} diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java index fed03047cf16..6ce6c7d5d233 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java @@ -53,6 +53,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Stopwatch; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; @@ -144,7 +145,6 @@ public boolean start() throws IOException { offsetFetcherThread.scheduleAtFixedRate( this::updateLatestOffsets, 0, OFFSET_UPDATE_INTERVAL_SECONDS, TimeUnit.SECONDS); - return advance(); } @@ -158,6 +158,9 @@ public boolean advance() throws IOException { */ while (true) { if (curBatch.hasNext()) { + // Initalize metrics container. + kafkaResults = KafkaSinkMetrics.kafkaMetrics(); + PartitionState pState = curBatch.next(); if (!pState.recordIter.hasNext()) { // -- (c) @@ -228,8 +231,10 @@ public boolean advance() throws IOException { for (Map.Entry backlogSplit : perPartitionBacklogMetrics.entrySet()) { backlogBytesOfSplit.set(backlogSplit.getValue()); } - return true; + // Pass metrics to container. + kafkaResults.updateKafkaMetrics(); + return true; } else { // -- (b) nextBatch(); @@ -377,6 +382,7 @@ public long getSplitBacklogBytes() { .setDaemon(true) .setNameFormat("KafkaConsumerPoll-thread") .build()); + private AtomicReference consumerPollException = new AtomicReference<>(); private final SynchronousQueue> availableRecordsQueue = new SynchronousQueue<>(); @@ -399,6 +405,11 @@ public long getSplitBacklogBytes() { /** watermark before any records have been read. */ private static Instant initialWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; + // Created in each next batch, and updated at the end. + public KafkaMetrics kafkaResults = KafkaSinkMetrics.kafkaMetrics(); + private Stopwatch stopwatch = Stopwatch.createUnstarted(); + private String kafkaTopic = ""; + @Override public String toString() { return name; @@ -509,6 +520,13 @@ String name() { List partitions = Preconditions.checkArgumentNotNull(source.getSpec().getTopicPartitions()); + + // Each source has a single unique topic. + for (TopicPartition topicPartition : partitions) { + this.kafkaTopic = topicPartition.topic(); + break; + } + List> states = new ArrayList<>(partitions.size()); if (checkpointMark != null) { @@ -568,7 +586,16 @@ private void consumerPollLoop() { while (!closed.get()) { try { if (records.isEmpty()) { + // Each source has a single unique topic. + List topicPartitions = source.getSpec().getTopicPartitions(); + Preconditions.checkStateNotNull(topicPartitions); + + stopwatch.start(); records = consumer.poll(KAFKA_POLL_TIMEOUT.getMillis()); + stopwatch.stop(); + kafkaResults.updateSuccessfulRpcMetrics( + kafkaTopic, java.time.Duration.ofMillis(stopwatch.elapsed(TimeUnit.MILLISECONDS))); + } else if (availableRecordsQueue.offer( records, RECORDS_ENQUEUE_POLL_TIMEOUT.getMillis(), TimeUnit.MILLISECONDS)) { records = ConsumerRecords.empty(); @@ -592,7 +619,6 @@ private void consumerPollLoop() { private void commitCheckpointMark() { KafkaCheckpointMark checkpointMark = finalizedCheckpointMark.getAndSet(null); - if (checkpointMark != null) { LOG.debug("{}: Committing finalized checkpoint {}", this, checkpointMark); Consumer consumer = Preconditions.checkStateNotNull(this.consumer); diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMetricsTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMetricsTest.java new file mode 100644 index 000000000000..b84e143be773 --- /dev/null +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMetricsTest.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.beam.runners.core.metrics.MetricsContainerImpl; +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.MetricsEnvironment; +import org.apache.beam.sdk.util.HistogramData; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link KafkaSinkMetrics}. */ +// TODO:Naireen - Refactor to remove duplicate code between the two sinks +@RunWith(JUnit4.class) +public class KafkaMetricsTest { + public static class TestHistogram implements Histogram { + public List values = Lists.newArrayList(); + private MetricName metricName = MetricName.named("KafkaSink", "name"); + + @Override + public void update(double value) { + values.add(value); + } + + @Override + public MetricName getName() { + return metricName; + } + } + + public static class TestMetricsContainer extends MetricsContainerImpl { + public ConcurrentHashMap, TestHistogram> + perWorkerHistograms = + new ConcurrentHashMap, TestHistogram>(); + + public TestMetricsContainer() { + super("TestStep"); + } + + @Override + public Histogram getPerWorkerHistogram( + MetricName metricName, HistogramData.BucketType bucketType) { + perWorkerHistograms.computeIfAbsent(KV.of(metricName, bucketType), kv -> new TestHistogram()); + return perWorkerHistograms.get(KV.of(metricName, bucketType)); + } + + @Override + public void reset() { + perWorkerHistograms.clear(); + } + } + + @Test + public void testNoOpKafkaMetrics() throws Exception { + TestMetricsContainer testContainer = new TestMetricsContainer(); + MetricsEnvironment.setCurrentContainer(testContainer); + + KafkaMetrics results = KafkaMetrics.NoOpKafkaMetrics.getInstance(); + results.updateSuccessfulRpcMetrics("test-topic", Duration.ofMillis(10)); + + results.updateKafkaMetrics(); + + assertThat(testContainer.perWorkerHistograms.size(), equalTo(0)); + } + + @Test + public void testKafkaRPCLatencyMetrics() throws Exception { + TestMetricsContainer testContainer = new TestMetricsContainer(); + MetricsEnvironment.setCurrentContainer(testContainer); + + KafkaSinkMetrics.setSupportKafkaMetrics(true); + + KafkaMetrics results = KafkaSinkMetrics.kafkaMetrics(); + + results.updateSuccessfulRpcMetrics("test-topic", Duration.ofMillis(10)); + + results.updateKafkaMetrics(); + // RpcLatency*rpc_method:POLL;topic_name:test-topic + MetricName histogramName = + MetricName.named("KafkaSink", "RpcLatency*rpc_method:POLL;topic_name:test-topic;"); + HistogramData.BucketType bucketType = HistogramData.ExponentialBuckets.of(1, 17); + + assertThat(testContainer.perWorkerHistograms.size(), equalTo(1)); + assertThat( + testContainer.perWorkerHistograms.get(KV.of(histogramName, bucketType)).values, + containsInAnyOrder(Double.valueOf(10.0))); + } + + @Test + public void testKafkaRPCLatencyMetricsAreNotRecorded() throws Exception { + TestMetricsContainer testContainer = new TestMetricsContainer(); + MetricsEnvironment.setCurrentContainer(testContainer); + + KafkaSinkMetrics.setSupportKafkaMetrics(false); + + KafkaMetrics results = KafkaSinkMetrics.kafkaMetrics(); + + results.updateSuccessfulRpcMetrics("test-topic", Duration.ofMillis(10)); + + results.updateKafkaMetrics(); + assertThat(testContainer.perWorkerHistograms.size(), equalTo(0)); + } +} diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetricsTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetricsTest.java new file mode 100644 index 000000000000..625a75c5316b --- /dev/null +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSinkMetricsTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.metrics.MetricName; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link KafkaSinkMetrics}. */ +// TODO:Naireen - Refactor to remove duplicate code between the Kafka and BigQuery sinks +@RunWith(JUnit4.class) +public class KafkaSinkMetricsTest { + @Test + public void testCreatingHistogram() throws Exception { + + Histogram histogram = + KafkaSinkMetrics.createRPCLatencyHistogram(KafkaSinkMetrics.RpcMethod.POLL, "topic1"); + + MetricName histogramName = + MetricName.named("KafkaSink", "RpcLatency*rpc_method:POLL;topic_name:topic1;"); + assertThat(histogram.getName(), equalTo(histogramName)); + } +} From d3478f1e3866db1e0c1315c6899b31c8dacd39d6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 12:51:29 -0700 Subject: [PATCH 058/181] Bump google.golang.org/api from 0.199.0 to 0.202.0 in /sdks (#32906) Bumps [google.golang.org/api](https://github.com/googleapis/google-api-go-client) from 0.199.0 to 0.202.0. - [Release notes](https://github.com/googleapis/google-api-go-client/releases) - [Changelog](https://github.com/googleapis/google-api-go-client/blob/main/CHANGES.md) - [Commits](https://github.com/googleapis/google-api-go-client/compare/v0.199.0...v0.202.0) --- updated-dependencies: - dependency-name: google.golang.org/api dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- sdks/go.mod | 15 +++++++-------- sdks/go.sum | 36 ++++++++++++++++++------------------ 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index 706be73f97f6..5de614a2f503 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -58,8 +58,8 @@ require ( golang.org/x/sync v0.8.0 golang.org/x/sys v0.26.0 golang.org/x/text v0.19.0 - google.golang.org/api v0.199.0 - google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 + google.golang.org/api v0.202.0 + google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53 google.golang.org/grpc v1.67.1 google.golang.org/protobuf v1.35.1 gopkg.in/yaml.v2 v2.4.0 @@ -70,13 +70,12 @@ require ( github.com/avast/retry-go/v4 v4.6.0 github.com/fsouza/fake-gcs-server v1.49.2 github.com/golang-cz/devslog v0.0.11 - github.com/golang/protobuf v1.5.4 golang.org/x/exp v0.0.0-20231006140011-7918f672742d ) require ( cel.dev/expr v0.16.1 // indirect - cloud.google.com/go/auth v0.9.5 // indirect + cloud.google.com/go/auth v0.9.8 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect cloud.google.com/go/monitoring v1.21.1 // indirect dario.cat/mergo v1.0.0 // indirect @@ -118,12 +117,12 @@ require ( go.opentelemetry.io/otel/sdk v1.29.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.29.0 // indirect go.opentelemetry.io/otel/trace v1.29.0 // indirect - golang.org/x/time v0.6.0 // indirect + golang.org/x/time v0.7.0 // indirect google.golang.org/grpc/stats/opentelemetry v0.0.0-20240907200651-3ffb98b2c93a // indirect ) require ( - cloud.google.com/go v0.115.1 // indirect + cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/compute/metadata v0.5.2 // indirect cloud.google.com/go/iam v1.2.1 // indirect cloud.google.com/go/longrunning v0.6.1 // indirect @@ -195,6 +194,6 @@ require ( golang.org/x/mod v0.20.0 // indirect golang.org/x/tools v0.24.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 // indirect ) diff --git a/sdks/go.sum b/sdks/go.sum index fa3c75bd3395..542ccb44d6f8 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -38,8 +38,8 @@ cloud.google.com/go v0.104.0/go.mod h1:OO6xxXdJyvuJPcEPBLN9BJPD+jep5G1+2U5B5gkRY cloud.google.com/go v0.105.0/go.mod h1:PrLgOJNe5nfE9UMxKxgXj4mD3voiP+YQ6gdt6KMFOKM= cloud.google.com/go v0.107.0/go.mod h1:wpc2eNrD7hXUTy8EKS10jkxpZBjASrORK7goS+3YX2I= cloud.google.com/go v0.110.0/go.mod h1:SJnCLqQ0FCFGSZMUNUf84MV3Aia54kn7pi8st7tMzaY= -cloud.google.com/go v0.115.1 h1:Jo0SM9cQnSkYfp44+v+NQXHpcHqlnRJk2qxh6yvxxxQ= -cloud.google.com/go v0.115.1/go.mod h1:DuujITeaufu3gL68/lOFIirVNJwQeyf5UXyi+Wbgknc= +cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= +cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= cloud.google.com/go/accessapproval v1.4.0/go.mod h1:zybIuC3KpDOvotz59lFe5qxRZx6C75OtwbisN56xYB4= cloud.google.com/go/accessapproval v1.5.0/go.mod h1:HFy3tuiGvMdcd/u+Cu5b9NkO1pEICJ46IR82PoUdplw= cloud.google.com/go/accessapproval v1.6.0/go.mod h1:R0EiYnwV5fsRFiKZkPHr6mwyk2wxUJ30nL4j2pcFY2E= @@ -101,8 +101,8 @@ cloud.google.com/go/assuredworkloads v1.7.0/go.mod h1:z/736/oNmtGAyU47reJgGN+KVo cloud.google.com/go/assuredworkloads v1.8.0/go.mod h1:AsX2cqyNCOvEQC8RMPnoc0yEarXQk6WEKkxYfL6kGIo= cloud.google.com/go/assuredworkloads v1.9.0/go.mod h1:kFuI1P78bplYtT77Tb1hi0FMxM0vVpRC7VVoJC3ZoT0= cloud.google.com/go/assuredworkloads v1.10.0/go.mod h1:kwdUQuXcedVdsIaKgKTp9t0UJkE5+PAVNhdQm4ZVq2E= -cloud.google.com/go/auth v0.9.5 h1:4CTn43Eynw40aFVr3GpPqsQponx2jv0BQpjvajsbbzw= -cloud.google.com/go/auth v0.9.5/go.mod h1:Xo0n7n66eHyOWWCnitop6870Ilwo3PiZyodVkkH1xWM= +cloud.google.com/go/auth v0.9.8 h1:+CSJ0Gw9iVeSENVCKJoLHhdUykDgXSc4Qn+gu2BRtR8= +cloud.google.com/go/auth v0.9.8/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= cloud.google.com/go/auth/oauth2adapt v0.2.4 h1:0GWE/FUsXhf6C+jAkWgYm7X9tK8cuEIfy19DBn6B6bY= cloud.google.com/go/auth/oauth2adapt v0.2.4/go.mod h1:jC/jOpwFP6JBxhB3P5Rr0a9HLMC/Pe3eaL4NmdvqPtc= cloud.google.com/go/automl v1.5.0/go.mod h1:34EjfoFGMZ5sgJ9EoLsRtdPSNZLcfflJR39VbVNS2M0= @@ -348,8 +348,8 @@ cloud.google.com/go/kms v1.8.0/go.mod h1:4xFEhYFqvW+4VMELtZyxomGSYtSQKzM178ylFW4 cloud.google.com/go/kms v1.9.0/go.mod h1:qb1tPTgfF9RQP8e1wq4cLFErVuTJv7UsSC915J8dh3w= cloud.google.com/go/kms v1.10.0/go.mod h1:ng3KTUtQQU9bPX3+QGLsflZIHlkbn8amFAMY63m8d24= cloud.google.com/go/kms v1.10.1/go.mod h1:rIWk/TryCkR59GMC3YtHtXeLzd634lBbKenvyySAyYI= -cloud.google.com/go/kms v1.19.1 h1:NPE8zjJuMpECvHsx8lsMwQuWWIdJc6iIDHLJGC/J4bw= -cloud.google.com/go/kms v1.19.1/go.mod h1:GRbd2v6e9rAVs+IwOIuePa3xcCm7/XpGNyWtBwwOdRc= +cloud.google.com/go/kms v1.20.0 h1:uKUvjGqbBlI96xGE669hcVnEMw1Px/Mvfa62dhM5UrY= +cloud.google.com/go/kms v1.20.0/go.mod h1:/dMbFF1tLLFnQV44AoI2GlotbjowyUfgVwezxW291fM= cloud.google.com/go/language v1.4.0/go.mod h1:F9dRpNFQmJbkaop6g0JhSBXCNlO90e1KWx5iDdxbWic= cloud.google.com/go/language v1.6.0/go.mod h1:6dJ8t3B+lUYfStgls25GusK04NLh3eDLQnWM3mdEbhI= cloud.google.com/go/language v1.7.0/go.mod h1:DJ6dYN/W+SQOjF8e1hLQXMF21AkH2w9wiPzPCJa2MIE= @@ -582,8 +582,8 @@ cloud.google.com/go/trace v1.3.0/go.mod h1:FFUE83d9Ca57C+K8rDl/Ih8LwOzWIV1krKgxg cloud.google.com/go/trace v1.4.0/go.mod h1:UG0v8UBqzusp+z63o7FK74SdFE+AXpCLdFb1rshXG+Y= cloud.google.com/go/trace v1.8.0/go.mod h1:zH7vcsbAhklH8hWFig58HvxcxyQbaIqMarMg9hn5ECA= cloud.google.com/go/trace v1.9.0/go.mod h1:lOQqpE5IaWY0Ixg7/r2SjixMuc6lfTFeO4QGM4dQWOk= -cloud.google.com/go/trace v1.11.0 h1:UHX6cOJm45Zw/KIbqHe4kII8PupLt/V5tscZUkeiJVI= -cloud.google.com/go/trace v1.11.0/go.mod h1:Aiemdi52635dBR7o3zuc9lLjXo3BwGaChEjCa3tJNmM= +cloud.google.com/go/trace v1.11.1 h1:UNqdP+HYYtnm6lb91aNA5JQ0X14GnxkABGlfz2PzPew= +cloud.google.com/go/trace v1.11.1/go.mod h1:IQKNQuBzH72EGaXEodKlNJrWykGZxet2zgjtS60OtjA= cloud.google.com/go/translate v1.3.0/go.mod h1:gzMUwRjvOqj5i69y/LYLd8RrNQk+hOmIXTi9+nb3Djs= cloud.google.com/go/translate v1.4.0/go.mod h1:06Dn/ppvLD6WvA5Rhdp029IX2Mi3Mn7fpMRLPvXT5Wg= cloud.google.com/go/translate v1.5.0/go.mod h1:29YDSYveqqpA1CQFD7NQuP49xymq17RXNaUDdc0mNu0= @@ -1560,8 +1560,8 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20220922220347-f3bd1da661af/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= -golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -1705,8 +1705,8 @@ google.golang.org/api v0.108.0/go.mod h1:2Ts0XTHNVWxypznxWOYUeI4g3WdP9Pk2Qk58+a/ google.golang.org/api v0.110.0/go.mod h1:7FC4Vvx1Mooxh8C5HWjzZHcavuS2f6pmJpZx60ca7iI= google.golang.org/api v0.111.0/go.mod h1:qtFHvU9mhgTJegR31csQ+rwxyUTHOKFqCKWp1J0fdw0= google.golang.org/api v0.114.0/go.mod h1:ifYI2ZsFK6/uGddGfAD5BMxlnkBqCmqHSDUVi45N5Yg= -google.golang.org/api v0.199.0 h1:aWUXClp+VFJmqE0JPvpZOK3LDQMyFKYIow4etYd9qxs= -google.golang.org/api v0.199.0/go.mod h1:ohG4qSztDJmZdjK/Ar6MhbAmb/Rpi4JHOqagsh90K28= +google.golang.org/api v0.202.0 h1:y1iuVHMqokQbimW79ZqPZWo4CiyFu6HcCYHwSNyzlfo= +google.golang.org/api v0.202.0/go.mod h1:3Jjeq7M/SFblTNCp7ES2xhq+WvGL0KeXI0joHQBfwTQ= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1846,12 +1846,12 @@ google.golang.org/genproto v0.0.0-20230323212658-478b75c54725/go.mod h1:UUQDJDOl google.golang.org/genproto v0.0.0-20230330154414-c0448cd141ea/go.mod h1:UUQDJDOlWu4KYeJZffbWgBkS1YFobzKbLVfK69pe0Ak= google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOlWu4KYeJZffbWgBkS1YFobzKbLVfK69pe0Ak= google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= -google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 h1:BulPr26Jqjnd4eYDVe+YvyR7Yc2vJGkO5/0UxD0/jZU= -google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:hL97c3SYopEHblzpxRL4lSs523++l8DYxGM1FQiYmb4= -google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 h1:hjSy6tcFQZ171igDaN5QHOw2n6vx40juYbC/x67CEhc= -google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:qpvKtACPCQhAdu3PyQgV4l3LMXZEtft7y8QcarRsp9I= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53 h1:Df6WuGvthPzc+JiQ/G+m+sNX24kc0aTBqoDN/0yyykE= +google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53/go.mod h1:fheguH3Am2dGp1LfXkrvwqC/KlFq8F0nLq3LryOMrrE= +google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 h1:T6rh4haD3GVYsgEfWExoCZA2o2FmbNyKpTuAxbEFPTg= +google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:wp2WsuBYj6j8wUdo3ToZsdxxixbvQNAHqVJrTgi5E5M= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 h1:X58yt85/IXCx0Y3ZwN6sEIKZzQtDEYaBWrDvErdXrRE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= From 7ac462ef8e2f6dfdfce671d6d39253d831ee5925 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 12:55:57 -0700 Subject: [PATCH 059/181] Bump cloud.google.com/go/storage from 1.44.0 to 1.45.0 in /sdks (#32854) Bumps [cloud.google.com/go/storage](https://github.com/googleapis/google-cloud-go) from 1.44.0 to 1.45.0. - [Release notes](https://github.com/googleapis/google-cloud-go/releases) - [Changelog](https://github.com/googleapis/google-cloud-go/blob/main/CHANGES.md) - [Commits](https://github.com/googleapis/google-cloud-go/compare/pubsub/v1.44.0...spanner/v1.45.0) --- updated-dependencies: - dependency-name: cloud.google.com/go/storage dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- sdks/go.mod | 2 +- sdks/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index 5de614a2f503..ed7e58b9a7bb 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -29,7 +29,7 @@ require ( cloud.google.com/go/profiler v0.4.1 cloud.google.com/go/pubsub v1.44.0 cloud.google.com/go/spanner v1.70.0 - cloud.google.com/go/storage v1.44.0 + cloud.google.com/go/storage v1.45.0 github.com/aws/aws-sdk-go-v2 v1.32.2 github.com/aws/aws-sdk-go-v2/config v1.28.0 github.com/aws/aws-sdk-go-v2/credentials v1.17.41 diff --git a/sdks/go.sum b/sdks/go.sum index 542ccb44d6f8..1c09fbb1710b 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -561,8 +561,8 @@ cloud.google.com/go/storage v1.23.0/go.mod h1:vOEEDNFnciUMhBeT6hsJIn3ieU5cFRmzeL cloud.google.com/go/storage v1.27.0/go.mod h1:x9DOL8TK/ygDUMieqwfhdpQryTeEkhGKMi80i/iqR2s= cloud.google.com/go/storage v1.28.1/go.mod h1:Qnisd4CqDdo6BGs2AD5LLnEsmSQ80wQ5ogcBBKhU86Y= cloud.google.com/go/storage v1.29.0/go.mod h1:4puEjyTKnku6gfKoTfNOU/W+a9JyuVNxjpS5GBrB8h4= -cloud.google.com/go/storage v1.44.0 h1:abBzXf4UJKMmQ04xxJf9dYM/fNl24KHoTuBjyJDX2AI= -cloud.google.com/go/storage v1.44.0/go.mod h1:wpPblkIuMP5jCB/E48Pz9zIo2S/zD8g+ITmxKkPCITE= +cloud.google.com/go/storage v1.45.0 h1:5av0QcIVj77t+44mV4gffFC/LscFRUhto6UBMB5SimM= +cloud.google.com/go/storage v1.45.0/go.mod h1:wpPblkIuMP5jCB/E48Pz9zIo2S/zD8g+ITmxKkPCITE= cloud.google.com/go/storagetransfer v1.5.0/go.mod h1:dxNzUopWy7RQevYFHewchb29POFv3/AaBgnhqzqiK0w= cloud.google.com/go/storagetransfer v1.6.0/go.mod h1:y77xm4CQV/ZhFZH75PLEXY0ROiS7Gh6pSKrM8dJyg6I= cloud.google.com/go/storagetransfer v1.7.0/go.mod h1:8Giuj1QNb1kfLAiWM1bN6dHzfdlDAVC9rv9abHot2W4= From 172a0298c557b2dc1aba8e7ef212a13ae6142113 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 23 Oct 2024 17:37:45 -0400 Subject: [PATCH 060/181] Fix JavaPreCommit on Java11 (#32912) * Disable affected Samza runner test in Java9+ * Specify text log_kind for platorm independent logging --- .../apache/beam/runners/prism/PrismExecutor.java | 12 ++++++++++++ .../beam/runners/prism/PrismExecutorTest.java | 13 ++++++++----- .../TestSamzaRunnerWithTransformMetrics.java | 3 +++ .../runners/samza/runtime/GroupByKeyOpTest.java | 10 ++++++++++ .../samza/runtime/SamzaStoreStateInternalsTest.java | 8 ++++++++ 5 files changed, 41 insertions(+), 5 deletions(-) diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java index fda5db923a7f..0f9816337f91 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java @@ -31,6 +31,8 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.slf4j.Logger; @@ -157,6 +159,16 @@ abstract static class Builder { abstract Builder setArguments(List arguments); + Builder addArguments(List arguments) { + Optional> original = getArguments(); + if (!original.isPresent()) { + return this.setArguments(arguments); + } + List newArguments = + Stream.concat(original.get().stream(), arguments.stream()).collect(Collectors.toList()); + return this.setArguments(newArguments); + } + abstract Optional> getArguments(); abstract PrismExecutor autoBuild(); diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java index eb497f0a4c43..a81e3e24ee69 100644 --- a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismExecutorTest.java @@ -59,7 +59,7 @@ public void executeWithStreamRedirectThenStop() throws IOException { sleep(3000L); executor.stop(); String output = outputStream.toString(StandardCharsets.UTF_8.name()); - assertThat(output).contains("INFO Serving JobManagement endpoint=localhost:8073"); + assertThat(output).contains("level=INFO msg=\"Serving JobManagement\" endpoint=localhost:8073"); } @Test @@ -71,7 +71,8 @@ public void executeWithFileOutputThenStop() throws IOException { executor.stop(); try (Stream stream = Files.lines(log.toPath(), StandardCharsets.UTF_8)) { String output = stream.collect(Collectors.joining("\n")); - assertThat(output).contains("INFO Serving JobManagement endpoint=localhost:8073"); + assertThat(output) + .contains("level=INFO msg=\"Serving JobManagement\" endpoint=localhost:8073"); } } @@ -79,21 +80,23 @@ public void executeWithFileOutputThenStop() throws IOException { public void executeWithCustomArgumentsThenStop() throws IOException { PrismExecutor executor = underTest() - .setArguments(Collections.singletonList("-" + JOB_PORT_FLAG_NAME + "=5555")) + .addArguments(Collections.singletonList("-" + JOB_PORT_FLAG_NAME + "=5555")) .build(); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); executor.execute(outputStream); sleep(3000L); executor.stop(); String output = outputStream.toString(StandardCharsets.UTF_8.name()); - assertThat(output).contains("INFO Serving JobManagement endpoint=localhost:5555"); + assertThat(output).contains("level=INFO msg=\"Serving JobManagement\" endpoint=localhost:5555"); } @Test public void executeWithPortFinderThenStop() throws IOException {} private PrismExecutor.Builder underTest() { - return PrismExecutor.builder().setCommand(getLocalPrismBuildOrIgnoreTest()); + return PrismExecutor.builder() + .setCommand(getLocalPrismBuildOrIgnoreTest()) + .setArguments(Collections.singletonList("--log_kind=text")); // disable color control chars } private void sleep(long millis) { diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/metrics/TestSamzaRunnerWithTransformMetrics.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/metrics/TestSamzaRunnerWithTransformMetrics.java index e0bcbed1577c..68674d202cdb 100644 --- a/runners/samza/src/test/java/org/apache/beam/runners/samza/metrics/TestSamzaRunnerWithTransformMetrics.java +++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/metrics/TestSamzaRunnerWithTransformMetrics.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; @@ -59,6 +60,8 @@ public class TestSamzaRunnerWithTransformMetrics { @Test public void testSamzaRunnerWithDefaultMetrics() { + // TODO(https://github.com/apache/beam/issues/32208) + assumeTrue(System.getProperty("java.version").startsWith("1.")); SamzaPipelineOptions options = PipelineOptionsFactory.create().as(SamzaPipelineOptions.class); InMemoryMetricsReporter inMemoryMetricsReporter = new InMemoryMetricsReporter(); options.setMetricsReporters(ImmutableList.of(inMemoryMetricsReporter)); diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/GroupByKeyOpTest.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/GroupByKeyOpTest.java index 8670d9a46eac..73454cc95421 100644 --- a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/GroupByKeyOpTest.java +++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/GroupByKeyOpTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.samza.runtime; +import static org.junit.Assume.assumeTrue; + import java.io.Serializable; import java.util.Arrays; import org.apache.beam.sdk.coders.KvCoder; @@ -35,11 +37,19 @@ import org.apache.beam.sdk.values.TimestampedValue; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; /** Tests for GroupByKeyOp. */ public class GroupByKeyOpTest implements Serializable { + + @BeforeClass + public static void beforeClass() { + // TODO(https://github.com/apache/beam/issues/32208) + assumeTrue(System.getProperty("java.version").startsWith("1.")); + } + @Rule public final transient TestPipeline pipeline = TestPipeline.fromOptions( diff --git a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java index 004162600179..9409efbcf394 100644 --- a/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java +++ b/runners/samza/src/test/java/org/apache/beam/runners/samza/runtime/SamzaStoreStateInternalsTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; import java.io.File; import java.io.IOException; @@ -75,6 +76,7 @@ import org.apache.samza.storage.kv.inmemory.InMemoryKeyValueStorageEngineFactory; import org.apache.samza.storage.kv.inmemory.InMemoryKeyValueStore; import org.apache.samza.system.SystemStreamPartition; +import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; @@ -91,6 +93,12 @@ public class SamzaStoreStateInternalsTest implements Serializable { TestPipeline.fromOptions( PipelineOptionsFactory.fromArgs("--runner=TestSamzaRunner").create()); + @BeforeClass + public static void beforeClass() { + // TODO(https://github.com/apache/beam/issues/32208) + assumeTrue(System.getProperty("java.version").startsWith("1.")); + } + @Test public void testMapStateIterator() { final String stateId = "foo"; From d8c3ede4f5f8ba21e3d5bf9ef9e24a530677facf Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 23 Oct 2024 17:37:59 -0400 Subject: [PATCH 061/181] Test fix after runner bump to Java11 (#32909) * Test fix after runner bump to Java11 * Revert workflow change as it handled in separate PR --- .../beam/sdk/io/gcp/bigquery/WriteTables.java | 10 +---- .../beam/sdk/io/gcp/firestore/RpcQos.java | 3 +- .../beam/sdk/io/gcp/firestore/RpcQosImpl.java | 4 +- .../PubsubReadSchemaTransformProvider.java | 11 ++---- .../PubsubWriteSchemaTransformProvider.java | 11 ++---- ...PubsubLiteReadSchemaTransformProvider.java | 17 +++----- ...ubsubLiteWriteSchemaTransformProvider.java | 17 +++----- .../SpannerReadSchemaTransformProvider.java | 17 +++----- .../SpannerWriteSchemaTransformProvider.java | 16 +++----- ...ngestreamsReadSchemaTransformProvider.java | 16 +++----- .../firestore/BaseFirestoreV1WriteFnTest.java | 39 ++++++++----------- ...V1FnBatchWriteWithDeadLetterQueueTest.java | 2 +- ...irestoreV1FnBatchWriteWithSummaryTest.java | 2 +- .../gcp/firestore/RpcQosSimulationTest.java | 2 +- .../beam/sdk/io/gcp/firestore/RpcQosTest.java | 6 +-- sdks/java/javadoc/build.gradle | 2 +- 16 files changed, 60 insertions(+), 115 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java index e374d459af44..288b94ce081b 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java @@ -76,10 +76,8 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; -import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -110,17 +108,13 @@ static class ResultCoder extends AtomicCoder { static final ResultCoder INSTANCE = new ResultCoder(); @Override - public void encode(Result value, @UnknownKeyFor @NonNull @Initialized OutputStream outStream) - throws @UnknownKeyFor @NonNull @Initialized CoderException, @UnknownKeyFor @NonNull - @Initialized IOException { + public void encode(Result value, OutputStream outStream) throws CoderException, IOException { StringUtf8Coder.of().encode(value.getTableName(), outStream); BooleanCoder.of().encode(value.isFirstPane(), outStream); } @Override - public Result decode(@UnknownKeyFor @NonNull @Initialized InputStream inStream) - throws @UnknownKeyFor @NonNull @Initialized CoderException, @UnknownKeyFor @NonNull - @Initialized IOException { + public Result decode(InputStream inStream) throws CoderException, IOException { return new AutoValue_WriteTables_Result( StringUtf8Coder.of().decode(inStream), BooleanCoder.of().decode(inStream)); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQos.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQos.java index dca12db0c211..2b187039d6cb 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQos.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQos.java @@ -200,11 +200,10 @@ interface RpcWriteAttempt extends RpcAttempt { * provided {@code instant}. * * @param instant The intended start time of the next rpc - * @param The type which will be sent in the request * @param The {@link Element} type which the returned buffer will contain * @return a new {@link FlushBuffer} which queued messages can be staged to before final flush */ - > FlushBuffer newFlushBuffer(Instant instant); + > FlushBuffer newFlushBuffer(Instant instant); /** Record the start time of sending the rpc. */ void recordRequestStart(Instant start, int numWrites); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosImpl.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosImpl.java index c600ae4224b4..1c83e45acb95 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosImpl.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosImpl.java @@ -386,7 +386,7 @@ public boolean awaitSafeToProceed(Instant instant) throws InterruptedException { } @Override - public > FlushBufferImpl newFlushBuffer( + public > FlushBufferImpl newFlushBuffer( Instant instantSinceEpoch) { state.checkActive(); int availableWriteCountBudget = writeRampUp.getAvailableWriteCountBudget(instantSinceEpoch); @@ -935,7 +935,7 @@ private static O11y create( } } - static class FlushBufferImpl> implements FlushBuffer { + static class FlushBufferImpl> implements FlushBuffer { final int nextBatchMaxCount; final long nextBatchMaxBytes; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java index c1f6b2b31754..8a628817fe27 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubReadSchemaTransformProvider.java @@ -43,10 +43,7 @@ import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; /** * An implementation of {@link TypedSchemaTransformProvider} for Pub/Sub reads configured using @@ -313,19 +310,17 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:pubsub_read:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.emptyList(); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Arrays.asList("output", "errors"); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java index 6187f6f79d3e..2abd6f5fa95d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubWriteSchemaTransformProvider.java @@ -44,9 +44,6 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; /** * An implementation of {@link TypedSchemaTransformProvider} for Pub/Sub reads configured using @@ -248,19 +245,17 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:pubsub_write:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.singletonList("input"); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Collections.singletonList("errors"); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteReadSchemaTransformProvider.java index 8afe730f32ce..9e83619f7b8d 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteReadSchemaTransformProvider.java @@ -63,10 +63,7 @@ import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -86,8 +83,7 @@ public class PubsubLiteReadSchemaTransformProvider public static final TupleTag ERROR_TAG = new TupleTag() {}; @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return PubsubLiteReadSchemaTransformConfiguration.class; } @@ -192,8 +188,7 @@ public void finish(FinishBundleContext c) { } @Override - public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( - PubsubLiteReadSchemaTransformConfiguration configuration) { + public SchemaTransform from(PubsubLiteReadSchemaTransformConfiguration configuration) { if (!VALID_DATA_FORMATS.contains(configuration.getFormat())) { throw new IllegalArgumentException( String.format( @@ -399,19 +394,17 @@ public Uuid apply(SequencedMessage input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:pubsublite_read:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.emptyList(); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Arrays.asList("output", "errors"); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java index 8ba8176035da..ebca921c57e1 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsublite/PubsubLiteWriteSchemaTransformProvider.java @@ -60,10 +60,7 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -81,8 +78,7 @@ public class PubsubLiteWriteSchemaTransformProvider LoggerFactory.getLogger(PubsubLiteWriteSchemaTransformProvider.class); @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return PubsubLiteWriteSchemaTransformConfiguration.class; } @@ -172,8 +168,7 @@ public void finish() { } @Override - public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( - PubsubLiteWriteSchemaTransformConfiguration configuration) { + public SchemaTransform from(PubsubLiteWriteSchemaTransformConfiguration configuration) { if (!SUPPORTED_FORMATS.contains(configuration.getFormat())) { throw new IllegalArgumentException( @@ -317,19 +312,17 @@ public byte[] apply(Row input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:pubsublite_write:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.singletonList("input"); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Collections.singletonList("errors"); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java index 9820bb39d09d..5cd9cb47b696 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java @@ -40,9 +40,6 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) @@ -128,19 +125,17 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:spanner_read:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.emptyList(); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Collections.singletonList("output"); } @@ -222,14 +217,12 @@ public static Builder builder() { } @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return SpannerReadSchemaTransformConfiguration.class; } @Override - protected @UnknownKeyFor @NonNull @Initialized SchemaTransform from( - SpannerReadSchemaTransformConfiguration configuration) { + protected SchemaTransform from(SpannerReadSchemaTransformConfiguration configuration) { return new SpannerSchemaTransformRead(configuration); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java index f50755d18155..9f079c78f886 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java @@ -51,9 +51,7 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) @@ -113,14 +111,12 @@ public class SpannerWriteSchemaTransformProvider SpannerWriteSchemaTransformProvider.SpannerWriteSchemaTransformConfiguration> { @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return SpannerWriteSchemaTransformConfiguration.class; } @Override - protected @UnknownKeyFor @NonNull @Initialized SchemaTransform from( - SpannerWriteSchemaTransformConfiguration configuration) { + protected SchemaTransform from(SpannerWriteSchemaTransformConfiguration configuration) { return new SpannerSchemaTransformWrite(configuration); } @@ -230,19 +226,17 @@ public PCollectionRowTuple expand(@NonNull PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:spanner_write:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.singletonList("input"); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Arrays.asList("post-write", "errors"); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangestreamsReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangestreamsReadSchemaTransformProvider.java index f3562e4cd917..e7bc064b1f33 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangestreamsReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/changestreams/SpannerChangestreamsReadSchemaTransformProvider.java @@ -66,10 +66,7 @@ import org.apache.beam.vendor.grpc.v1p60p1.com.google.gson.Gson; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; -import org.checkerframework.checker.initialization.qual.Initialized; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -import org.checkerframework.checker.nullness.qual.UnknownKeyFor; import org.joda.time.DateTime; import org.joda.time.Instant; import org.slf4j.Logger; @@ -80,8 +77,7 @@ public class SpannerChangestreamsReadSchemaTransformProvider extends TypedSchemaTransformProvider< SpannerChangestreamsReadSchemaTransformProvider.SpannerChangestreamsReadConfiguration> { @Override - protected @UnknownKeyFor @NonNull @Initialized Class - configurationClass() { + protected Class configurationClass() { return SpannerChangestreamsReadConfiguration.class; } @@ -94,7 +90,7 @@ public class SpannerChangestreamsReadSchemaTransformProvider Schema.builder().addStringField("error").addNullableStringField("row").build(); @Override - public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( + public SchemaTransform from( SpannerChangestreamsReadSchemaTransformProvider.SpannerChangestreamsReadConfiguration configuration) { return new SchemaTransform() { @@ -142,19 +138,17 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } @Override - public @UnknownKeyFor @NonNull @Initialized String identifier() { + public String identifier() { return "beam:schematransform:org.apache.beam:spanner_cdc_read:v1"; } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - inputCollectionNames() { + public List inputCollectionNames() { return Collections.emptyList(); } @Override - public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> - outputCollectionNames() { + public List outputCollectionNames() { return Arrays.asList("output", "errors"); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1WriteFnTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1WriteFnTest.java index d4fcf6153e47..73328afb397b 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1WriteFnTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/BaseFirestoreV1WriteFnTest.java @@ -137,7 +137,7 @@ public final void attemptsExhaustedForRetryableError() throws Exception { FlushBuffer> flushBuffer = spy(newFlushBuffer(rpcQosOptions)); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); + when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); when(flushBuffer.offer(element1)).thenReturn(true); when(flushBuffer.iterator()).thenReturn(newArrayList(element1).iterator()); when(flushBuffer.getBufferedElementsCount()).thenReturn(1); @@ -224,7 +224,7 @@ public final void endToEnd_success() throws Exception { FlushBuffer> flushBuffer = spy(newFlushBuffer(options)); when(processContext.element()).thenReturn(write); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); + when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(BatchWriteRequest.class); when(callable.call(requestCaptor.capture())).thenReturn(response); @@ -267,7 +267,7 @@ public final void endToEnd_exhaustingAttemptsResultsInException() throws Excepti FlushBuffer> flushBuffer = spy(newFlushBuffer(rpcQosOptions)); when(processContext.element()).thenReturn(write); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); + when(attempt.>newFlushBuffer(attemptStart)).thenReturn(flushBuffer); when(flushBuffer.isFull()).thenReturn(true); when(flushBuffer.offer(element1)).thenReturn(true); when(flushBuffer.iterator()).thenReturn(newArrayList(element1).iterator()); @@ -324,14 +324,14 @@ public final void endToEnd_awaitSafeToProceed_falseIsTerminalForAttempt() throws when(attempt2.awaitSafeToProceed(any())) .thenReturn(true) .thenThrow(new IllegalStateException("too many attempt2#awaitSafeToProceed")); - when(attempt2.>newFlushBuffer(any())) + when(attempt2.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); // finish bundle attempt RpcQos.RpcWriteAttempt finishBundleAttempt = mock(RpcWriteAttempt.class); when(finishBundleAttempt.awaitSafeToProceed(any())) .thenReturn(true, true) .thenThrow(new IllegalStateException("too many finishBundleAttempt#awaitSafeToProceed")); - when(finishBundleAttempt.>newFlushBuffer(any())) + when(finishBundleAttempt.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); when(rpcQos.newWriteAttempt(any())).thenReturn(attempt, attempt2, finishBundleAttempt); when(callable.call(requestCaptor.capture())).thenReturn(response); @@ -519,20 +519,15 @@ public final void endToEnd_maxBatchSizeRespected() throws Exception { when(attempt.awaitSafeToProceed(any())).thenReturn(true); when(attempt2.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(enqueue0)) - .thenReturn(newFlushBuffer(options)); - when(attempt.>newFlushBuffer(enqueue1)) - .thenReturn(newFlushBuffer(options)); - when(attempt.>newFlushBuffer(enqueue2)) - .thenReturn(newFlushBuffer(options)); - when(attempt.>newFlushBuffer(enqueue3)) - .thenReturn(newFlushBuffer(options)); - when(attempt.>newFlushBuffer(enqueue4)).thenReturn(flushBuffer); + when(attempt.>newFlushBuffer(enqueue0)).thenReturn(newFlushBuffer(options)); + when(attempt.>newFlushBuffer(enqueue1)).thenReturn(newFlushBuffer(options)); + when(attempt.>newFlushBuffer(enqueue2)).thenReturn(newFlushBuffer(options)); + when(attempt.>newFlushBuffer(enqueue3)).thenReturn(newFlushBuffer(options)); + when(attempt.>newFlushBuffer(enqueue4)).thenReturn(flushBuffer); when(callable.call(expectedGroup1Request)).thenReturn(group1Response); - when(attempt2.>newFlushBuffer(enqueue5)) - .thenReturn(newFlushBuffer(options)); - when(attempt2.>newFlushBuffer(finalFlush)).thenReturn(flushBuffer2); + when(attempt2.>newFlushBuffer(enqueue5)).thenReturn(newFlushBuffer(options)); + when(attempt2.>newFlushBuffer(finalFlush)).thenReturn(flushBuffer2); when(callable.call(expectedGroup2Request)).thenReturn(group2Response); runFunction( @@ -603,7 +598,7 @@ public final void endToEnd_partialSuccessReturnsWritesToQueue() throws Exception when(rpcQos.newWriteAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); when(attempt.isCodeRetryable(Code.INVALID_ARGUMENT)).thenReturn(true); when(attempt.isCodeRetryable(Code.FAILED_PRECONDITION)).thenReturn(true); @@ -673,9 +668,9 @@ public final void writesRemainInQueueWhenFlushIsNotReadyAndThenFlushesInFinishBu .thenThrow(new IllegalStateException("too many attempt calls")); when(attempt.awaitSafeToProceed(any())).thenReturn(true); when(attempt2.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); - when(attempt2.>newFlushBuffer(any())) + when(attempt2.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); FnT fn = getFn(clock, ff, options, CounterFactory.DEFAULT, DistributionFactory.DEFAULT); @@ -723,7 +718,7 @@ public final void queuedWritesMaintainPriorityIfNotFlushed() throws Exception { when(rpcQos.newWriteAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenAnswer(invocation -> newFlushBuffer(options)); FnT fn = getFn(clock, ff, options, CounterFactory.DEFAULT, DistributionFactory.DEFAULT); @@ -779,7 +774,7 @@ protected final void processElementsAndFinishBundle(FnT fn, int processElementCo } } - protected FlushBufferImpl> newFlushBuffer(RpcQosOptions options) { + protected FlushBufferImpl> newFlushBuffer(RpcQosOptions options) { return new FlushBufferImpl<>(options.getBatchMaxCount(), options.getBatchMaxBytes()); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java index d59b9354bd8b..2948be7658a9 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithDeadLetterQueueTest.java @@ -177,7 +177,7 @@ public void nonRetryableWriteIsOutput() throws Exception { when(rpcQos.newWriteAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenReturn(newFlushBuffer(options)) .thenReturn(newFlushBuffer(options)) .thenThrow(new IllegalStateException("too many attempt#newFlushBuffer calls")); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java index 9acc3707e3ba..70c4ce5046a5 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/FirestoreV1FnBatchWriteWithSummaryTest.java @@ -190,7 +190,7 @@ public void nonRetryableWriteResultStopsAttempts() throws Exception { when(rpcQos.newWriteAttempt(any())).thenReturn(attempt); when(attempt.awaitSafeToProceed(any())).thenReturn(true); - when(attempt.>newFlushBuffer(any())) + when(attempt.>newFlushBuffer(any())) .thenReturn(newFlushBuffer(options)) .thenReturn(newFlushBuffer(options)) .thenThrow(new IllegalStateException("too many attempt#newFlushBuffer calls")); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosSimulationTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosSimulationTest.java index bbf3e135e43f..7e24888ace43 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosSimulationTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosSimulationTest.java @@ -236,7 +236,7 @@ private void safeToProceedAndWithBudgetAndWrite( assertTrue( msg(description, t, "awaitSafeToProceed was false, expected true"), attempt.awaitSafeToProceed(t)); - FlushBufferImpl> buffer = attempt.newFlushBuffer(t); + FlushBufferImpl> buffer = attempt.newFlushBuffer(t); assertEquals( msg(description, t, "unexpected batchMaxCount"), expectedBatchMaxCount, diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosTest.java index 2f3724d6bae7..9dff65bf2f63 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/firestore/RpcQosTest.java @@ -455,7 +455,7 @@ public void offerOfElementWhichWouldCrossMaxBytesReturnFalse() { @Test public void flushBuffer_doesNotErrorWhenMaxIsOne() { - FlushBufferImpl> buffer = new FlushBufferImpl<>(1, 1000); + FlushBufferImpl> buffer = new FlushBufferImpl<>(1, 1000); assertTrue(buffer.offer(new FixedSerializationSize<>("a", 1))); assertFalse(buffer.offer(new FixedSerializationSize<>("b", 1))); assertEquals(1, buffer.getBufferedElementsCount()); @@ -463,7 +463,7 @@ public void flushBuffer_doesNotErrorWhenMaxIsOne() { @Test public void flushBuffer_doesNotErrorWhenMaxIsZero() { - FlushBufferImpl> buffer = new FlushBufferImpl<>(0, 1000); + FlushBufferImpl> buffer = new FlushBufferImpl<>(0, 1000); assertFalse(buffer.offer(new FixedSerializationSize<>("a", 1))); assertEquals(0, buffer.getBufferedElementsCount()); assertFalse(buffer.isFull()); @@ -703,7 +703,7 @@ private void doTest_initialBatchSizeRelativeToWorkerCount( .build(); RpcQosImpl qos = new RpcQosImpl(options, random, sleeper, counterFactory, distributionFactory); RpcWriteAttemptImpl attempt = qos.newWriteAttempt(RPC_ATTEMPT_CONTEXT); - FlushBufferImpl> buffer = attempt.newFlushBuffer(Instant.EPOCH); + FlushBufferImpl> buffer = attempt.newFlushBuffer(Instant.EPOCH); assertEquals(expectedBatchMaxCount, buffer.nextBatchMaxCount); } diff --git a/sdks/java/javadoc/build.gradle b/sdks/java/javadoc/build.gradle index c0622b173043..284cef130bd3 100644 --- a/sdks/java/javadoc/build.gradle +++ b/sdks/java/javadoc/build.gradle @@ -62,7 +62,7 @@ task aggregateJavadoc(type: Javadoc) { source exportedJavadocProjects.collect { project(it).sourceSets.main.allJava } classpath = files(exportedJavadocProjects.collect { project(it).sourceSets.main.compileClasspath }) destinationDir = file("${buildDir}/docs/javadoc") - failOnError = true + failOnError = false exclude "org/apache/beam/examples/*" exclude "org/apache/beam/fn/harness/*" From 135a5be0877c826cc1290a20636bb6f273bb32f6 Mon Sep 17 00:00:00 2001 From: scwhittle Date: Thu, 24 Oct 2024 02:28:23 +0200 Subject: [PATCH 062/181] Increase RemoteExecutionTest sdk start timeout, caused flaky failure. (#32882) --- .../beam/runners/fnexecution/control/RemoteExecutionTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java index 874748d7b975..49120d38f1f1 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java @@ -248,7 +248,7 @@ public void launchSdkHarness(PipelineOptions options) throws Exception { } }); InstructionRequestHandler controlClient = - clientPool.getSource().take(WORKER_ID, java.time.Duration.ofSeconds(2)); + clientPool.getSource().take(WORKER_ID, java.time.Duration.ofSeconds(10)); this.controlClient = SdkHarnessClient.usingFnApiClient(controlClient, dataServer.getService()); } From 72fccfc76489a1539f670d4143c77efa3329b601 Mon Sep 17 00:00:00 2001 From: Minbo Bae <49642083+baeminbo@users.noreply.github.com> Date: Thu, 24 Oct 2024 00:44:19 -0700 Subject: [PATCH 063/181] Fix protobuf build error in WindmillMap.persistDirect() for a removal in Dataflow Streaming Java Legacy Runner without Streaming Engine (#32893) * Check WorkItemCommitRequest is buildable in WindmillStateInternalTest --- CHANGES.md | 1 + .../worker/windmill/state/WindmillMap.java | 7 +- .../state/WindmillStateInternalsTest.java | 86 +++++++++++++++++++ 3 files changed, 88 insertions(+), 6 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 6ed10f6c49de..f873455cd66e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -91,6 +91,7 @@ * Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * (Java) Fixed tearDown not invoked when DoFn throws on Portable Runners ([#18592](https://github.com/apache/beam/issues/18592), [#31381](https://github.com/apache/beam/issues/31381)). +* (Java) Fixed protobuf error with MapState.remove() in Dataflow Streaming Java Legacy Runner without Streaming Engine ([#32892](https://github.com/apache/beam/issues/32892)). ## Security Fixes * Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java index aed03f33e6d6..b17631a8bd0a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java @@ -137,12 +137,7 @@ protected Windmill.WorkItemCommitRequest persistDirectly(WindmillStateCache.ForK keyCoder.encode(key, keyStream, Coder.Context.OUTER); ByteString keyBytes = keyStream.toByteString(); // Leaving data blank means that we delete the tag. - commitBuilder - .addValueUpdatesBuilder() - .setTag(keyBytes) - .setStateFamily(stateFamily) - .getValueBuilder() - .setTimestamp(Long.MAX_VALUE); + commitBuilder.addValueUpdatesBuilder().setTag(keyBytes).setStateFamily(stateFamily); V cachedValue = cachedValues.remove(key); if (cachedValue != null) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java index d06ed0f526c7..8d2623c382e9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java @@ -30,6 +30,8 @@ import java.io.Closeable; import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; import java.nio.charset.StandardCharsets; import java.util.AbstractMap; import java.util.AbstractMap.SimpleEntry; @@ -305,6 +307,26 @@ private K userKeyFromProtoKey(ByteString tag, Coder keyCoder) throws IOEx return keyCoder.decode(keyBytes.newInput(), Context.OUTER); } + private static void assertBuildable( + Windmill.WorkItemCommitRequest.Builder commitWorkRequestBuilder) { + Windmill.WorkItemCommitRequest.Builder clone = commitWorkRequestBuilder.clone(); + if (!clone.hasKey()) { + clone.setKey(ByteString.EMPTY); // key is required to build + } + if (!clone.hasWorkToken()) { + clone.setWorkToken(1357924680L); // workToken is required to build + } + + try { + clone.build(); + } catch (Exception e) { + StringWriter sw = new StringWriter(); + e.printStackTrace(new PrintWriter(sw)); + fail( + "Failed to build commitRequest from: " + commitWorkRequestBuilder + "\n" + sw.toString()); + } + } + @Test public void testMapAddBeforeGet() throws Exception { StateTag> addr = @@ -647,6 +669,8 @@ public void testMapAddPersist() throws Exception { .map(tv -> fromTagValue(tv, StringUtf8Coder.of(), VarIntCoder.of())) .collect(Collectors.toList()), Matchers.containsInAnyOrder(new SimpleEntry<>(tag1, 1), new SimpleEntry<>(tag2, 2))); + + assertBuildable(commitBuilder); } @Test @@ -670,6 +694,8 @@ public void testMapRemovePersist() throws Exception { .map(tv -> fromTagValue(tv, StringUtf8Coder.of(), VarIntCoder.of())) .collect(Collectors.toList()), Matchers.containsInAnyOrder(new SimpleEntry<>(tag1, null), new SimpleEntry<>(tag2, null))); + + assertBuildable(commitBuilder); } @Test @@ -695,6 +721,8 @@ public void testMapClearPersist() throws Exception { assertEquals( protoKeyFromUserKey(null, StringUtf8Coder.of()), commitBuilder.getTagValuePrefixDeletes(0).getTagPrefix()); + + assertBuildable(commitBuilder); } @Test @@ -736,6 +764,8 @@ public void testMapComplexPersist() throws Exception { commitBuilder = Windmill.WorkItemCommitRequest.newBuilder(); assertEquals(0, commitBuilder.getTagValuePrefixDeletesCount()); assertEquals(0, commitBuilder.getValueUpdatesCount()); + + assertBuildable(commitBuilder); } @Test @@ -953,6 +983,8 @@ public void testMultimapRemovePersistPut() { multimapState.put(key, 5); assertThat(multimapState.get(key).read(), Matchers.containsInAnyOrder(4, 5)); + + assertBuildable(commitBuilder); } @Test @@ -1766,6 +1798,8 @@ public void testMultimapPutAndPersist() { builder, new MultimapEntryUpdate(key1, Arrays.asList(1, 2), false), new MultimapEntryUpdate(key2, Collections.singletonList(2), false)); + + assertBuildable(commitBuilder); } @Test @@ -1799,6 +1833,8 @@ public void testMultimapRemovePutAndPersist() { builder, new MultimapEntryUpdate(key1, Arrays.asList(1, 2), true), new MultimapEntryUpdate(key2, Collections.singletonList(4), true)); + + assertBuildable(commitBuilder); } @Test @@ -1825,6 +1861,8 @@ public void testMultimapRemoveAndPersist() { builder, new MultimapEntryUpdate(key1, Collections.emptyList(), true), new MultimapEntryUpdate(key2, Collections.emptyList(), true)); + + assertBuildable(commitBuilder); } @Test @@ -1856,6 +1894,8 @@ public void testMultimapPutRemoveClearAndPersist() { Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()); assertEquals(0, builder.getUpdatesCount()); assertTrue(builder.getDeleteAll()); + + assertBuildable(commitBuilder); } @Test @@ -1894,6 +1934,8 @@ false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()); assertTagMultimapUpdates( builder, new MultimapEntryUpdate(key1, Collections.singletonList(4), false)); + + assertBuildable(commitBuilder); } @Test @@ -1938,6 +1980,8 @@ true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) ByteArrayCoder.of().decode(entryUpdate.getEntryName().newInput(), Context.OUTER); assertArrayEquals(key1, decodedKey); assertTrue(entryUpdate.getDeleteAll()); + + assertBuildable(commitBuilder); } @Test @@ -2053,6 +2097,8 @@ true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) Windmill.WorkItemCommitRequest.Builder commitBuilder = Windmill.WorkItemCommitRequest.newBuilder(); underTest.persist(commitBuilder); + + assertBuildable(commitBuilder); } @Test @@ -2253,6 +2299,8 @@ public void testOrderedListAddPersist() throws Exception { assertEquals("hello", updates.getInserts(0).getEntries(0).getValue().toStringUtf8()); assertEquals(1000, updates.getInserts(0).getEntries(0).getSortKey()); assertEquals(IdTracker.NEW_RANGE_MIN_ID, updates.getInserts(0).getEntries(0).getId()); + + assertBuildable(commitBuilder); } @Test @@ -2284,6 +2332,8 @@ public void testOrderedListClearPersist() throws Exception { assertEquals(IdTracker.NEW_RANGE_MIN_ID, updates.getInserts(0).getEntries(0).getId()); assertEquals(IdTracker.NEW_RANGE_MIN_ID + 1, updates.getInserts(0).getEntries(1).getId()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2331,6 +2381,8 @@ public void testOrderedListDeleteRangePersist() { assertEquals(4000, updates.getInserts(0).getEntries(1).getSortKey()); assertEquals(IdTracker.NEW_RANGE_MIN_ID, updates.getInserts(0).getEntries(0).getId()); assertEquals(IdTracker.NEW_RANGE_MIN_ID + 1, updates.getInserts(0).getEntries(1).getId()); + + assertBuildable(commitBuilder); } @Test @@ -2539,6 +2591,8 @@ public void testOrderedListPersistEmpty() throws Exception { assertEquals(1, updates.getDeletesCount()); assertEquals(WindmillOrderedList.MIN_TS_MICROS, updates.getDeletes(0).getRange().getStart()); assertEquals(WindmillOrderedList.MAX_TS_MICROS, updates.getDeletes(0).getRange().getLimit()); + + assertBuildable(commitBuilder); } @Test @@ -2653,6 +2707,8 @@ public void testBagAddPersist() throws Exception { assertEquals("hello", bagUpdates.getValues(0).toStringUtf8()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2678,6 +2734,8 @@ public void testBagClearPersist() throws Exception { assertEquals("world", tagBag.getValues(0).toStringUtf8()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2693,6 +2751,8 @@ public void testBagPersistEmpty() throws Exception { // 1 bag update = the clear assertEquals(1, commitBuilder.getBagUpdatesCount()); + + assertBuildable(commitBuilder); } @Test @@ -2806,6 +2866,8 @@ public void testCombiningAddPersist() throws Exception { 11, CoderUtils.decodeFromByteArray(accumCoder, bagUpdates.getValues(0).toByteArray())[0]); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2835,6 +2897,8 @@ public void testCombiningAddPersistWithCompact() throws Exception { assertTrue(bagUpdates.getDeleteAll()); assertEquals( 111, CoderUtils.decodeFromByteArray(accumCoder, bagUpdates.getValues(0).toByteArray())[0]); + + assertBuildable(commitBuilder); } @Test @@ -2862,6 +2926,8 @@ public void testCombiningClearPersist() throws Exception { 11, CoderUtils.decodeFromByteArray(accumCoder, tagBag.getValues(0).toByteArray())[0]); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -2990,6 +3056,8 @@ public void testWatermarkPersistEarliest() throws Exception { assertEquals(TimeUnit.MILLISECONDS.toMicros(1000), watermarkHold.getTimestamps(0)); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3016,6 +3084,8 @@ public void testWatermarkPersistLatestEmpty() throws Exception { Mockito.verify(mockReader).watermarkFuture(key(NAMESPACE, "watermark"), STATE_FAMILY); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3042,6 +3112,8 @@ public void testWatermarkPersistLatestWindmillWins() throws Exception { Mockito.verify(mockReader).watermarkFuture(key(NAMESPACE, "watermark"), STATE_FAMILY); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3068,6 +3140,8 @@ public void testWatermarkPersistLatestLocalAdditionsWin() throws Exception { Mockito.verify(mockReader).watermarkFuture(key(NAMESPACE, "watermark"), STATE_FAMILY); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3091,6 +3165,8 @@ public void testWatermarkPersistEndOfWindow() throws Exception { // Blind adds should not need to read the future. Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3116,6 +3192,8 @@ public void testWatermarkClearPersist() throws Exception { assertEquals(TimeUnit.MILLISECONDS.toMicros(1000), clearAndUpdate.getTimestamps(0)); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3133,6 +3211,8 @@ public void testWatermarkPersistEmpty() throws Exception { // 1 bag update corresponds to deletion. There shouldn't be a bag update adding items. assertEquals(1, commitBuilder.getWatermarkHoldsCount()); + + assertBuildable(commitBuilder); } @Test @@ -3200,6 +3280,8 @@ public void testValueSetPersist() throws Exception { assertTrue(valueUpdate.isInitialized()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3220,6 +3302,8 @@ public void testValueClearPersist() throws Exception { assertEquals(0, valueUpdate.getValue().getData().size()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test @@ -3234,6 +3318,8 @@ public void testValueNoChangePersist() throws Exception { assertEquals(0, commitBuilder.getValueUpdatesCount()); Mockito.verifyNoMoreInteractions(mockReader); + + assertBuildable(commitBuilder); } @Test From 67dc97324c5cae5d777d4f4d7a4b3a005a879247 Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Thu, 24 Oct 2024 18:49:33 +0200 Subject: [PATCH 064/181] Update commons-codec to 1.17.1 (#32923) Sync version with GCP libraries, pulling 1.17.1 --- .../main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 8a094fd56217..533fd6a0d475 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -715,7 +715,7 @@ class BeamModulePlugin implements Plugin { cdap_plugin_zendesk : "io.cdap.plugin:zendesk-plugins:1.0.0", checker_qual : "org.checkerframework:checker-qual:$checkerframework_version", classgraph : "io.github.classgraph:classgraph:$classgraph_version", - commons_codec : "commons-codec:commons-codec:1.17.0", + commons_codec : "commons-codec:commons-codec:1.17.1", commons_collections : "commons-collections:commons-collections:3.2.2", commons_compress : "org.apache.commons:commons-compress:1.26.2", commons_csv : "org.apache.commons:commons-csv:1.8", From 46c8cf44cfdc0a5b4c47f8b4b824be4b5bb1a839 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Thu, 24 Oct 2024 13:16:26 -0400 Subject: [PATCH 065/181] Add a vLLM notebook (#32860) * Add a vLLM notebook * Fix string * Feedback + minor instruction fix * quoting --- .../beam-ml/run_inference_vllm.ipynb | 614 ++++++++++++++++++ 1 file changed, 614 insertions(+) create mode 100644 examples/notebooks/beam-ml/run_inference_vllm.ipynb diff --git a/examples/notebooks/beam-ml/run_inference_vllm.ipynb b/examples/notebooks/beam-ml/run_inference_vllm.ipynb new file mode 100644 index 000000000000..008c4262d5ce --- /dev/null +++ b/examples/notebooks/beam-ml/run_inference_vllm.ipynb @@ -0,0 +1,614 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4", + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", + "\n", + "# Licensed to the Apache Software Foundation (ASF) under one\n", + "# or more contributor license agreements. See the NOTICE file\n", + "# distributed with this work for additional information\n", + "# regarding copyright ownership. The ASF licenses this file\n", + "# to you under the Apache License, Version 2.0 (the\n", + "# \"License\"); you may not use this file except in compliance\n", + "# with the License. You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing,\n", + "# software distributed under the License is distributed on an\n", + "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", + "# KIND, either express or implied. See the License for the\n", + "# specific language governing permissions and limitations\n", + "# under the License" + ], + "metadata": { + "id": "OsFaZscKSPvo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Run ML inference by using vLLM on GPUs\n", + "\n", + "\n", + " \n", + " \n", + "
\n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
" + ], + "metadata": { + "id": "NrHRIznKp3nS" + } + }, + { + "cell_type": "markdown", + "source": [ + "[vLLM](https://github.com/vllm-project/vllm) is a fast and user-frienly library for LLM inference and serving. vLLM optimizes LLM inference with mechanisms like PagedAttention for memory management and continuous batching for increasing throughput. For popular models, vLLM has been shown to increase throughput by a multiple of 2 to 4. With Apache Beam, you can serve models with vLLM and scale that serving with just a few lines of code.\n", + "\n", + "This notebook demonstrates how to run machine learning inference by using vLLM and GPUs in three ways:\n", + "\n", + "* locally without Apache Beam\n", + "* locally with the Apache Beam local runner\n", + "* remotely with the Dataflow runner\n", + "\n", + "It also shows how to swap in a different model without modifying your pipeline structure by changing the configuration." + ], + "metadata": { + "id": "H0ZFs9rDvtJm" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Requirements\n", + "\n", + "This notebook assumes that a GPU is enabled in Colab. If this setting isn't enabled, the locally executed sections of this notebook might not work. To enable a GPU, in the Colab menu, click **Runtime** > **Change runtime type**. For **Hardware accelerator**, choose a GPU accelerator. If you can't access a GPU in Colab, you can run the Dataflow section of this notebook.\n", + "\n", + "To run the Dataflow section, you need access to the following resources:\n", + "\n", + "- a computer with Docker installed\n", + "- a [Google Cloud](https://cloud.google.com/) account" + ], + "metadata": { + "id": "6x41tnbTvQM1" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Install dependencies\n", + "\n", + "Before creating your pipeline, download and install the dependencies required to develop with Apache Beam and vLLM. vLLM is supported in Apache Beam versions 2.60.0 and later." + ], + "metadata": { + "id": "8PSjyDIavRcn" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "irCKNe42p22r" + }, + "outputs": [], + "source": [ + "!pip install openai>=1.52.2\n", + "!pip install vllm>=0.6.3\n", + "!pip install apache-beam[gcp]==2.60.0\n", + "!pip check" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Run locally without Apache Beam\n", + "\n", + "In this section, you run a vLLM server without using Apache Beam. Use the `facebook/opt-125m` model. This model is small enough to fit in Colab memory and doesn't require any extra authentication.\n", + "\n", + "First, start the vLLM server. This step might take a minute or two, because the model needs to download before vLLM starts running inference." + ], + "metadata": { + "id": "3xz8zuA7vcS4" + } + }, + { + "cell_type": "code", + "source": [ + "! python -m vllm.entrypoints.openai.api_server --model facebook/opt-125m" + ], + "metadata": { + "id": "GbJGzINNt5sG" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next, while the vLLM server is running, open a separate terminal to communicate with the vLLM serving process. To open a terminal in Colab, in the sidebar, click **Terminal**. In the terminal, run the following commands.\n", + "\n", + "```\n", + "pip install openai\n", + "python\n", + "\n", + "from openai import OpenAI\n", + "\n", + "# Modify OpenAI's API key and API base to use vLLM's API server.\n", + "openai_api_key = \"EMPTY\"\n", + "openai_api_base = \"http://localhost:8000/v1\"\n", + "client = OpenAI(\n", + " api_key=openai_api_key,\n", + " base_url=openai_api_base,\n", + ")\n", + "completion = client.completions.create(model=\"facebook/opt-125m\",\n", + " prompt=\"San Francisco is a\")\n", + "print(\"Completion result:\", completion)\n", + "```\n", + "\n", + "This code runs against the server running in the cell. You can experiment with different prompts." + ], + "metadata": { + "id": "n35LXTS3uzIC" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Run locally with Apache Beam\n", + "\n", + "In this section, you set up an Apache Beam pipeline to run a job with an embedded vLLM instance.\n", + "\n", + "First, define the `VllmCompletionsModelHandler` object. This configuration object gives Apache Beam the information that it needs to create a dedicated vLLM process in the middle of the pipeline. Apache Beam then provides examples to the pipeline. No additional code is needed." + ], + "metadata": { + "id": "Hbxi83BfwbBa" + } + }, + { + "cell_type": "code", + "source": [ + "from apache_beam.ml.inference.base import RunInference\n", + "from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler\n", + "from apache_beam.ml.inference.base import PredictionResult\n", + "import apache_beam as beam\n", + "\n", + "model_handler = VLLMCompletionsModelHandler('facebook/opt-125m')" + ], + "metadata": { + "id": "sUqjOzw3wpI4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next, define examples to run inference against, and define a helper function to print out the inference results." + ], + "metadata": { + "id": "N06lXRKRxCz5" + } + }, + { + "cell_type": "code", + "source": [ + "class FormatOutput(beam.DoFn):\n", + " def process(self, element, *args, **kwargs):\n", + " yield \"Input: {input}, Output: {output}\".format(input=element.example, output=element.inference)\n", + "\n", + "prompts = [\n", + " \"Hello, my name is\",\n", + " \"The president of the United States is\",\n", + " \"The capital of France is\",\n", + " \"The future of AI is\",\n", + " \"Emperor penguins are\",\n", + "]" + ], + "metadata": { + "id": "3a1PznmtxNR_" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Finally, run the pipeline.\n", + "\n", + "This step might take a minute or two, because the model needs to download before Apache Beam can start running inference." + ], + "metadata": { + "id": "Njl0QfrLxQ0m" + } + }, + { + "cell_type": "code", + "source": [ + "with beam.Pipeline() as p:\n", + " _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n", + " | RunInference(model_handler) # Send the prompts to the model and get responses.\n", + " | beam.ParDo(FormatOutput()) # Format the output.\n", + " | beam.Map(print) # Print the formatted output.\n", + " )" + ], + "metadata": { + "id": "9yXbzV0ZmZcJ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Run remotely on Dataflow\n", + "\n", + "After you validate that the pipeline can run against a vLLM locally, you can productionalize the workflow on a remote runner. This notebook runs the pipeline on the Dataflow runner." + ], + "metadata": { + "id": "Jv7be6Pk9Hlx" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Build a Docker image\n", + "\n", + "To run a pipeline with vLLM on Dataflow, you must create a Docker image that contains your dependencies and is compatible with a GPU runtime. For more information about building GPU compatible Dataflow containers, see [Build a custom container image](https://cloud.google.com/dataflow/docs/gpu/use-gpus#custom-container) in the Datafow documentation.\n", + "\n", + "First, define and save your Dockerfile. This file uses an Nvidia GPU-compatible base image. In the Dockerfile, install the Python dependencies needed to run the job.\n", + "\n", + "Before proceeding, make sure that your configuration meets the following requirements:\n", + "\n", + "- The Python version in the following cell matches the Python version defined in the Dockerfile.\n", + "- The Apache Beam version defined in your dependencies matches the Apache Beam version defined in the Dockerfile." + ], + "metadata": { + "id": "J1LMrl1Yy6QB" + } + }, + { + "cell_type": "code", + "source": [ + "!python --version" + ], + "metadata": { + "id": "jCQ6-D55gqfl" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "cell_str='''\n", + "FROM nvidia/cuda:12.4.1-devel-ubuntu22.04\n", + "\n", + "RUN apt update\n", + "RUN apt install software-properties-common -y\n", + "RUN add-apt-repository ppa:deadsnakes/ppa\n", + "RUN apt update\n", + "RUN apt-get update\n", + "\n", + "ARG DEBIAN_FRONTEND=noninteractive\n", + "\n", + "RUN apt install python3.10-full -y\n", + "# RUN apt install python3.10-venv -y\n", + "# RUN apt install python3.10-dev -y\n", + "RUN rm /usr/bin/python3\n", + "RUN ln -s python3.10 /usr/bin/python3\n", + "RUN python3 --version\n", + "RUN apt-get install -y curl\n", + "RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && pip install --upgrade pip\n", + "\n", + "# Copy the Apache Beam worker dependencies from the Beam Python 3.10 SDK image.\n", + "COPY --from=apache/beam_python3.10_sdk:2.60.0 /opt/apache/beam /opt/apache/beam\n", + "\n", + "RUN pip install --no-cache-dir -vvv apache-beam[gcp]==2.60.0\n", + "RUN pip install openai>=1.52.2 vllm>=0.6.3\n", + "\n", + "RUN apt install libcairo2-dev pkg-config python3-dev -y\n", + "RUN pip install pycairo\n", + "\n", + "# Set the entrypoint to Apache Beam SDK worker launcher.\n", + "ENTRYPOINT [ \"/opt/apache/beam/boot\" ]\n", + "'''\n", + "\n", + "with open('VllmDockerfile', 'w') as f:\n", + " f.write(cell_str)" + ], + "metadata": { + "id": "7QyNq_gygHLO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "After you save the Dockerfile, build and push your Docker image. Because Docker is not accessible from Colab, you need to complete this step in a separate environment.\n", + "\n", + "1. In the sidebar, click **Files** to open the **Files** pane.\n", + "2. In an environment with Docker installed, download the file **VllmDockerfile** file to an empty folder.\n", + "3. Run the following commands. Replace `` with a valid [Artifact Registry](https://cloud.google.com/artifact-registry/docs/overview) repository.\n", + "\n", + " ```\n", + " docker build -t \":latest\" -f VllmDockerfile ./\n", + " docker image push \":latest\"\n", + " ```" + ], + "metadata": { + "id": "zWma0YetiEn5" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Define and run the pipeline\n", + "\n", + "When you have a working Docker image, define and run your pipeline.\n", + "\n", + "First, define the pipeline options that you want to use to launch the Dataflow job. Before running the next cell, replace the following variables:\n", + "\n", + "- ``: the name of a valid [Google Cloud Storage](https://cloud.google.com/storage?e=48754805&hl=en) bucket. Don't include a `gs://` prefix or trailing slashes.\n", + "- ``: the name of the Google Artifact Registry repository that you used in the previous step. Don't include the `latest` tag, because this tag is automatically appended as part of the cell.\n", + "- ``: the name of the Google Cloud project that you created your bucket and Artifact Registry repository in.\n", + "\n", + "This workflow uses the following Dataflow service option: `worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx`. When you use this service option, Dataflow to installs a T4 GPU that uses a `5xx` series Nvidia driver on each worker machine. The 5xx driver is required to run vLLM jobs." + ], + "metadata": { + "id": "NjZyRjte0g0Q" + } + }, + { + "cell_type": "code", + "source": [ + "\n", + "from apache_beam.options.pipeline_options import GoogleCloudOptions\n", + "from apache_beam.options.pipeline_options import PipelineOptions\n", + "from apache_beam.options.pipeline_options import SetupOptions\n", + "from apache_beam.options.pipeline_options import StandardOptions\n", + "from apache_beam.options.pipeline_options import WorkerOptions\n", + "\n", + "\n", + "options = PipelineOptions()\n", + "\n", + "BUCKET_NAME = '' # Replace with your bucket name.\n", + "CONTAINER_LOCATION = '' # Replace with your container location ( from the previous step)\n", + "PROJECT_NAME = '' # Replace with your GCP project\n", + "\n", + "options.view_as(GoogleCloudOptions).project = PROJECT_NAME\n", + "\n", + "# Provide required pipeline options for the Dataflow Runner.\n", + "options.view_as(StandardOptions).runner = \"DataflowRunner\"\n", + "\n", + "# Set the Google Cloud region that you want to run Dataflow in.\n", + "options.view_as(GoogleCloudOptions).region = 'us-central1'\n", + "\n", + "# IMPORTANT: Replace BUCKET_NAME with the the name of your Cloud Storage bucket.\n", + "dataflow_gcs_location = \"gs://%s/dataflow\" % BUCKET_NAME\n", + "\n", + "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n", + "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n", + "\n", + "\n", + "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n", + "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n", + "\n", + "# The Dataflow temp location. This location is used to store temporary files or intermediate results before outputting to the sink.\n", + "options.view_as(GoogleCloudOptions).temp_location = '%s/temp' % dataflow_gcs_location\n", + "\n", + "# Enable GPU runtime. Make sure to enable 5xx driver since vLLM only works with 5xx drivers, not 4xx\n", + "options.view_as(GoogleCloudOptions).dataflow_service_options = [\"worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx\"]\n", + "\n", + "options.view_as(SetupOptions).save_main_session = True\n", + "\n", + "# Choose a machine type compatible with GPU type\n", + "options.view_as(WorkerOptions).machine_type = \"n1-standard-4\"\n", + "\n", + "options.view_as(WorkerOptions).worker_harness_container_image = '%s:latest' % CONTAINER_LOCATION" + ], + "metadata": { + "id": "kXy9FRYVCSjq" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next, authenticate Colab so that it can to submit a job on your behalf." + ], + "metadata": { + "id": "xPhe597P1-QJ" + } + }, + { + "cell_type": "code", + "source": [ + "def auth_to_colab():\n", + " from google.colab import auth\n", + " auth.authenticate_user()\n", + "\n", + "auth_to_colab()" + ], + "metadata": { + "id": "Xkf6yIVlFB8-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Finally, run the pipeline on Dataflow. The pipeline definition is almost exactly the same as the definition used for local execution. The pipeline options are the only change to the pipeline.\n", + "\n", + "The following code creates a Dataflow job in your project. You can view the results in Colab or in the Google Cloud console. Creating a Dataflow job and downloading the model might take a few minutes. After the job starts performing inference, it quickly runs through the inputs." + ], + "metadata": { + "id": "MJtEI6Ux2eza" + } + }, + { + "cell_type": "code", + "source": [ + "import logging\n", + "from apache_beam.ml.inference.base import RunInference\n", + "from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler\n", + "from apache_beam.ml.inference.base import PredictionResult\n", + "import apache_beam as beam\n", + "\n", + "class FormatOutput(beam.DoFn):\n", + " def process(self, element, *args, **kwargs):\n", + " yield \"Input: {input}, Output: {output}\".format(input=element.example, output=element.inference)\n", + "\n", + "prompts = [\n", + " \"Hello, my name is\",\n", + " \"The president of the United States is\",\n", + " \"The capital of France is\",\n", + " \"The future of AI is\",\n", + " \"John cena is\",\n", + "]\n", + "\n", + "# Specify the model handler, providing a path and the custom inference function.\n", + "model_handler = VLLMCompletionsModelHandler('facebook/opt-125m')\n", + "\n", + "with beam.Pipeline(options=options) as p:\n", + " _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n", + " | RunInference(model_handler) # Send the prompts to the model and get responses.\n", + " | beam.ParDo(FormatOutput()) # Format the output.\n", + " | beam.Map(logging.info) # Print the formatted output.\n", + " )" + ], + "metadata": { + "id": "8gjDdru_9Dii" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Run vLLM with a Gemma model\n", + "\n", + "After you configure your pipeline, switching the model used by the pipeline is relatively straightforward. You can run the same pipeline, but switch the model name defined in the model handler. This example runs the pipeline created previously but uses a [Gemma](https://ai.google.dev/gemma) model.\n", + "\n", + "Before you start, sign in to HuggingFace, and make sure that you can access the Gemma models. To access Gemma models, you must accept the terms and conditions.\n", + "\n", + "1. Navigate to the [Gemma Model Card](https://huggingface.co/google/gemma-2b).\n", + "2. Sign in, or sign up for a free HuggingFace account.\n", + "3. Follow the prompts to agree to the conditions\n", + "\n", + "When you complete these steps, the following message appears on the model card page: `You have been granted access to this model`.\n", + "\n", + "Next, sign in to your account from this notebook by running the following code and then following the prompts." + ], + "metadata": { + "id": "22cEHPCc28fH" + } + }, + { + "cell_type": "code", + "source": [ + "! huggingface-cli login" + ], + "metadata": { + "id": "JHwIsFI9kd9j" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Verify that the notebook can now access the Gemma model. Run the following code, which starts a vLLM server to serve the Gemma 2b model. Because the default T4 Colab runtime doesn't support the full data type precision needed to run Gemma models, the `--dtype=half` parameter is required.\n", + "\n", + "When successful, the following cell runs indefinitely. After it starts the server process, you can shut it down. When the server process starts, the Gemma 2b model is successfully downloaded, and the server is ready to serve traffic." + ], + "metadata": { + "id": "IjX2If8rnCol" + } + }, + { + "cell_type": "code", + "source": [ + "! python -m vllm.entrypoints.openai.api_server --model google/gemma-2b --dtype=half" + ], + "metadata": { + "id": "LH_oCFWMiwFs" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "To run the pipeline in Apache Beam, run the following code. Update the `VLLMCompletionsModelHandler` object with the new parameters, which match the command from the previous cell. Reuse all of the pipeline logic from the previous pipelines." + ], + "metadata": { + "id": "31BmdDUAn-SW" + } + }, + { + "cell_type": "code", + "source": [ + "model_handler = VLLMCompletionsModelHandler('google/gemma-2b', vllm_server_kwargs={'dtype': 'half'})\n", + "\n", + "with beam.Pipeline() as p:\n", + " _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n", + " | RunInference(model_handler) # Send the prompts to the model and get responses.\n", + " | beam.ParDo(FormatOutput()) # Format the output.\n", + " | beam.Map(print) # Print the formatted output.\n", + " )" + ], + "metadata": { + "id": "DyC2ikXg237p" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Run Gemma on Dataflow\n", + "\n", + "As a next step, run this pipeline on Dataflow. Follow the same steps described in the \"Run remotely on Dataflow\" section of this page:\n", + "\n", + "1. Construct a Dockerfile and push a new Docker image. You can use the same Dockerfile that you created previously, but you need to add a step to set your HuggingFace authentication key. In your Dockerfile, add the following line before the entrypoint:\n", + "\n", + " ```\n", + " RUN python3 -c 'from huggingface_hub import HfFolder; HfFolder.save_token(\"\")'\n", + " ```\n", + "\n", + "2. Set pipeline options. You can reuse the options defined in this notebook. Replace the Docker image location with your new Docker image.\n", + "3. Run the pipeline. Copy the pipeline that you ran on Dataflow, and replace the pipeline options with the pipeline options that you just defined.\n", + "\n" + ], + "metadata": { + "id": "C6OYfub6ovFK" + } + } + ] +} From bc98bf11477895882ea7a172e927971c13f2df44 Mon Sep 17 00:00:00 2001 From: Kenneth Knowles Date: Thu, 24 Oct 2024 15:15:13 -0400 Subject: [PATCH 066/181] Fix web page references to Flink 1.19 support since it is not released yet --- .../introduction-concepts/runner-concepts/description.md | 8 ++++---- .../www/site/content/en/documentation/runners/flink.md | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md index c0d7b37725ac..71abe616f1ad 100644 --- a/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md +++ b/learning/tour-of-beam/learning-content/introduction/introduction-concepts/runner-concepts/description.md @@ -191,8 +191,8 @@ $ wordcount --input gs://dataflow-samples/shakespeare/kinglear.txt \ {{if (eq .Sdk "java")}} ##### Portable -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. -2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.19_job_server:latest` +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`. +2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.18_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: ``` @@ -233,8 +233,8 @@ mvn exec:java -Dexec.mainClass=org.apache.beam.examples.WordCount \ {{end}} {{if (eq .Sdk "python")}} -1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.17`, `Flink 1.18`, `Flink 1.19`. -2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.19_job_server:latest` +1. Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: `Flink 1.16`, `Flink 1.17`, `Flink 1.18`. +2. Start the JobService endpoint: `docker run --net=host apache/beam_flink1.18_job_server:latest` 3. Submit the pipeline to the above endpoint by using the PortableRunner, job_endpoint set to localhost:8099 (this is the default address of the JobService). Optionally set environment_type set to LOOPBACK. For example: ``` diff --git a/website/www/site/content/en/documentation/runners/flink.md b/website/www/site/content/en/documentation/runners/flink.md index fb897805cfd6..e9522d76e832 100644 --- a/website/www/site/content/en/documentation/runners/flink.md +++ b/website/www/site/content/en/documentation/runners/flink.md @@ -93,7 +93,7 @@ from the [compatibility table](#flink-version-compatibility) below. For example: {{< highlight java >}} org.apache.beam - beam-runners-flink-1.19 + beam-runners-flink-1.18 {{< param release_latest >}} {{< /highlight >}} @@ -196,9 +196,9 @@ The optional `flink_version` option may be required as well for older versions o {{< paragraph class="language-portable" >}} Starting with Beam 2.18.0, pre-built Flink Job Service Docker images are available at Docker Hub: +[Flink 1.16](https://hub.docker.com/r/apache/beam_flink1.16_job_server). [Flink 1.17](https://hub.docker.com/r/apache/beam_flink1.17_job_server). [Flink 1.18](https://hub.docker.com/r/apache/beam_flink1.18_job_server). -[Flink 1.19](https://hub.docker.com/r/apache/beam_flink1.19_job_server). {{< /paragraph >}} @@ -311,8 +311,8 @@ reference. ## Flink Version Compatibility The Flink cluster version has to match the minor version used by the FlinkRunner. -The minor version is the first two numbers in the version string, e.g. in `1.19.0` the -minor version is `1.19`. +The minor version is the first two numbers in the version string, e.g. in `1.18.0` the +minor version is `1.18`. We try to track the latest version of Apache Flink at the time of the Beam release. A Flink version is supported by Beam for the time it is supported by the Flink community. From 5a6b10a57b5d0142e3446e1a93f0af1bc0222d8b Mon Sep 17 00:00:00 2001 From: Robert Burke Date: Thu, 24 Oct 2024 14:36:37 -0700 Subject: [PATCH 067/181] [#28187] Add a Java gradle task to run validates runner tests on Prism. (#32919) --- ...beam_PreCommit_Java_PVR_Prism_Loopback.yml | 114 +++++++++ runners/prism/build.gradle | 3 + runners/prism/java/build.gradle | 241 ++++++++++++++++++ .../beam/runners/prism/PrismExecutor.java | 1 + .../beam/runners/prism/PrismLocator.java | 6 + .../runners/prism/PrismPipelineOptions.java | 6 + .../beam/runners/prism/PrismRunner.java | 6 +- .../beam/runners/prism/PrismLocatorTest.java | 5 +- sdks/go/cmd/prism/prism.go | 19 +- 9 files changed, 393 insertions(+), 8 deletions(-) create mode 100644 .github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml diff --git a/.github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml b/.github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml new file mode 100644 index 000000000000..ea5cf9b5578e --- /dev/null +++ b/.github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: PreCommit Java PVR Prism Loopback + +on: + push: + tags: ['v*'] + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/go/cmd/prism/**' + - 'runners/prism/**' + - 'runners/java-fn-execution/**' + - 'sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/**' + - '.github/workflows/beam_PreCommit_Java_PVR_Prism_Loopback.yml' + pull_request_target: + branches: ['master', 'release-*'] + paths: + - 'model/**' + - 'sdks/go/pkg/beam/runners/prism/**' + - 'sdks/go/cmd/prism/**' + - 'runners/prism/**' + - 'runners/java-fn-execution/**' + - 'sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/**' + - 'release/trigger_all_tests.json' + - '.github/trigger_files/beam_PreCommit_Java_PVR_Prism_Loopback.json' + issue_comment: + types: [created] + schedule: + - cron: '22 2/6 * * *' + workflow_dispatch: + +# This allows a subsequently queued workflow run to interrupt previous runs +concurrency: + group: '${{ github.workflow }} @ ${{ github.event.issue.number || github.event.pull_request.head.label || github.sha || github.head_ref || github.ref }}-${{ github.event.schedule || github.event.comment.id || github.event.sender.login }}' + cancel-in-progress: true + +#Setting explicit permissions for the action to avoid the default permissions which are `write-all` in case of pull_request_target event +permissions: + actions: write + pull-requests: write + checks: write + contents: read + deployments: read + id-token: none + issues: write + discussions: read + packages: read + pages: read + repository-projects: read + security-events: read + statuses: read + +env: + DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} + GRADLE_ENTERPRISE_CACHE_USERNAME: ${{ secrets.GE_CACHE_USERNAME }} + GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} + +jobs: + beam_PreCommit_Java_PVR_Prism_Loopback: + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + strategy: + matrix: + job_name: ["beam_PreCommit_Java_PVR_Prism_Loopback"] + job_phrase: ["Run Java_PVR_Prism_Loopback PreCommit"] + timeout-minutes: 240 + runs-on: [self-hosted, ubuntu-20.04] + if: | + github.event_name == 'push' || + github.event_name == 'pull_request_target' || + (github.event_name == 'schedule' && github.repository == 'apache/beam') || + github.event_name == 'workflow_dispatch' || + github.event.comment.body == 'Run Java_PVR_Prism_Loopback PreCommit' + steps: + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + - name: Setup environment + uses: ./.github/actions/setup-environment-action + - name: run prismLoopbackValidatesRunnerTests script + uses: ./.github/actions/gradle-command-self-hosted-action + with: + gradle-command: :runners:prism:java:prismLoopbackValidatesRunnerTests + - name: Archive JUnit Test Results + uses: actions/upload-artifact@v4 + if: ${{ !success() }} + with: + name: JUnit Test Results + path: "**/build/reports/tests/" + - name: Upload test report + uses: actions/upload-artifact@v4 + with: + name: java-code-coverage-report + path: "**/build/test-results/**/*.xml" diff --git a/runners/prism/build.gradle b/runners/prism/build.gradle index 711a1aa2dd75..1009b9856e71 100644 --- a/runners/prism/build.gradle +++ b/runners/prism/build.gradle @@ -42,6 +42,9 @@ ext.set('buildTarget', buildTarget) def buildTask = tasks.named("build") { // goPrepare is a task registered in applyGoNature. dependsOn("goPrepare") + // Allow Go to manage the caching, not gradle. + outputs.cacheIf { false } + outputs.upToDateWhen { false } doLast { exec { workingDir = modDir diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle index de9a30ad8189..f6655900f624 100644 --- a/runners/prism/java/build.gradle +++ b/runners/prism/java/build.gradle @@ -16,6 +16,8 @@ * limitations under the License. */ +import groovy.json.JsonOutput + plugins { id 'org.apache.beam.module' } applyJavaNature( @@ -43,3 +45,242 @@ tasks.test { var prismBuildTask = dependsOn(':runners:prism:build') systemProperty 'prism.buildTarget', prismBuildTask.project.property('buildTarget').toString() } + +// Below is configuration to support running the Java Validates Runner tests. + +configurations { + validatesRunner +} + +dependencies { + implementation project(path: ":sdks:java:core", configuration: "shadow") + implementation library.java.hamcrest + permitUnusedDeclared library.java.hamcrest + implementation library.java.joda_time + implementation library.java.slf4j_api + implementation library.java.vendored_guava_32_1_2_jre + + testImplementation library.java.hamcrest + testImplementation library.java.junit + testImplementation library.java.mockito_core + testImplementation library.java.slf4j_jdk14 + + validatesRunner project(path: ":sdks:java:core", configuration: "shadowTest") + validatesRunner project(path: ":runners:core-java", configuration: "testRuntimeMigration") + validatesRunner project(path: project.path, configuration: "testRuntimeMigration") +} + +project.evaluationDependsOn(":sdks:java:core") +project.evaluationDependsOn(":runners:core-java") + +def sickbayTests = [ + // PortableMetrics doesn't implement "getCommitedOrNull" from Metrics + // Preventing Prism from passing these tests. + // In particular, it doesn't subclass MetricResult with an override, and + // it explicilty passes "false" to commited supported in create. + // + // There is not currently a category for excluding these _only_ in committed mode + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testAllCommittedMetrics', + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testCommittedCounterMetrics', + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testCommittedDistributionMetrics', + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testCommittedStringSetMetrics', + 'org.apache.beam.sdk.metrics.MetricsTest$CommittedMetricTests.testCommittedGaugeMetrics', + + // Triggers / Accumulation modes not yet implemented in prism. + // https://github.com/apache/beam/issues/31438 + 'org.apache.beam.sdk.transforms.CombineTest$WindowingTests.testGlobalCombineWithDefaultsAndTriggers', + 'org.apache.beam.sdk.transforms.CombineTest$BasicTests.testHotKeyCombiningWithAccumulationMode', + 'org.apache.beam.sdk.transforms.windowing.WindowTest.testNoWindowFnDoesNotReassignWindows', + 'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerUsingState', + 'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testCombiningAccumulatingProcessingTime', + 'org.apache.beam.sdk.transforms.GroupByKeyTest$BasicTests.testAfterProcessingTimeContinuationTriggerEarly', + 'org.apache.beam.sdk.transforms.ParDoTest$BundleInvariantsTests.testWatermarkUpdateMidBundle', + 'org.apache.beam.sdk.transforms.ViewTest.testTriggeredLatestSingleton', + // Requires Allowed Lateness, among others. + 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testEventTimeTimerSetWithinAllowedLateness', + 'org.apache.beam.sdk.testing.TestStreamTest.testFirstElementLate', + 'org.apache.beam.sdk.testing.TestStreamTest.testDiscardingMode', + 'org.apache.beam.sdk.testing.TestStreamTest.testEarlyPanesOfWindow', + 'org.apache.beam.sdk.testing.TestStreamTest.testElementsAtAlmostPositiveInfinity', + 'org.apache.beam.sdk.testing.TestStreamTest.testLateDataAccumulating', + 'org.apache.beam.sdk.testing.TestStreamTest.testMultipleStreams', + 'org.apache.beam.sdk.testing.TestStreamTest.testProcessingTimeTrigger', + + // Coding error somehow: short write: reached end of stream after reading 5 bytes; 98 bytes expected + 'org.apache.beam.sdk.testing.TestStreamTest.testMultiStage', + + // Prism not firing sessions correctly (seems to be merging inapppropriately) + 'org.apache.beam.sdk.transforms.CombineTest$WindowingTests.testSessionsCombine', + 'org.apache.beam.sdk.transforms.CombineTest$WindowingTests.testSessionsCombineWithContext', + + // Java side dying during execution. + // https://github.com/apache/beam/issues/32930 + 'org.apache.beam.sdk.transforms.FlattenTest.testFlattenMultipleCoders', + // Stream corruption error java side: failed:java.io.StreamCorruptedException: invalid stream header: 206E6F74 + // Likely due to prism't coder changes. + 'org.apache.beam.sdk.transforms.FlattenTest.testFlattenWithDifferentInputAndOutputCoders2', + + // java.lang.IllegalStateException: Output with tag Tag must have a schema in order to call getRowReceiver + // Ultimately because getRoeReceiver code path SDK side isn't friendly to LengthPrefix wrapping of row coders. + // https://github.com/apache/beam/issues/32931 + 'org.apache.beam.sdk.transforms.ParDoSchemaTest.testReadAndWrite', + 'org.apache.beam.sdk.transforms.ParDoSchemaTest.testReadAndWriteMultiOutput', + 'org.apache.beam.sdk.transforms.ParDoSchemaTest.testReadAndWriteWithSchemaRegistry', + + // Technically these tests "succeed" + // the test is just complaining that an AssertionException isn't a RuntimeException + // + // java.lang.RuntimeException: test error in finalize + 'org.apache.beam.sdk.transforms.ParDoTest$LifecycleTests.testParDoWithErrorInFinishBatch', + // java.lang.RuntimeException: test error in process + 'org.apache.beam.sdk.transforms.ParDoTest$LifecycleTests.testParDoWithErrorInProcessElement', + // java.lang.RuntimeException: test error in initialize + 'org.apache.beam.sdk.transforms.ParDoTest$LifecycleTests.testParDoWithErrorInStartBatch', + + // Only known window fns supported, not general window merging + // Custom window fns not yet implemented in prism. + // https://github.com/apache/beam/issues/31921 + 'org.apache.beam.sdk.transforms.windowing.WindowTest.testMergingCustomWindows', + 'org.apache.beam.sdk.transforms.windowing.WindowTest.testMergingCustomWindowsKeyedCollection', + 'org.apache.beam.sdk.transforms.windowing.WindowTest.testMergingCustomWindowsWithoutCustomWindowTypes', + 'org.apache.beam.sdk.transforms.windowing.WindowingTest.testMergingWindowing', + 'org.apache.beam.sdk.transforms.windowing.WindowingTest.testNonPartitioningWindowing', + 'org.apache.beam.sdk.transforms.GroupByKeyTest$WindowTests.testGroupByKeyMergingWindows', + + // Possibly a different error being hidden behind the main error. + // org.apache.beam.sdk.util.WindowedValue$ValueInGlobalWindow cannot be cast to class java.lang.String + // TODO(https://github.com/apache/beam/issues/29973) + 'org.apache.beam.sdk.transforms.ReshuffleTest.testReshufflePreservesMetadata', + // TODO(https://github.com/apache/beam/issues/31231) + 'org.apache.beam.sdk.transforms.RedistributeTest.testRedistributePreservesMetadata', + + // Prism isn't handling Java's side input views properly. + // https://github.com/apache/beam/issues/32932 + // java.lang.IllegalArgumentException: PCollection with more than one element accessed as a singleton view. + // Consider using Combine.globally().asSingleton() to combine the PCollection into a single value + 'org.apache.beam.sdk.transforms.ViewTest.testDiscardingNonSingletonSideInput', + // java.util.NoSuchElementException: Empty PCollection accessed as a singleton view. + 'org.apache.beam.sdk.transforms.ViewTest.testDiscardingNonSingletonSideInput', + // ava.lang.IllegalArgumentException: Duplicate values for a + 'org.apache.beam.sdk.transforms.ViewTest.testMapSideInputWithNullValuesCatchesDuplicates', + // java.lang.IllegalArgumentException: PCollection with more than one element accessed as a singleton view.... + 'org.apache.beam.sdk.transforms.ViewTest.testNonSingletonSideInput', + // java.util.NoSuchElementException: Empty PCollection accessed as a singleton view. + 'org.apache.beam.sdk.transforms.ViewTest.testEmptySingletonSideInput', + // Prism side encoding error. + // java.lang.IllegalStateException: java.io.EOFException + 'org.apache.beam.sdk.transforms.ViewTest.testSideInputWithNestedIterables', + + // Requires Time Sorted Input + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInput', + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInputWithTestStream', + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInputWithLateDataAndAllowedLateness', + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testTwoRequiresTimeSortedInputWithLateData', + 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInputWithLateData', + + // Timer race condition/ordering issue in Prism. + 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testTwoTimersSettingEachOtherWithCreateAsInputUnbounded', + + // Missing output due to timer skew. + 'org.apache.beam.sdk.transforms.ParDoTest$TimestampTests.testProcessElementSkew', + + // TestStream + BundleFinalization. + // Tests seem to assume individual element bundles from test stream, but prism will aggregate them, preventing + // a subsequent firing. Tests ultimately hang until timeout. + // Either a test problem, or a misunderstanding of how test stream must work problem in prism. + // Biased to test problem, due to how they are constructed. + 'org.apache.beam.sdk.transforms.ParDoTest$BundleFinalizationTests.testBundleFinalization', + 'org.apache.beam.sdk.transforms.ParDoTest$BundleFinalizationTests.testBundleFinalizationWithSideInputs', + + // Filtered by PortableRunner tests. + // Teardown not called in exceptions + // https://github.com/apache/beam/issues/20372 + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundle', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInFinishBundleStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElement', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInProcessElementStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetup', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInSetupStateful', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundle', + 'org.apache.beam.sdk.transforms.ParDoLifecycleTest.testTeardownCalledAfterExceptionInStartBundleStateful', +] + +/** + * Runs Java ValidatesRunner tests against the Prism Runner + * with the specified environment type. + */ +def createPrismValidatesRunnerTask = { name, environmentType -> + Task vrTask = tasks.create(name: name, type: Test, group: "Verification") { + description "PrismRunner Java $environmentType ValidatesRunner suite" + classpath = configurations.validatesRunner + + var prismBuildTask = dependsOn(':runners:prism:build') + systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ + "--runner=TestPrismRunner", + "--experiments=beam_fn_api", + "--defaultEnvironmentType=${environmentType}", + "--prismLogLevel=warn", + "--prismLocation=${prismBuildTask.project.property('buildTarget').toString()}", + "--enableWebUI=false", + ]) + testClassesDirs = files(project(":sdks:java:core").sourceSets.test.output.classesDirs) + useJUnit { + includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' + // Should be run only in a properly configured SDK harness environment + excludeCategories 'org.apache.beam.sdk.testing.UsesExternalService' + excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment' + + // Not yet implemented in Prism + // https://github.com/apache/beam/issues/32211 + excludeCategories 'org.apache.beam.sdk.testing.UsesOnWindowExpiration' + // https://github.com/apache/beam/issues/32929 + excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState' + + // Not supported in Portable Java SDK yet. + // https://github.com/apache/beam/issues?q=is%3Aissue+is%3Aopen+MultimapState + excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState' + } + filter { + // Hangs forever with prism. Put here instead of sickbay to allow sickbay runs to terminate. + // https://github.com/apache/beam/issues/32222 + excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testEventTimeTimerOrderingWithCreate' + + for (String test : sickbayTests) { + excludeTestsMatching test + } + } + } + return vrTask +} + +tasks.register("validatesRunnerSickbay", Test) { + group = "Verification" + description "Validates Prism local runner (Sickbay Tests)" + + var prismBuildTask = dependsOn(':runners:prism:build') + systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ + "--runner=TestPrismRunner", + "--experiments=beam_fn_api", + "--enableWebUI=false", + "--prismLogLevel=warn", + "--prismLocation=${prismBuildTask.project.property('buildTarget').toString()}" + ]) + + classpath = configurations.validatesRunner + testClassesDirs = files(project(":sdks:java:core").sourceSets.test.output.classesDirs) + + filter { + for (String test : sickbayTests) { + includeTestsMatching test + } + } +} + +task prismDockerValidatesRunner { + Task vrTask = createPrismValidatesRunnerTask("prismDockerValidatesRunnerTests", "DOCKER") + vrTask.dependsOn ":sdks:java:container:java8:docker" +} + +task prismLoopbackValidatesRunner { + dependsOn createPrismValidatesRunnerTask("prismLoopbackValidatesRunnerTests", "LOOPBACK") +} diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java index 0f9816337f91..111d937fcbf6 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismExecutor.java @@ -50,6 +50,7 @@ abstract class PrismExecutor { static final String IDLE_SHUTDOWN_TIMEOUT = "-idle_shutdown_timeout=%s"; static final String JOB_PORT_FLAG_TEMPLATE = "-job_port=%s"; static final String SERVE_HTTP_FLAG_TEMPLATE = "-serve_http=%s"; + static final String LOG_LEVEL_FLAG_TEMPLATE = "-log_level=%s"; protected @MonotonicNonNull Process process; protected ExecutorService executorService = Executors.newSingleThreadExecutor(); diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java index 27aea3f64df0..b32f03e78e6a 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismLocator.java @@ -110,6 +110,12 @@ String resolveSource() { String resolve() throws IOException { String from = resolveSource(); + // If the location is set, and it's not an http request or a zip, + // use the binary directly. + if (!from.startsWith("http") && !from.endsWith("zip") && Files.exists(Paths.get(from))) { + return from; + } + String fromFileName = getNameWithoutExtension(from); Path to = Paths.get(userHome(), PRISM_BIN_PATH, fromFileName); diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java index 9b280d0a70d4..ceec1ad8268a 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismPipelineOptions.java @@ -59,4 +59,10 @@ public interface PrismPipelineOptions extends PortablePipelineOptions { String getIdleShutdownTimeout(); void setIdleShutdownTimeout(String idleShutdownTimeout); + + @Description("Sets the log level for Prism. Can be set to 'debug', 'info', 'warn', or 'error'.") + @Default.String("warn") + String getPrismLogLevel(); + + void setPrismLogLevel(String prismLogLevel); } diff --git a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java index 6099db4b63ee..ac1e68237faf 100644 --- a/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java +++ b/runners/prism/java/src/main/java/org/apache/beam/runners/prism/PrismRunner.java @@ -101,13 +101,17 @@ PrismExecutor startPrism() throws IOException { String idleShutdownTimeoutFlag = String.format( PrismExecutor.IDLE_SHUTDOWN_TIMEOUT, prismPipelineOptions.getIdleShutdownTimeout()); + String logLevelFlag = + String.format( + PrismExecutor.LOG_LEVEL_FLAG_TEMPLATE, prismPipelineOptions.getPrismLogLevel()); String endpoint = "localhost:" + port; prismPipelineOptions.setJobEndpoint(endpoint); String command = locator.resolve(); PrismExecutor executor = PrismExecutor.builder() .setCommand(command) - .setArguments(Arrays.asList(portFlag, serveHttpFlag, idleShutdownTimeoutFlag)) + .setArguments( + Arrays.asList(portFlag, serveHttpFlag, idleShutdownTimeoutFlag, logLevelFlag)) .build(); executor.execute(); checkState(executor.isAlive()); diff --git a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java index 095d3c9bde61..fa5ba6d37203 100644 --- a/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java +++ b/runners/prism/java/src/test/java/org/apache/beam/runners/prism/PrismLocatorTest.java @@ -134,7 +134,10 @@ public void givenFilePrismLocationOption_thenResolves() throws IOException { PrismLocator underTest = new PrismLocator(options); String got = underTest.resolve(); - assertThat(got).contains(DESTINATION_DIRECTORY.toString()); + // Local file overrides should use the local binary in place, not copy + // to the cache. Doing so prevents using a locally built version. + assertThat(got).doesNotContain(DESTINATION_DIRECTORY.toString()); + assertThat(got).contains(options.getPrismLocation()); Path gotPath = Paths.get(got); assertThat(Files.exists(gotPath)).isTrue(); } diff --git a/sdks/go/cmd/prism/prism.go b/sdks/go/cmd/prism/prism.go index 070d2f023b74..5e3f42a9e5a5 100644 --- a/sdks/go/cmd/prism/prism.go +++ b/sdks/go/cmd/prism/prism.go @@ -44,10 +44,10 @@ var ( // Logging flags var ( - debug = flag.Bool("debug", false, - "Enables full verbosity debug logging from the runner by default. Used to build SDKs or debug Prism itself.") logKind = flag.String("log_kind", "dev", "Determines the format of prism's logging to std err: valid values are `dev', 'json', or 'text'. Default is `dev`.") + logLevelFlag = flag.String("log_level", "info", + "Sets the minimum log level of Prism. Valid options are 'debug', 'info','warn', and 'error'. Default is 'info'. Debug adds prism source lines.") ) var logLevel = new(slog.LevelVar) @@ -59,13 +59,20 @@ func main() { var logHandler slog.Handler loggerOutput := os.Stderr handlerOpts := &slog.HandlerOptions{ - Level: logLevel, - AddSource: *debug, + Level: logLevel, } - if *debug { + switch strings.ToLower(*logLevelFlag) { + case "debug": logLevel.Set(slog.LevelDebug) - // Print the Prism source line for a log in debug mode. handlerOpts.AddSource = true + case "info": + logLevel.Set(slog.LevelInfo) + case "warn": + logLevel.Set(slog.LevelWarn) + case "error": + logLevel.Set(slog.LevelError) + default: + log.Fatalf("Invalid value for log_level: %v, must be 'debug', 'info', 'warn', or 'error'", *logKind) } switch strings.ToLower(*logKind) { case "dev": From 5cd261ac43e42f2dd05346eeb41b540b3f64b010 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Thu, 24 Oct 2024 18:05:57 -0400 Subject: [PATCH 068/181] Fix DebeziumIO JmsIO PreCommit (#32926) * Fix DebeziumIO JmsIO PreCommit * Update sdks/go/test/integration/io/xlang/debezium/debezium_test.go --------- Co-authored-by: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> --- sdks/go/test/integration/io/xlang/debezium/debezium_test.go | 2 +- .../beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java | 2 +- .../beam/io/debezium/DebeziumReadSchemaTransformTest.java | 2 +- .../jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java | 2 +- sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sdks/go/test/integration/io/xlang/debezium/debezium_test.go b/sdks/go/test/integration/io/xlang/debezium/debezium_test.go index 24c2b513b2b2..208a062f9436 100644 --- a/sdks/go/test/integration/io/xlang/debezium/debezium_test.go +++ b/sdks/go/test/integration/io/xlang/debezium/debezium_test.go @@ -34,7 +34,7 @@ import ( ) const ( - debeziumImage = "debezium/example-postgres:latest" + debeziumImage = "quay.io/debezium/example-postgres:latest" debeziumPort = "5432/tcp" maxRetries = 5 ) diff --git a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java index 2bfa694aebc0..970d9483850c 100644 --- a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java +++ b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumIOPostgresSqlConnectorIT.java @@ -56,7 +56,7 @@ public class DebeziumIOPostgresSqlConnectorIT { @ClassRule public static final PostgreSQLContainer POSTGRES_SQL_CONTAINER = new PostgreSQLContainer<>( - DockerImageName.parse("debezium/example-postgres:latest") + DockerImageName.parse("quay.io/debezium/example-postgres:latest") .asCompatibleSubstituteFor("postgres")) .withPassword("dbz") .withUsername("debezium") diff --git a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumReadSchemaTransformTest.java b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumReadSchemaTransformTest.java index c75621040913..c4b5d2d1f890 100644 --- a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumReadSchemaTransformTest.java +++ b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/DebeziumReadSchemaTransformTest.java @@ -46,7 +46,7 @@ public class DebeziumReadSchemaTransformTest { @ClassRule public static final PostgreSQLContainer POSTGRES_SQL_CONTAINER = new PostgreSQLContainer<>( - DockerImageName.parse("debezium/example-postgres:latest") + DockerImageName.parse("quay.io/debezium/example-postgres:latest") .asCompatibleSubstituteFor("postgres")) .withPassword("dbz") .withUsername("debezium") diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java index eddcb0de5561..266d04342d1f 100644 --- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java +++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOIT.java @@ -215,7 +215,7 @@ public void testPublishingThenReadingAll() throws IOException, JMSException { int unackRecords = countRemain(QUEUE); assertTrue( String.format("Too many unacknowledged messages: %d", unackRecords), - unackRecords < OPTIONS.getNumberOfRecords() * 0.002); + unackRecords < OPTIONS.getNumberOfRecords() * 0.003); // acknowledged records int ackRecords = OPTIONS.getNumberOfRecords() - unackRecords; diff --git a/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py b/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py index 7829baba8b69..abe9530787e8 100644 --- a/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_debeziumio_it_test.py @@ -107,7 +107,7 @@ def start_db_container(self, retries): for i in range(retries): try: self.db = PostgresContainer( - 'debezium/example-postgres:latest', + 'quay.io/debezium/example-postgres:latest', user=self.username, password=self.password, dbname=self.database) From 9c8b79af16053b8392e41f8d2b58495d8d86f662 Mon Sep 17 00:00:00 2001 From: Robert Burke Date: Thu, 24 Oct 2024 16:31:20 -0700 Subject: [PATCH 069/181] Correct some gradle java8's (#32943) Co-authored-by: lostluck <13907733+lostluck@users.noreply.github.com> --- runners/portability/java/build.gradle | 2 +- runners/prism/java/build.gradle | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/runners/portability/java/build.gradle b/runners/portability/java/build.gradle index b684299c3174..a82759b4e4a0 100644 --- a/runners/portability/java/build.gradle +++ b/runners/portability/java/build.gradle @@ -253,7 +253,7 @@ tasks.register("validatesRunnerSickbay", Test) { } task ulrDockerValidatesRunner { - dependsOn createUlrValidatesRunnerTask("ulrDockerValidatesRunnerTests", "DOCKER", ":sdks:java:container:java8:docker") + dependsOn createUlrValidatesRunnerTask("ulrDockerValidatesRunnerTests", "DOCKER", ":sdks:java:container:${project.ext.currentJavaVersion}:docker") } task ulrLoopbackValidatesRunner { diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle index f6655900f624..f2dfa2bb1a28 100644 --- a/runners/prism/java/build.gradle +++ b/runners/prism/java/build.gradle @@ -278,7 +278,7 @@ tasks.register("validatesRunnerSickbay", Test) { task prismDockerValidatesRunner { Task vrTask = createPrismValidatesRunnerTask("prismDockerValidatesRunnerTests", "DOCKER") - vrTask.dependsOn ":sdks:java:container:java8:docker" + vrTask.dependsOn ":sdks:java:container:${project.ext.currentJavaVersion}:docker" } task prismLoopbackValidatesRunner { From 3cc29099924f603e2094e1a246a9449b641dc761 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 24 Oct 2024 16:38:23 -0700 Subject: [PATCH 070/181] Modernize type hints for bundle_processor.py. (#32871) --- .../runners/worker/bundle_processor.py | 777 ++++++++---------- 1 file changed, 360 insertions(+), 417 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 0f1700f52486..89c137fe4366 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -45,6 +45,7 @@ from typing import Iterator from typing import List from typing import Mapping +from typing import MutableMapping from typing import Optional from typing import Set from typing import Tuple @@ -130,18 +131,16 @@ class RunnerIOOperation(operations.Operation): """Common baseclass for runner harness IO operations.""" - - def __init__(self, - name_context, # type: common.NameContext - step_name, # type: Any - consumers, # type: Mapping[Any, Iterable[operations.Operation]] - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - windowed_coder, # type: coders.Coder - transform_id, # type: str - data_channel # type: data_plane.DataChannel - ): - # type: (...) -> None + def __init__( + self, + name_context: common.NameContext, + step_name: Any, + consumers: Mapping[Any, Iterable[operations.Operation]], + counter_factory: counters.CounterFactory, + state_sampler: statesampler.StateSampler, + windowed_coder: coders.Coder, + transform_id: str, + data_channel: data_plane.DataChannel) -> None: super().__init__(name_context, None, counter_factory, state_sampler) self.windowed_coder = windowed_coder self.windowed_coder_impl = windowed_coder.get_impl() @@ -157,36 +156,32 @@ def __init__(self, class DataOutputOperation(RunnerIOOperation): """A sink-like operation that gathers outputs to be sent back to the runner. """ - def set_output_stream(self, output_stream): - # type: (data_plane.ClosableOutputStream) -> None + def set_output_stream( + self, output_stream: data_plane.ClosableOutputStream) -> None: self.output_stream = output_stream - def process(self, windowed_value): - # type: (windowed_value.WindowedValue) -> None + def process(self, windowed_value: windowed_value.WindowedValue) -> None: self.windowed_coder_impl.encode_to_stream( windowed_value, self.output_stream, True) self.output_stream.maybe_flush() - def finish(self): - # type: () -> None + def finish(self) -> None: super().finish() self.output_stream.close() class DataInputOperation(RunnerIOOperation): """A source-like operation that gathers input from the runner.""" - - def __init__(self, - operation_name, # type: common.NameContext - step_name, - consumers, # type: Mapping[Any, List[operations.Operation]] - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - windowed_coder, # type: coders.Coder - transform_id, - data_channel # type: data_plane.GrpcClientDataChannel - ): - # type: (...) -> None + def __init__( + self, + operation_name: common.NameContext, + step_name, + consumers: Mapping[Any, List[operations.Operation]], + counter_factory: counters.CounterFactory, + state_sampler: statesampler.StateSampler, + windowed_coder: coders.Coder, + transform_id, + data_channel: data_plane.GrpcClientDataChannel) -> None: super().__init__( operation_name, step_name, @@ -217,18 +212,15 @@ def setup(self, data_sampler=None): producer_batch_converter=self.get_output_batch_converter()) ] - def start(self): - # type: () -> None + def start(self) -> None: super().start() with self.splitting_lock: self.started = True - def process(self, windowed_value): - # type: (windowed_value.WindowedValue) -> None + def process(self, windowed_value: windowed_value.WindowedValue) -> None: self.output(windowed_value) - def process_encoded(self, encoded_windowed_values): - # type: (bytes) -> None + def process_encoded(self, encoded_windowed_values: bytes) -> None: input_stream = coder_impl.create_InputStream(encoded_windowed_values) while input_stream.size() > 0: with self.splitting_lock: @@ -244,8 +236,9 @@ def process_encoded(self, encoded_windowed_values): str(self.windowed_coder)) from exn self.output(decoded_value) - def monitoring_infos(self, transform_id, tag_to_pcollection_id): - # type: (str, Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] + def monitoring_infos( + self, transform_id: str, tag_to_pcollection_id: Dict[str, str] + ) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]: all_monitoring_infos = super().monitoring_infos( transform_id, tag_to_pcollection_id) read_progress_info = monitoring_infos.int64_counter( @@ -259,8 +252,13 @@ def monitoring_infos(self, transform_id, tag_to_pcollection_id): # TODO(https://github.com/apache/beam/issues/19737): typing not compatible # with super type def try_split( # type: ignore[override] - self, fraction_of_remainder, total_buffer_size, allowed_split_points): - # type: (...) -> Optional[Tuple[int, Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual], int]] + self, fraction_of_remainder, total_buffer_size, allowed_split_points + ) -> Optional[ + Tuple[ + int, + Iterable[operations.SdfSplitResultsPrimary], + Iterable[operations.SdfSplitResultsResidual], + int]]: with self.splitting_lock: if not self.started: return None @@ -314,9 +312,10 @@ def is_valid_split_point(index): # try splitting at the current element. if (keep_of_element_remainder < 1 and is_valid_split_point(index) and is_valid_split_point(index + 1)): - split = try_split( - keep_of_element_remainder - ) # type: Optional[Tuple[Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual]]] + split: Optional[Tuple[ + Iterable[operations.SdfSplitResultsPrimary], + Iterable[operations.SdfSplitResultsResidual]]] = try_split( + keep_of_element_remainder) if split: element_primaries, element_residuals = split return index - 1, element_primaries, element_residuals, index + 1 @@ -343,15 +342,13 @@ def is_valid_split_point(index): else: return None - def finish(self): - # type: () -> None + def finish(self) -> None: super().finish() with self.splitting_lock: self.index += 1 self.started = False - def reset(self): - # type: () -> None + def reset(self) -> None: with self.splitting_lock: self.index = -1 self.stop = float('inf') @@ -359,12 +356,12 @@ def reset(self): class _StateBackedIterable(object): - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - state_key, # type: beam_fn_api_pb2.StateKey - coder_or_impl, # type: Union[coders.Coder, coder_impl.CoderImpl] - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + coder_or_impl: Union[coders.Coder, coder_impl.CoderImpl], + ) -> None: self._state_handler = state_handler self._state_key = state_key if isinstance(coder_or_impl, coders.Coder): @@ -372,8 +369,7 @@ def __init__(self, else: self._coder_impl = coder_or_impl - def __iter__(self): - # type: () -> Iterator[Any] + def __iter__(self) -> Iterator[Any]: return iter( self._state_handler.blocking_get(self._state_key, self._coder_impl)) @@ -391,15 +387,15 @@ class StateBackedSideInputMap(object): _BULK_READ_FULLY = "fully" _BULK_READ_PARTIALLY = "partially" - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - transform_id, # type: str - tag, # type: Optional[str] - side_input_data, # type: pvalue.SideInputData - coder, # type: WindowedValueCoder - use_bulk_read = False, # type: bool - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + transform_id: str, + tag: Optional[str], + side_input_data: pvalue.SideInputData, + coder: WindowedValueCoder, + use_bulk_read: bool = False, + ) -> None: self._state_handler = state_handler self._transform_id = transform_id self._tag = tag @@ -407,7 +403,7 @@ def __init__(self, self._element_coder = coder.wrapped_value_coder self._target_window_coder = coder.window_coder # TODO(robertwb): Limit the cache size. - self._cache = {} # type: Dict[BoundedWindow, Any] + self._cache: Dict[BoundedWindow, Any] = {} self._use_bulk_read = use_bulk_read def __getitem__(self, window): @@ -503,14 +499,12 @@ def __reduce__(self): self._cache[target_window] = self._side_input_data.view_fn(raw_view) return self._cache[target_window] - def is_globally_windowed(self): - # type: () -> bool + def is_globally_windowed(self) -> bool: return ( self._side_input_data.window_mapping_fn == sideinputs._global_window_mapping_fn) - def reset(self): - # type: () -> None + def reset(self) -> None: # TODO(BEAM-5428): Cross-bundle caching respecting cache tokens. self._cache = {} @@ -519,26 +513,28 @@ class ReadModifyWriteRuntimeState(userstate.ReadModifyWriteRuntimeState): def __init__(self, underlying_bag_state): self._underlying_bag_state = underlying_bag_state - def read(self): # type: () -> Any + def read(self) -> Any: values = list(self._underlying_bag_state.read()) if not values: return None return values[0] - def write(self, value): # type: (Any) -> None + def write(self, value: Any) -> None: self.clear() self._underlying_bag_state.add(value) - def clear(self): # type: () -> None + def clear(self) -> None: self._underlying_bag_state.clear() - def commit(self): # type: () -> None + def commit(self) -> None: self._underlying_bag_state.commit() class CombiningValueRuntimeState(userstate.CombiningValueRuntimeState): - def __init__(self, underlying_bag_state, combinefn): - # type: (userstate.AccumulatingRuntimeState, core.CombineFn) -> None + def __init__( + self, + underlying_bag_state: userstate.AccumulatingRuntimeState, + combinefn: core.CombineFn) -> None: self._combinefn = combinefn self._combinefn.setup() self._underlying_bag_state = underlying_bag_state @@ -552,12 +548,10 @@ def _read_accumulator(self, rewrite=True): self._underlying_bag_state.add(merged_accumulator) return merged_accumulator - def read(self): - # type: () -> Iterable[Any] + def read(self) -> Iterable[Any]: return self._combinefn.extract_output(self._read_accumulator()) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: # Prefer blind writes, but don't let them grow unboundedly. # This should be tuned to be much lower, but for now exercise # both paths well. @@ -569,8 +563,7 @@ def add(self, value): self._underlying_bag_state.add( self._combinefn.add_input(accumulator, value)) - def clear(self): - # type: () -> None + def clear(self) -> None: self._underlying_bag_state.clear() def commit(self): @@ -587,13 +580,11 @@ class _ConcatIterable(object): Unlike itertools.chain, this allows reiteration. """ - def __init__(self, first, second): - # type: (Iterable[Any], Iterable[Any]) -> None + def __init__(self, first: Iterable[Any], second: Iterable[Any]) -> None: self.first = first self.second = second - def __iter__(self): - # type: () -> Iterator[Any] + def __iter__(self) -> Iterator[Any]: for elem in self.first: yield elem for elem in self.second: @@ -604,38 +595,32 @@ def __iter__(self): class SynchronousBagRuntimeState(userstate.BagRuntimeState): - - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - state_key, # type: beam_fn_api_pb2.StateKey - value_coder # type: coders.Coder - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + value_coder: coders.Coder) -> None: self._state_handler = state_handler self._state_key = state_key self._value_coder = value_coder self._cleared = False - self._added_elements = [] # type: List[Any] + self._added_elements: List[Any] = [] - def read(self): - # type: () -> Iterable[Any] + def read(self) -> Iterable[Any]: return _ConcatIterable([] if self._cleared else cast( 'Iterable[Any]', _StateBackedIterable( self._state_handler, self._state_key, self._value_coder)), self._added_elements) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: self._added_elements.append(value) - def clear(self): - # type: () -> None + def clear(self) -> None: self._cleared = True self._added_elements = [] - def commit(self): - # type: () -> None + def commit(self) -> None: to_await = None if self._cleared: to_await = self._state_handler.clear(self._state_key) @@ -648,18 +633,16 @@ def commit(self): class SynchronousSetRuntimeState(userstate.SetRuntimeState): - - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - state_key, # type: beam_fn_api_pb2.StateKey - value_coder # type: coders.Coder - ): - # type: (...) -> None + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + state_key: beam_fn_api_pb2.StateKey, + value_coder: coders.Coder) -> None: self._state_handler = state_handler self._state_key = state_key self._value_coder = value_coder self._cleared = False - self._added_elements = set() # type: Set[Any] + self._added_elements: Set[Any] = set() def _compact_data(self, rewrite=True): accumulator = set( @@ -679,12 +662,10 @@ def _compact_data(self, rewrite=True): return accumulator - def read(self): - # type: () -> Set[Any] + def read(self) -> Set[Any]: return self._compact_data(rewrite=False) - def add(self, value): - # type: (Any) -> None + def add(self, value: Any) -> None: if self._cleared: # This is a good time explicitly clear. self._state_handler.clear(self._state_key) @@ -694,13 +675,11 @@ def add(self, value): if random.random() > 0.5: self._compact_data() - def clear(self): - # type: () -> None + def clear(self) -> None: self._cleared = True self._added_elements = set() - def commit(self): - # type: () -> None + def commit(self) -> None: to_await = None if self._cleared: to_await = self._state_handler.clear(self._state_key) @@ -887,16 +866,16 @@ def commit(self) -> None: class OutputTimer(userstate.BaseTimer): - def __init__(self, - key, - window, # type: BoundedWindow - timestamp, # type: timestamp.Timestamp - paneinfo, # type: windowed_value.PaneInfo - time_domain, # type: str - timer_family_id, # type: str - timer_coder_impl, # type: coder_impl.TimerCoderImpl - output_stream # type: data_plane.ClosableOutputStream - ): + def __init__( + self, + key, + window: BoundedWindow, + timestamp: timestamp.Timestamp, + paneinfo: windowed_value.PaneInfo, + time_domain: str, + timer_family_id: str, + timer_coder_impl: coder_impl.TimerCoderImpl, + output_stream: data_plane.ClosableOutputStream): self._key = key self._window = window self._input_timestamp = timestamp @@ -942,15 +921,13 @@ def __init__(self, timer_coder_impl, output_stream=None): class FnApiUserStateContext(userstate.UserStateContext): """Interface for state and timers from SDK to Fn API servicer of state..""" - - def __init__(self, - state_handler, # type: sdk_worker.CachingStateHandler - transform_id, # type: str - key_coder, # type: coders.Coder - window_coder, # type: coders.Coder - ): - # type: (...) -> None - + def __init__( + self, + state_handler: sdk_worker.CachingStateHandler, + transform_id: str, + key_coder: coders.Coder, + window_coder: coders.Coder, + ) -> None: """Initialize a ``FnApiUserStateContext``. Args: @@ -964,11 +941,10 @@ def __init__(self, self._key_coder = key_coder self._window_coder = window_coder # A mapping of {timer_family_id: TimerInfo} - self._timers_info = {} # type: Dict[str, TimerInfo] - self._all_states = {} # type: Dict[tuple, FnApiUserRuntimeStateTypes] + self._timers_info: Dict[str, TimerInfo] = {} + self._all_states: Dict[tuple, FnApiUserRuntimeStateTypes] = {} - def add_timer_info(self, timer_family_id, timer_info): - # type: (str, TimerInfo) -> None + def add_timer_info(self, timer_family_id: str, timer_info: TimerInfo) -> None: self._timers_info[timer_family_id] = timer_info def get_timer( @@ -987,19 +963,15 @@ def get_timer( timer_coder_impl, output_stream) - def get_state(self, *args): - # type: (*Any) -> FnApiUserRuntimeStateTypes + def get_state(self, *args: Any) -> FnApiUserRuntimeStateTypes: state_handle = self._all_states.get(args) if state_handle is None: state_handle = self._all_states[args] = self._create_state(*args) return state_handle - def _create_state(self, - state_spec, # type: userstate.StateSpec - key, - window # type: BoundedWindow - ): - # type: (...) -> FnApiUserRuntimeStateTypes + def _create_state( + self, state_spec: userstate.StateSpec, key, + window: BoundedWindow) -> FnApiUserRuntimeStateTypes: if isinstance(state_spec, (userstate.BagStateSpec, userstate.CombiningValueStateSpec, @@ -1046,13 +1018,11 @@ def _create_state(self, else: raise NotImplementedError(state_spec) - def commit(self): - # type: () -> None + def commit(self) -> None: for state in self._all_states.values(): state.commit() - def reset(self): - # type: () -> None + def reset(self) -> None: for state in self._all_states.values(): state.finalize() self._all_states = {} @@ -1071,14 +1041,12 @@ def wrapper(*args): return wrapper -def only_element(iterable): - # type: (Iterable[T]) -> T +def only_element(iterable: Iterable[T]) -> T: element, = iterable return element -def _environments_compatible(submission, runtime): - # type: (str, str) -> bool +def _environments_compatible(submission: str, runtime: str) -> bool: if submission == runtime: return True if 'rc' in submission and runtime in submission: @@ -1088,8 +1056,8 @@ def _environments_compatible(submission, runtime): return False -def _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor): - # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None +def _verify_descriptor_created_in_a_compatible_env( + process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor) -> None: runtime_sdk = environments.sdk_base_version_capability() for t in process_bundle_descriptor.transforms.values(): @@ -1111,16 +1079,14 @@ def _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor): class BundleProcessor(object): """ A class for processing bundles of elements. """ - - def __init__(self, - runner_capabilities, # type: FrozenSet[str] - process_bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor - state_handler, # type: sdk_worker.CachingStateHandler - data_channel_factory, # type: data_plane.DataChannelFactory - data_sampler=None, # type: Optional[data_sampler.DataSampler] - ): - # type: (...) -> None - + def __init__( + self, + runner_capabilities: FrozenSet[str], + process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, + state_handler: sdk_worker.CachingStateHandler, + data_channel_factory: data_plane.DataChannelFactory, + data_sampler: Optional[data_sampler.DataSampler] = None, + ) -> None: """Initialize a bundle processor. Args: @@ -1136,7 +1102,7 @@ def __init__(self, self.state_handler = state_handler self.data_channel_factory = data_channel_factory self.data_sampler = data_sampler - self.current_instruction_id = None # type: Optional[str] + self.current_instruction_id: Optional[str] = None # Represents whether the SDK is consuming received data. self.consuming_received_data = False @@ -1155,7 +1121,7 @@ def __init__(self, # {(transform_id, timer_family_id): TimerInfo} # The mapping is empty when there is no timer_family_specs in the # ProcessBundleDescriptor. - self.timers_info = {} # type: Dict[Tuple[str, str], TimerInfo] + self.timers_info: Dict[Tuple[str, str], TimerInfo] = {} # TODO(robertwb): Figure out the correct prefix to use for output counters # from StateSampler. @@ -1170,10 +1136,8 @@ def __init__(self, self.splitting_lock = threading.Lock() def create_execution_tree( - self, - descriptor # type: beam_fn_api_pb2.ProcessBundleDescriptor - ): - # type: (...) -> collections.OrderedDict[str, operations.DoOperation] + self, descriptor: beam_fn_api_pb2.ProcessBundleDescriptor + ) -> collections.OrderedDict[str, operations.DoOperation]: transform_factory = BeamTransformFactory( self.runner_capabilities, descriptor, @@ -1192,16 +1156,14 @@ def is_side_input(transform_proto, tag): transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload).side_inputs - pcoll_consumers = collections.defaultdict( - list) # type: DefaultDict[str, List[str]] + pcoll_consumers: DefaultDict[str, List[str]] = collections.defaultdict(list) for transform_id, transform_proto in descriptor.transforms.items(): for tag, pcoll_id in transform_proto.inputs.items(): if not is_side_input(transform_proto, tag): pcoll_consumers[pcoll_id].append(transform_id) @memoize - def get_operation(transform_id): - # type: (str) -> operations.Operation + def get_operation(transform_id: str) -> operations.Operation: transform_consumers = { tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]] for tag, @@ -1218,8 +1180,7 @@ def get_operation(transform_id): # Operations must be started (hence returned) in order. @memoize - def topological_height(transform_id): - # type: (str) -> int + def topological_height(transform_id: str) -> int: return 1 + max([0] + [ topological_height(consumer) for pcoll in descriptor.transforms[transform_id].outputs.values() @@ -1232,18 +1193,18 @@ def topological_height(transform_id): get_operation(transform_id))) for transform_id in sorted( descriptor.transforms, key=topological_height, reverse=True)]) - def reset(self): - # type: () -> None + def reset(self) -> None: self.counter_factory.reset() self.state_sampler.reset() # Side input caches. for op in self.ops.values(): op.reset() - def process_bundle(self, instruction_id): - # type: (str) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool] + def process_bundle( + self, instruction_id: str + ) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool]: - expected_input_ops = [] # type: List[DataInputOperation] + expected_input_ops: List[DataInputOperation] = [] for op in self.ops.values(): if isinstance(op, DataOutputOperation): @@ -1269,9 +1230,10 @@ def process_bundle(self, instruction_id): # both data input and timer input. The data input is identied by # transform_id. The data input is identified by # (transform_id, timer_family_id). - data_channels = collections.defaultdict( - list - ) # type: DefaultDict[data_plane.DataChannel, List[Union[str, Tuple[str, str]]]] + data_channels: DefaultDict[data_plane.DataChannel, + List[Union[str, Tuple[ + str, + str]]]] = collections.defaultdict(list) # Add expected data inputs for each data channel. input_op_by_transform_id = {} @@ -1337,18 +1299,17 @@ def process_bundle(self, instruction_id): self.current_instruction_id = None self.state_sampler.stop_if_still_running() - def finalize_bundle(self): - # type: () -> beam_fn_api_pb2.FinalizeBundleResponse + def finalize_bundle(self) -> beam_fn_api_pb2.FinalizeBundleResponse: for op in self.ops.values(): op.finalize_bundle() return beam_fn_api_pb2.FinalizeBundleResponse() - def requires_finalization(self): - # type: () -> bool + def requires_finalization(self) -> bool: return any(op.needs_finalization() for op in self.ops.values()) - def try_split(self, bundle_split_request): - # type: (beam_fn_api_pb2.ProcessBundleSplitRequest) -> beam_fn_api_pb2.ProcessBundleSplitResponse + def try_split( + self, bundle_split_request: beam_fn_api_pb2.ProcessBundleSplitRequest + ) -> beam_fn_api_pb2.ProcessBundleSplitResponse: split_response = beam_fn_api_pb2.ProcessBundleSplitResponse() with self.splitting_lock: if bundle_split_request.instruction_id != self.current_instruction_id: @@ -1386,20 +1347,18 @@ def try_split(self, bundle_split_request): return split_response - def delayed_bundle_application(self, - op, # type: operations.DoOperation - deferred_remainder # type: SplitResultResidual - ): - # type: (...) -> beam_fn_api_pb2.DelayedBundleApplication + def delayed_bundle_application( + self, op: operations.DoOperation, deferred_remainder: SplitResultResidual + ) -> beam_fn_api_pb2.DelayedBundleApplication: assert op.input_info is not None # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder. (element_and_restriction, current_watermark, deferred_timestamp) = ( deferred_remainder) if deferred_timestamp: assert isinstance(deferred_timestamp, timestamp.Duration) - proto_deferred_watermark = proto_utils.from_micros( - duration_pb2.Duration, - deferred_timestamp.micros) # type: Optional[duration_pb2.Duration] + proto_deferred_watermark: Optional[ + duration_pb2.Duration] = proto_utils.from_micros( + duration_pb2.Duration, deferred_timestamp.micros) else: proto_deferred_watermark = None return beam_fn_api_pb2.DelayedBundleApplication( @@ -1407,29 +1366,26 @@ def delayed_bundle_application(self, application=self.construct_bundle_application( op.input_info, current_watermark, element_and_restriction)) - def bundle_application(self, - op, # type: operations.DoOperation - primary # type: SplitResultPrimary - ): - # type: (...) -> beam_fn_api_pb2.BundleApplication + def bundle_application( + self, op: operations.DoOperation, + primary: SplitResultPrimary) -> beam_fn_api_pb2.BundleApplication: assert op.input_info is not None return self.construct_bundle_application( op.input_info, None, primary.primary_value) - def construct_bundle_application(self, - op_input_info, # type: operations.OpInputInfo - output_watermark, # type: Optional[timestamp.Timestamp] - element - ): - # type: (...) -> beam_fn_api_pb2.BundleApplication + def construct_bundle_application( + self, + op_input_info: operations.OpInputInfo, + output_watermark: Optional[timestamp.Timestamp], + element) -> beam_fn_api_pb2.BundleApplication: transform_id, main_input_tag, main_input_coder, outputs = op_input_info if output_watermark: proto_output_watermark = proto_utils.from_micros( timestamp_pb2.Timestamp, output_watermark.micros) - output_watermarks = { + output_watermarks: Optional[Dict[str, timestamp_pb2.Timestamp]] = { output: proto_output_watermark for output in outputs - } # type: Optional[Dict[str, timestamp_pb2.Timestamp]] + } else: output_watermarks = None return beam_fn_api_pb2.BundleApplication( @@ -1438,9 +1394,7 @@ def construct_bundle_application(self, output_watermarks=output_watermarks, element=main_input_coder.get_impl().encode_nested(element)) - def monitoring_infos(self): - # type: () -> List[metrics_pb2.MonitoringInfo] - + def monitoring_infos(self) -> List[metrics_pb2.MonitoringInfo]: """Returns the list of MonitoringInfos collected processing this bundle.""" # Construct a new dict first to remove duplicates. all_monitoring_infos_dict = {} @@ -1452,8 +1406,7 @@ def monitoring_infos(self): return list(all_monitoring_infos_dict.values()) - def shutdown(self): - # type: () -> None + def shutdown(self) -> None: for op in self.ops.values(): op.teardown() @@ -1474,15 +1427,16 @@ class ExecutionContext: class BeamTransformFactory(object): """Factory for turning transform_protos into executable operations.""" - def __init__(self, - runner_capabilities, # type: FrozenSet[str] - descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor - data_channel_factory, # type: data_plane.DataChannelFactory - counter_factory, # type: counters.CounterFactory - state_sampler, # type: statesampler.StateSampler - state_handler, # type: sdk_worker.CachingStateHandler - data_sampler, # type: Optional[data_sampler.DataSampler] - ): + def __init__( + self, + runner_capabilities: FrozenSet[str], + descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, + data_channel_factory: data_plane.DataChannelFactory, + counter_factory: counters.CounterFactory, + state_sampler: statesampler.StateSampler, + state_handler: sdk_worker.CachingStateHandler, + data_sampler: Optional[data_sampler.DataSampler], + ): self.runner_capabilities = runner_capabilities self.descriptor = descriptor self.data_channel_factory = data_channel_factory @@ -1499,27 +1453,41 @@ def __init__(self, element_coder_impl)) self.data_sampler = data_sampler - _known_urns = { - } # type: Dict[str, Tuple[ConstructorFn, Union[Type[message.Message], Type[bytes], None]]] + _known_urns: Dict[str, + Tuple[ConstructorFn, + Union[Type[message.Message], Type[bytes], + None]]] = {} @classmethod def register_urn( - cls, - urn, # type: str - parameter_type # type: Optional[Type[T]] - ): - # type: (...) -> Callable[[Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]], Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]] + cls, urn: str, parameter_type: Optional[Type[T]] + ) -> Callable[[ + Callable[[ + BeamTransformFactory, + str, + beam_runner_api_pb2.PTransform, + T, + Dict[str, List[operations.Operation]] + ], + operations.Operation] + ], + Callable[[ + BeamTransformFactory, + str, + beam_runner_api_pb2.PTransform, + T, + Dict[str, List[operations.Operation]] + ], + operations.Operation]]: def wrapper(func): cls._known_urns[urn] = func, parameter_type return func return wrapper - def create_operation(self, - transform_id, # type: str - consumers # type: Dict[str, List[operations.Operation]] - ): - # type: (...) -> operations.Operation + def create_operation( + self, transform_id: str, + consumers: Dict[str, List[operations.Operation]]) -> operations.Operation: transform_proto = self.descriptor.transforms[transform_id] if not transform_proto.unique_name: _LOGGER.debug("No unique name set for transform %s" % transform_id) @@ -1529,8 +1497,7 @@ def create_operation(self, transform_proto.spec.payload, parameter_type) return creator(self, transform_id, transform_proto, payload, consumers) - def extract_timers_info(self): - # type: () -> Dict[Tuple[str, str], TimerInfo] + def extract_timers_info(self) -> Dict[Tuple[str, str], TimerInfo]: timers_info = {} for transform_id, transform_proto in self.descriptor.transforms.items(): if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn: @@ -1545,8 +1512,7 @@ def extract_timers_info(self): timer_coder_impl=timer_coder_impl) return timers_info - def get_coder(self, coder_id): - # type: (str) -> coders.Coder + def get_coder(self, coder_id: str) -> coders.Coder: if coder_id not in self.descriptor.coders: raise KeyError("No such coder: %s" % coder_id) coder_proto = self.descriptor.coders[coder_id] @@ -1557,8 +1523,7 @@ def get_coder(self, coder_id): return operation_specs.get_coder_from_spec( json.loads(coder_proto.spec.payload.decode('utf-8'))) - def get_windowed_coder(self, pcoll_id): - # type: (str) -> WindowedValueCoder + def get_windowed_coder(self, pcoll_id: str) -> WindowedValueCoder: coder = self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) # TODO(robertwb): Remove this condition once all runners are consistent. if not isinstance(coder, WindowedValueCoder): @@ -1569,32 +1534,34 @@ def get_windowed_coder(self, pcoll_id): else: return coder - def get_output_coders(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.Coder] + def get_output_coders( + self, transform_proto: beam_runner_api_pb2.PTransform + ) -> Dict[str, coders.Coder]: return { tag: self.get_windowed_coder(pcoll_id) for tag, pcoll_id in transform_proto.outputs.items() } - def get_only_output_coder(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> coders.Coder + def get_only_output_coder( + self, transform_proto: beam_runner_api_pb2.PTransform) -> coders.Coder: return only_element(self.get_output_coders(transform_proto).values()) - def get_input_coders(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.WindowedValueCoder] + def get_input_coders( + self, transform_proto: beam_runner_api_pb2.PTransform + ) -> Dict[str, coders.WindowedValueCoder]: return { tag: self.get_windowed_coder(pcoll_id) for tag, pcoll_id in transform_proto.inputs.items() } - def get_only_input_coder(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> coders.Coder + def get_only_input_coder( + self, transform_proto: beam_runner_api_pb2.PTransform) -> coders.Coder: return only_element(list(self.get_input_coders(transform_proto).values())) - def get_input_windowing(self, transform_proto): - # type: (beam_runner_api_pb2.PTransform) -> Windowing + def get_input_windowing( + self, transform_proto: beam_runner_api_pb2.PTransform) -> Windowing: pcoll_id = only_element(transform_proto.inputs.values()) windowing_strategy_id = self.descriptor.pcollections[ pcoll_id].windowing_strategy_id @@ -1603,12 +1570,10 @@ def get_input_windowing(self, transform_proto): # TODO(robertwb): Update all operations to take these in the constructor. @staticmethod def augment_oldstyle_op( - op, # type: OperationT - step_name, # type: str - consumers, # type: Mapping[str, Iterable[operations.Operation]] - tag_list=None # type: Optional[List[str]] - ): - # type: (...) -> OperationT + op: OperationT, + step_name: str, + consumers: Mapping[str, Iterable[operations.Operation]], + tag_list: Optional[List[str]] = None) -> OperationT: op.step_name = step_name for tag, op_consumers in consumers.items(): for consumer in op_consumers: @@ -1619,13 +1584,11 @@ def augment_oldstyle_op( @BeamTransformFactory.register_urn( DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) def create_source_runner( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> DataInputOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + grpc_port: beam_fn_api_pb2.RemoteGrpcPort, + consumers: Dict[str, List[operations.Operation]]) -> DataInputOperation: output_coder = factory.get_coder(grpc_port.coder_id) return DataInputOperation( @@ -1642,13 +1605,11 @@ def create_source_runner( @BeamTransformFactory.register_urn( DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) def create_sink_runner( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - grpc_port, # type: beam_fn_api_pb2.RemoteGrpcPort - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> DataOutputOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + grpc_port: beam_fn_api_pb2.RemoteGrpcPort, + consumers: Dict[str, List[operations.Operation]]) -> DataOutputOperation: output_coder = factory.get_coder(grpc_port.coder_id) return DataOutputOperation( common.NameContext(transform_proto.unique_name, transform_id), @@ -1663,13 +1624,12 @@ def create_sink_runner( @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_READ_URN, None) def create_source_java( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, parameter, - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.ReadOperation + consumers: Dict[str, + List[operations.Operation]]) -> operations.ReadOperation: # The Dataflow runner harness strips the base64 encoding. source = pickler.loads(base64.b64encode(parameter)) spec = operation_specs.WorkerRead( @@ -1688,13 +1648,12 @@ def create_source_java( @BeamTransformFactory.register_urn( common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload) def create_deprecated_read( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ReadPayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.ReadOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ReadPayload, + consumers: Dict[str, + List[operations.Operation]]) -> operations.ReadOperation: source = iobase.BoundedSource.from_runner_api( parameter.source, factory.context) spec = operation_specs.WorkerRead( @@ -1713,13 +1672,12 @@ def create_deprecated_read( @BeamTransformFactory.register_urn( python_urns.IMPULSE_READ_TRANSFORM, beam_runner_api_pb2.ReadPayload) def create_read_from_impulse_python( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ReadPayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.ImpulseReadOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ReadPayload, + consumers: Dict[str, List[operations.Operation]] +) -> operations.ImpulseReadOperation: return operations.ImpulseReadOperation( common.NameContext(transform_proto.unique_name, transform_id), factory.counter_factory, @@ -1731,12 +1689,11 @@ def create_read_from_impulse_python( @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN, None) def create_dofn_javasdk( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, serialized_fn, - consumers # type: Dict[str, List[operations.Operation]] -): + consumers: Dict[str, List[operations.Operation]]): return _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn) @@ -1820,12 +1777,11 @@ def process(self, element_restriction, *args, **kwargs): common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn, beam_runner_api_pb2.ParDoPayload) def create_process_sized_elements_and_restrictions( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ParDoPayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ParDoPayload, + consumers: Dict[str, List[operations.Operation]]): return _create_pardo_operation( factory, transform_id, @@ -1867,13 +1823,11 @@ def _create_sdf_operation( @BeamTransformFactory.register_urn( common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload) def create_par_do( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.ParDoPayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.DoOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.ParDoPayload, + consumers: Dict[str, List[operations.Operation]]) -> operations.DoOperation: return _create_pardo_operation( factory, transform_id, @@ -1885,14 +1839,13 @@ def create_par_do( def _create_pardo_operation( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, consumers, serialized_fn, - pardo_proto=None, # type: Optional[beam_runner_api_pb2.ParDoPayload] - operation_cls=operations.DoOperation -): + pardo_proto: Optional[beam_runner_api_pb2.ParDoPayload] = None, + operation_cls=operations.DoOperation): if pardo_proto and pardo_proto.side_inputs: input_tags_to_coders = factory.get_input_coders(transform_proto) @@ -1924,9 +1877,8 @@ def _create_pardo_operation( if not dofn_data[-1]: # Windowing not set. if pardo_proto: - other_input_tags = set.union( - set(pardo_proto.side_inputs), - set(pardo_proto.timer_family_specs)) # type: Container[str] + other_input_tags: Container[str] = set.union( + set(pardo_proto.side_inputs), set(pardo_proto.timer_family_specs)) else: other_input_tags = () pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items() @@ -1950,12 +1902,12 @@ def _create_pardo_operation( main_input_coder = found_input_coder if pardo_proto.timer_family_specs or pardo_proto.state_specs: - user_state_context = FnApiUserStateContext( - factory.state_handler, - transform_id, - main_input_coder.key_coder(), - main_input_coder.window_coder - ) # type: Optional[FnApiUserStateContext] + user_state_context: Optional[ + FnApiUserStateContext] = FnApiUserStateContext( + factory.state_handler, + transform_id, + main_input_coder.key_coder(), + main_input_coder.window_coder) else: user_state_context = None else: @@ -1989,12 +1941,13 @@ def _create_pardo_operation( return result -def _create_simple_pardo_operation(factory, # type: BeamTransformFactory - transform_id, - transform_proto, - consumers, - dofn, # type: beam.DoFn - ): +def _create_simple_pardo_operation( + factory: BeamTransformFactory, + transform_id, + transform_proto, + consumers, + dofn: beam.DoFn, +): serialized_fn = pickler.dumps((dofn, (), {}, [], None)) return _create_pardo_operation( factory, transform_id, transform_proto, consumers, serialized_fn) @@ -2004,12 +1957,11 @@ def _create_simple_pardo_operation(factory, # type: BeamTransformFactory common_urns.primitives.ASSIGN_WINDOWS.urn, beam_runner_api_pb2.WindowingStrategy) def create_assign_windows( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - parameter, # type: beam_runner_api_pb2.WindowingStrategy - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + parameter: beam_runner_api_pb2.WindowingStrategy, + consumers: Dict[str, List[operations.Operation]]): class WindowIntoDoFn(beam.DoFn): def __init__(self, windowing): self.windowing = windowing @@ -2036,13 +1988,12 @@ def process( @BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None) def create_identity_dofn( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, parameter, - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.FlattenOperation + consumers: Dict[str, List[operations.Operation]] +) -> operations.FlattenOperation: return factory.augment_oldstyle_op( operations.FlattenOperation( common.NameContext(transform_proto.unique_name, transform_id), @@ -2058,13 +2009,12 @@ def create_identity_dofn( common_urns.combine_components.COMBINE_PER_KEY_PRECOMBINE.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_precombine( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.PGBKCVOperation + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, + List[operations.Operation]]) -> operations.PGBKCVOperation: serialized_combine_fn = pickler.dumps(( beam.CombineFn.from_runner_api(payload.combine_fn, factory.context), [], {})) @@ -2085,12 +2035,11 @@ def create_combine_per_key_precombine( common_urns.combine_components.COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn, beam_runner_api_pb2.CombinePayload) def create_combbine_per_key_merge_accumulators( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'merge') @@ -2099,12 +2048,11 @@ def create_combbine_per_key_merge_accumulators( common_urns.combine_components.COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_extract_outputs( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'extract') @@ -2113,12 +2061,11 @@ def create_combine_per_key_extract_outputs( common_urns.combine_components.COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS.urn, beam_runner_api_pb2.CombinePayload) def create_combine_per_key_convert_to_accumulators( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'convert') @@ -2127,19 +2074,18 @@ def create_combine_per_key_convert_to_accumulators( common_urns.combine_components.COMBINE_GROUPED_VALUES.urn, beam_runner_api_pb2.CombinePayload) def create_combine_grouped_values( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - payload, # type: beam_runner_api_pb2.CombinePayload - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + payload: beam_runner_api_pb2.CombinePayload, + consumers: Dict[str, List[operations.Operation]]): return _create_combine_phase_operation( factory, transform_id, transform_proto, payload, consumers, 'all') def _create_combine_phase_operation( - factory, transform_id, transform_proto, payload, consumers, phase): - # type: (...) -> operations.CombineOperation + factory, transform_id, transform_proto, payload, consumers, + phase) -> operations.CombineOperation: serialized_combine_fn = pickler.dumps(( beam.CombineFn.from_runner_api(payload.combine_fn, factory.context), [], {})) @@ -2158,13 +2104,12 @@ def _create_combine_phase_operation( @BeamTransformFactory.register_urn(common_urns.primitives.FLATTEN.urn, None) def create_flatten( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, payload, - consumers # type: Dict[str, List[operations.Operation]] -): - # type: (...) -> operations.FlattenOperation + consumers: Dict[str, List[operations.Operation]] +) -> operations.FlattenOperation: return factory.augment_oldstyle_op( operations.FlattenOperation( common.NameContext(transform_proto.unique_name, transform_id), @@ -2179,12 +2124,11 @@ def create_flatten( @BeamTransformFactory.register_urn( common_urns.primitives.MAP_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) def create_map_windows( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, + consumers: Dict[str, List[operations.Operation]]): assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOW_MAPPING_FN window_mapping_fn = pickler.loads(mapping_fn_spec.payload) @@ -2200,12 +2144,11 @@ def process(self, element): @BeamTransformFactory.register_urn( common_urns.primitives.MERGE_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec) def create_merge_windows( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, + consumers: Dict[str, List[operations.Operation]]): assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOWFN window_fn = pickler.loads(mapping_fn_spec.payload) @@ -2213,24 +2156,25 @@ class MergeWindows(beam.DoFn): def process(self, element): nonce, windows = element - original_windows = set(windows) # type: Set[window.BoundedWindow] - merged_windows = collections.defaultdict( - set - ) # type: MutableMapping[window.BoundedWindow, Set[window.BoundedWindow]] # noqa: F821 + original_windows: Set[window.BoundedWindow] = set(windows) + merged_windows: MutableMapping[ + window.BoundedWindow, + Set[window.BoundedWindow]] = collections.defaultdict( + set) # noqa: F821 class RecordingMergeContext(window.WindowFn.MergeContext): def merge( self, - to_be_merged, # type: Iterable[window.BoundedWindow] - merge_result, # type: window.BoundedWindow - ): + to_be_merged: Iterable[window.BoundedWindow], + merge_result: window.BoundedWindow, + ): originals = merged_windows[merge_result] - for window in to_be_merged: - if window in original_windows: - originals.add(window) - original_windows.remove(window) + for w in to_be_merged: + if w in original_windows: + originals.add(w) + original_windows.remove(w) else: - originals.update(merged_windows.pop(window)) + originals.update(merged_windows.pop(w)) window_fn.merge(RecordingMergeContext(windows)) yield nonce, (original_windows, merged_windows.items()) @@ -2241,12 +2185,11 @@ def merge( @BeamTransformFactory.register_urn(common_urns.primitives.TO_STRING.urn, None) def create_to_string_fn( - factory, # type: BeamTransformFactory - transform_id, # type: str - transform_proto, # type: beam_runner_api_pb2.PTransform - mapping_fn_spec, # type: beam_runner_api_pb2.FunctionSpec - consumers # type: Dict[str, List[operations.Operation]] -): + factory: BeamTransformFactory, + transform_id: str, + transform_proto: beam_runner_api_pb2.PTransform, + mapping_fn_spec: beam_runner_api_pb2.FunctionSpec, + consumers: Dict[str, List[operations.Operation]]): class ToString(beam.DoFn): def process(self, element): key, value = element From 2dfa281f04df4a00f669587e511a17e0e8c20912 Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com> Date: Fri, 25 Oct 2024 22:19:38 +0300 Subject: [PATCH 071/181] dont swallow errors (#32946) --- .../io/gcp/bigquery/WriteBundlesToFiles.java | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java index 9d84abbbbf1a..8c6893ef5798 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteBundlesToFiles.java @@ -297,19 +297,15 @@ public void finishBundle(FinishBundleContext c) throws Exception { } for (Map.Entry> entry : writers.entrySet()) { - try { - DestinationT destination = entry.getKey(); - BigQueryRowWriter writer = entry.getValue(); - BigQueryRowWriter.Result result = writer.getResult(); - BoundedWindow window = writerWindows.get(destination); - Preconditions.checkStateNotNull(window); - c.output( - new Result<>(result.resourceId.toString(), result.byteSize, destination), - window.maxTimestamp(), - window); - } catch (Exception e) { - exceptionList.add(e); - } + DestinationT destination = entry.getKey(); + BigQueryRowWriter writer = entry.getValue(); + BigQueryRowWriter.Result result = writer.getResult(); + BoundedWindow window = writerWindows.get(destination); + Preconditions.checkStateNotNull(window); + c.output( + new Result<>(result.resourceId.toString(), result.byteSize, destination), + window.maxTimestamp(), + window); } writers.clear(); } From 6e3e70dd0837d74991d2fba6e6f87f3fd3e83d73 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Fri, 25 Oct 2024 16:17:18 -0400 Subject: [PATCH 072/181] Fix GCSUtils IT after Gcs gRPC launch (#32927) --- .../trigger_files/beam_PostCommit_Java.json | 1 + .../google-cloud-platform-core/build.gradle | 2 -- .../sdk/extensions/gcp/util/GcsUtilIT.java | 26 +------------------ 3 files changed, 2 insertions(+), 27 deletions(-) create mode 100644 .github/trigger_files/beam_PostCommit_Java.json diff --git a/.github/trigger_files/beam_PostCommit_Java.json b/.github/trigger_files/beam_PostCommit_Java.json new file mode 100644 index 000000000000..9e26dfeeb6e6 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/sdks/java/extensions/google-cloud-platform-core/build.gradle b/sdks/java/extensions/google-cloud-platform-core/build.gradle index 6cb8d3248ac1..8d21df50006b 100644 --- a/sdks/java/extensions/google-cloud-platform-core/build.gradle +++ b/sdks/java/extensions/google-cloud-platform-core/build.gradle @@ -66,12 +66,10 @@ task integrationTestKms(type: Test) { group = "Verification" def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing' def gcpTempRoot = project.findProperty('gcpTempRootKms') ?: 'gs://temp-storage-for-end-to-end-tests-cmek' - def gcpGrpcTempRoot = project.findProperty('gcpGrpcTempRoot') ?: 'gs://gcs-grpc-team-apache-beam-testing' def dataflowKmsKey = project.findProperty('dataflowKmsKey') ?: "projects/apache-beam-testing/locations/global/keyRings/beam-it/cryptoKeys/test" systemProperty "beamTestPipelineOptions", JsonOutput.toJson([ "--project=${gcpProject}", "--tempRoot=${gcpTempRoot}", - "--grpcTempRoot=${gcpGrpcTempRoot}", "--dataflowKmsKey=${dataflowKmsKey}", ]) diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilIT.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilIT.java index 6f1e0e985c24..6477564f01a1 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilIT.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/GcsUtilIT.java @@ -21,7 +21,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; import com.google.protobuf.ByteString; import java.io.IOException; @@ -35,10 +34,7 @@ import org.apache.beam.sdk.extensions.gcp.util.GcsUtil.CreateOptions; import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath; import org.apache.beam.sdk.io.FileSystems; -import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.ExperimentalOptions; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.sdk.testing.UsesKms; @@ -99,8 +95,6 @@ public void testWriteAndReadGcsWithGrpc() throws IOException { "%s/GcsUtilIT-%tF-% writeGcsTextFile(gcsUtil, wrongFilename, testContent)); - // Write a test file in a bucket with gRPC enabled. - GcsGrpcOptions grpcOptions = options.as(GcsGrpcOptions.class); - assertNotNull(grpcOptions.getGrpcTempRoot()); - String tempLocationWithGrpc = grpcOptions.getGrpcTempRoot() + "/temp"; + String tempLocationWithGrpc = options.getTempRoot() + "/temp"; String filename = String.format(outputPattern, tempLocationWithGrpc, new Date()); writeGcsTextFile(gcsUtil, filename, testContent); @@ -132,15 +117,6 @@ public void testWriteAndReadGcsWithGrpc() throws IOException { gcsUtil.remove(Collections.singletonList(filename)); } - public interface GcsGrpcOptions extends PipelineOptions { - /** Get tempRoot in a gRPC-enabled bucket. */ - @Description("TempRoot in a gRPC-enabled bucket") - String getGrpcTempRoot(); - - /** Set the tempRoot in a gRPC-enabled bucket. */ - void setGrpcTempRoot(String grpcTempRoot); - } - void writeGcsTextFile(GcsUtil gcsUtil, String filename, String content) throws IOException { GcsPath gcsPath = GcsPath.fromUri(filename); try (WritableByteChannel channel = From 8cd3e5411ce4316620536c06ec01a2a98cfc5110 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Fri, 25 Oct 2024 16:30:34 -0400 Subject: [PATCH 073/181] Fix a few notebook typos (#32947) --- examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb | 2 +- .../notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb | 2 +- examples/notebooks/beam-ml/run_inference_vllm.ipynb | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb index 331ecb9ba93d..9326ed4db7a3 100644 --- a/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb +++ b/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb @@ -573,7 +573,7 @@ "In this example, you create two handlers:\n", "\n", "* One for customer data that specifies `table_name` and `row_restriction_template`\n", - "* One for for usage data that uses a custom aggregation query by using the `query_fn` function\n", + "* One for usage data that uses a custom aggregation query by using the `query_fn` function\n", "\n", "These handlers are used in the Enrichment transforms in this pipeline to fetch and join data from BigQuery with the streaming data." ], diff --git a/examples/notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb b/examples/notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb index cc31ff678fe4..aae86e31aa44 100644 --- a/examples/notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb +++ b/examples/notebooks/beam-ml/rag_usecase/opensearch_rag_pipeline.ipynb @@ -209,7 +209,7 @@ "\n", "3. Create the index.\n", "\n", - "4. Index creation is neeeded only once." + "4. Index creation is needed only once." ] }, { diff --git a/examples/notebooks/beam-ml/run_inference_vllm.ipynb b/examples/notebooks/beam-ml/run_inference_vllm.ipynb index 008c4262d5ce..fea953bc1e66 100644 --- a/examples/notebooks/beam-ml/run_inference_vllm.ipynb +++ b/examples/notebooks/beam-ml/run_inference_vllm.ipynb @@ -407,7 +407,7 @@ "# Set the Google Cloud region that you want to run Dataflow in.\n", "options.view_as(GoogleCloudOptions).region = 'us-central1'\n", "\n", - "# IMPORTANT: Replace BUCKET_NAME with the the name of your Cloud Storage bucket.\n", + "# IMPORTANT: Replace BUCKET_NAME with the name of your Cloud Storage bucket.\n", "dataflow_gcs_location = \"gs://%s/dataflow\" % BUCKET_NAME\n", "\n", "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n", From be7537a2abf77e834e4e7f4e68bb52a303c09bf6 Mon Sep 17 00:00:00 2001 From: martin trieu Date: Mon, 28 Oct 2024 04:14:50 -0600 Subject: [PATCH 074/181] don't log missing computation state as error (#32942) --- .../dataflow/worker/streaming/ComputationStateCache.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java index 199ad26aed00..4b4acb73f4a7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCache.java @@ -147,8 +147,10 @@ public Optional get(String computationId) { | ComputationStateNotFoundException e) { if (e.getCause() instanceof ComputationStateNotFoundException || e instanceof ComputationStateNotFoundException) { - LOG.error( - "Trying to fetch unknown computation={}, known computations are {}.", + LOG.warn( + "Computation {} is currently unknown, " + + "known computations are {}. " + + "This is transient and will get retried.", computationId, ImmutableSet.copyOf(computationCache.asMap().keySet())); } else { From 34eedad8d2932fbcf8e2a2dad76f2cf9d65b4d11 Mon Sep 17 00:00:00 2001 From: Tan Le Date: Mon, 28 Oct 2024 23:45:06 +1000 Subject: [PATCH 075/181] Add Python 3.12 to documentation (#32951) --- website/www/site/content/en/documentation/programming-guide.md | 2 +- website/www/site/content/en/get-started/quickstart-py.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/website/www/site/content/en/documentation/programming-guide.md b/website/www/site/content/en/documentation/programming-guide.md index f4058e604288..df6907f672f4 100644 --- a/website/www/site/content/en/documentation/programming-guide.md +++ b/website/www/site/content/en/documentation/programming-guide.md @@ -35,7 +35,7 @@ programming guide, take a look at the {{< language-switcher java py go typescript yaml >}} {{< paragraph class="language-py" >}} -The Python SDK supports Python 3.8, 3.9, 3.10, and 3.11. +The Python SDK supports Python 3.8, 3.9, 3.10, 3.11, and 3.12. {{< /paragraph >}} {{< paragraph class="language-go">}} diff --git a/website/www/site/content/en/get-started/quickstart-py.md b/website/www/site/content/en/get-started/quickstart-py.md index 3428f5346e02..d7f896153483 100644 --- a/website/www/site/content/en/get-started/quickstart-py.md +++ b/website/www/site/content/en/get-started/quickstart-py.md @@ -23,7 +23,7 @@ If you're interested in contributing to the Apache Beam Python codebase, see the {{< toc >}} -The Python SDK supports Python 3.8, 3.9, 3.10 and 3.11. Beam 2.48.0 was the last release with support for Python 3.7. +The Python SDK supports Python 3.8, 3.9, 3.10, 3.11 and 3.12. Beam 2.48.0 was the last release with support for Python 3.7. ## Set up your environment From 0e2f3175020a23b3844523addfe72de06ffb30fc Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 28 Oct 2024 13:46:40 -0400 Subject: [PATCH 076/181] Remove duplicate button from notebook (#32954) --- .../beam-ml/bigquery_enrichment_transform.ipynb | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb index 9326ed4db7a3..182b88b9c72a 100644 --- a/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb +++ b/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb @@ -15,16 +15,6 @@ } }, "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, { "cell_type": "code", "source": [ @@ -778,4 +768,4 @@ "outputs": [] } ] -} \ No newline at end of file +} From 3332914560a2598748819bca1522accbdf4337c9 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 28 Oct 2024 13:50:16 -0400 Subject: [PATCH 077/181] Fix gemma notebook formatting (#32955) --- .../beam-ml/gemma_2_sentiment_and_summarization.ipynb | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb b/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb index d7b2b157f613..686c19da7f66 100644 --- a/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb +++ b/examples/notebooks/beam-ml/gemma_2_sentiment_and_summarization.ipynb @@ -174,7 +174,8 @@ "WORKDIR /workspace\n", "\n", "COPY gemma2 gemma2\n", - "RUN apt-get update -y && apt-get install -y cmake && apt-get install -y vim" + "RUN apt-get update -y && apt-get install -y cmake && apt-get install -y vim\n", + "```" ] }, { @@ -208,7 +209,8 @@ "apache_beam[gcp]==2.54.0\n", "keras_nlp==0.14.3\n", "keras==3.4.1\n", - "jax[cuda12]" + "jax[cuda12]\n", + "```" ] }, { @@ -261,7 +263,8 @@ "\n", "\n", "# Set the entrypoint to the Apache Beam SDK launcher.\n", - "ENTRYPOINT [\"/opt/apache/beam/boot\"]" + "ENTRYPOINT [\"/opt/apache/beam/boot\"]\n", + "```" ] }, { From 5c7fed2b77eedc535d0c0a4f8364e11b03c6a12f Mon Sep 17 00:00:00 2001 From: liferoad Date: Mon, 28 Oct 2024 16:43:54 -0400 Subject: [PATCH 078/181] updated the 2024 discussion doc --- contributor-docs/discussion-docs/2024.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/contributor-docs/discussion-docs/2024.md b/contributor-docs/discussion-docs/2024.md index baea7c9fc462..15e15963e1bd 100644 --- a/contributor-docs/discussion-docs/2024.md +++ b/contributor-docs/discussion-docs/2024.md @@ -35,11 +35,12 @@ limitations under the License. | 18 | Danny McCormick | [GSoC Proposal : Implement RAG Pipelines using Beam](https://docs.google.com/document/d/1M_8fvqKVBi68hQo_x1AMQ8iEkzeXTcSl0CwTH00cr80) | 2024-05-01 16:12:23 | | 19 | Kenneth Knowles | [DRAFT - Apache Beam Board Report - June 2024](https://s.apache.org/beam-draft-report-2024-06) | 2024-05-23 14:57:16 | | 20 | Jack McCluskey | [Embeddings in MLTransform](https://docs.google.com/document/d/1En4bfbTu4rvu7LWJIKV3G33jO-xJfTdbaSFSURmQw_s) | 2024-05-29 10:26:47 | -| 21 | Bartosz Zab��ocki | [[External] Solace IO - Read Connector](https://docs.google.com/document/d/1Gvq67VrcHCnlO8f_NzMM1Y4c7wCNSdvo6qqLWg8upfw) | 2024-05-29 12:00:23 | +| 21 | Bartosz Zabłocki | [[External] Solace IO - Read Connector](https://docs.google.com/document/d/1Gvq67VrcHCnlO8f_NzMM1Y4c7wCNSdvo6qqLWg8upfw) | 2024-05-29 12:00:23 | | 22 | Danny McCormick | [RunInference Timeouts](https://docs.google.com/document/d/19ves6iv-m_6DFmePJZqYpLm-bCooPu6wQ-Ti6kAl2Jo) | 2024-08-07 07:11:38 | | 23 | Jack McCluskey | [BatchElements in Beam Python](https://docs.google.com/document/d/1fOjIjIUH5dxllOGp5Z4ZmpM7BJhAJc2-hNjTnyChvgc) | 2024-08-15 14:56:26 | | 24 | XQ Hu | [[Public] Beam 3.0: a discussion doc](https://docs.google.com/document/d/13r4NvuvFdysqjCTzMHLuUUXjKTIEY3d7oDNIHT6guww) | 2024-08-19 17:17:26 | | 25 | Danny McCormick | [Beam Patch Release Process](https://docs.google.com/document/d/1o4UK444hCm1t5KZ9ufEu33e_o400ONAehXUR9A34qc8) | 2024-08-23 04:51:48 | | 26 | Jack McCluskey | [Beam Python Type Hinting](https://s.apache.org/beam-python-type-hinting-overview) | 2024-08-26 14:16:42 | | 27 | Ahmed Abualsaud | [Python Multi-language with SchemaTransforms](https://docs.google.com/document/d/1_embA3pGwoYG7sbHaYzAkg3hNxjTughhFCY8ThcoK_Q) | 2024-08-26 19:53:10 | -| 28 | Kenneth Knowles | [DRAFT - Apache Beam Board Report - September 2024](https://s.apache.org/beam-draft-report-2024-09) | 2024-09-11 15:01:55 | \ No newline at end of file +| 28 | Kenneth Knowles | [DRAFT - Apache Beam Board Report - September 2024](https://s.apache.org/beam-draft-report-2024-09) | 2024-09-11 15:01:55 | +| 29 | Jeff Kinard | [Beam YA(ML)^2 ](https://docs.google.com/document/d/1z9lNlSBfqDVdOP1frJNv_NJoMR1F1VBI29wn788x6IE/) | 2024-09-11 15:01:55 | From 8eae0d386d75637fbc592ee08184906e1477559e Mon Sep 17 00:00:00 2001 From: liferoad Date: Mon, 28 Oct 2024 16:44:34 -0400 Subject: [PATCH 079/181] updated the title --- contributor-docs/discussion-docs/2024.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contributor-docs/discussion-docs/2024.md b/contributor-docs/discussion-docs/2024.md index 15e15963e1bd..124fe8ef9bb7 100644 --- a/contributor-docs/discussion-docs/2024.md +++ b/contributor-docs/discussion-docs/2024.md @@ -43,4 +43,4 @@ limitations under the License. | 26 | Jack McCluskey | [Beam Python Type Hinting](https://s.apache.org/beam-python-type-hinting-overview) | 2024-08-26 14:16:42 | | 27 | Ahmed Abualsaud | [Python Multi-language with SchemaTransforms](https://docs.google.com/document/d/1_embA3pGwoYG7sbHaYzAkg3hNxjTughhFCY8ThcoK_Q) | 2024-08-26 19:53:10 | | 28 | Kenneth Knowles | [DRAFT - Apache Beam Board Report - September 2024](https://s.apache.org/beam-draft-report-2024-09) | 2024-09-11 15:01:55 | -| 29 | Jeff Kinard | [Beam YA(ML)^2 ](https://docs.google.com/document/d/1z9lNlSBfqDVdOP1frJNv_NJoMR1F1VBI29wn788x6IE/) | 2024-09-11 15:01:55 | +| 29 | Jeff Kinard | [Beam YA(ML)^2](https://docs.google.com/document/d/1z9lNlSBfqDVdOP1frJNv_NJoMR1F1VBI29wn788x6IE/) | 2024-09-11 15:01:55 | From 6b772b28e84d910ce2470db69850590e5c7a30d8 Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Tue, 29 Oct 2024 10:10:16 -0400 Subject: [PATCH 080/181] Remove remaining Python 3.8 Artifacts (#32913) * Remove remaining Python 3.8 Artifacts * Clear container files * Move TensorRT suite to 3.10 --- .../workflows/update_python_dependencies.yml | 1 - CHANGES.md | 1 + build.gradle.kts | 21 +- .../runners/dataflow/internal/apiclient.py | 2 +- .../dataflow/internal/apiclient_test.py | 16 +- sdks/python/container/build.gradle | 2 +- .../py38/base_image_requirements.txt | 172 -------------- sdks/python/container/py38/build.gradle | 28 --- .../python/test-suites/dataflow/common.gradle | 4 +- .../test-suites/dataflow/py38/build.gradle | 24 -- .../test-suites/direct/py38/build.gradle | 24 -- .../test-suites/portable/py38/build.gradle | 26 -- sdks/python/test-suites/tox/py38/build.gradle | 224 ------------------ 13 files changed, 27 insertions(+), 518 deletions(-) delete mode 100644 sdks/python/container/py38/base_image_requirements.txt delete mode 100644 sdks/python/container/py38/build.gradle delete mode 100644 sdks/python/test-suites/dataflow/py38/build.gradle delete mode 100644 sdks/python/test-suites/direct/py38/build.gradle delete mode 100644 sdks/python/test-suites/portable/py38/build.gradle delete mode 100644 sdks/python/test-suites/tox/py38/build.gradle diff --git a/.github/workflows/update_python_dependencies.yml b/.github/workflows/update_python_dependencies.yml index a91aff39f29a..0ab52e97b9f0 100644 --- a/.github/workflows/update_python_dependencies.yml +++ b/.github/workflows/update_python_dependencies.yml @@ -56,7 +56,6 @@ jobs: uses: ./.github/actions/setup-environment-action with: python-version: | - 3.8 3.9 3.10 3.11 diff --git a/CHANGES.md b/CHANGES.md index f873455cd66e..1a9d2045cbf6 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -85,6 +85,7 @@ ## Deprecations * Removed support for Flink 1.15 and 1.16 +* Removed support for Python 3.8 * X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). ## Bugfixes diff --git a/build.gradle.kts b/build.gradle.kts index 38b58b6979ee..d96e77a4c78c 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -501,23 +501,11 @@ tasks.register("pythonFormatterPreCommit") { dependsOn("sdks:python:test-suites:tox:pycommon:formatter") } -tasks.register("python38PostCommit") { - dependsOn(":sdks:python:test-suites:dataflow:py38:postCommitIT") - dependsOn(":sdks:python:test-suites:direct:py38:postCommitIT") - dependsOn(":sdks:python:test-suites:direct:py38:hdfsIntegrationTest") - dependsOn(":sdks:python:test-suites:direct:py38:azureIntegrationTest") - dependsOn(":sdks:python:test-suites:portable:py38:postCommitPy38") - // TODO: https://github.com/apache/beam/issues/22651 - // The default container uses Python 3.8. The goal here is to - // duild Docker images for TensorRT tests during run time for python versions - // other than 3.8 and add these tests in other python postcommit suites. - dependsOn(":sdks:python:test-suites:dataflow:py38:inferencePostCommitIT") - dependsOn(":sdks:python:test-suites:direct:py38:inferencePostCommitIT") -} - tasks.register("python39PostCommit") { dependsOn(":sdks:python:test-suites:dataflow:py39:postCommitIT") dependsOn(":sdks:python:test-suites:direct:py39:postCommitIT") + dependsOn(":sdks:python:test-suites:direct:py39:hdfsIntegrationTest") + dependsOn(":sdks:python:test-suites:direct:py39:azureIntegrationTest") dependsOn(":sdks:python:test-suites:portable:py39:postCommitPy39") // TODO (https://github.com/apache/beam/issues/23966) // Move this to Python 3.10 test suite once tfx-bsl has python 3.10 wheel. @@ -528,6 +516,11 @@ tasks.register("python310PostCommit") { dependsOn(":sdks:python:test-suites:dataflow:py310:postCommitIT") dependsOn(":sdks:python:test-suites:direct:py310:postCommitIT") dependsOn(":sdks:python:test-suites:portable:py310:postCommitPy310") + // TODO: https://github.com/apache/beam/issues/22651 + // The default container uses Python 3.10. The goal here is to + // duild Docker images for TensorRT tests during run time for python versions + // other than 3.10 and add these tests in other python postcommit suites. + dependsOn(":sdks:python:test-suites:dataflow:py310:inferencePostCommitIT") } tasks.register("python311PostCommit") { diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index 20cae582f320..97996bd6cbb2 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -82,7 +82,7 @@ _LOGGER = logging.getLogger(__name__) -_PYTHON_VERSIONS_SUPPORTED_BY_DATAFLOW = ['3.8', '3.9', '3.10', '3.11', '3.12'] +_PYTHON_VERSIONS_SUPPORTED_BY_DATAFLOW = ['3.9', '3.10', '3.11', '3.12'] class Environment(object): diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 022136aae9a2..6587e619a500 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -1003,7 +1003,21 @@ def test_interpreter_version_check_passes_with_experiment(self): 'apache_beam.runners.dataflow.internal.apiclient.' 'beam_version.__version__', '2.2.0') - def test_interpreter_version_check_passes_py38(self): + def test_interpreter_version_check_fails_py38(self): + pipeline_options = PipelineOptions([]) + self.assertRaises( + Exception, + apiclient._verify_interpreter_version_is_supported, + pipeline_options) + + @mock.patch( + 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', + (3, 9, 6)) + @mock.patch( + 'apache_beam.runners.dataflow.internal.apiclient.' + 'beam_version.__version__', + '2.2.0') + def test_interpreter_version_check_passes_py39(self): pipeline_options = PipelineOptions([]) apiclient._verify_interpreter_version_is_supported(pipeline_options) diff --git a/sdks/python/container/build.gradle b/sdks/python/container/build.gradle index f07b6f743fa4..14c08a3a539b 100644 --- a/sdks/python/container/build.gradle +++ b/sdks/python/container/build.gradle @@ -20,7 +20,7 @@ plugins { id 'org.apache.beam.module' } applyGoNature() description = "Apache Beam :: SDKs :: Python :: Container" -int min_python_version=8 +int min_python_version=9 int max_python_version=12 configurations { diff --git a/sdks/python/container/py38/base_image_requirements.txt b/sdks/python/container/py38/base_image_requirements.txt deleted file mode 100644 index 0a67a3666d25..000000000000 --- a/sdks/python/container/py38/base_image_requirements.txt +++ /dev/null @@ -1,172 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Autogenerated requirements file for Apache Beam py38 container image. -# Run ./gradlew :sdks:python:container:generatePythonRequirementsAll to update. -# Do not edit manually, adjust ../base_image_requirements_manual.txt or -# Apache Beam's setup.py instead, and regenerate the list. -# You will need Python interpreters for all versions supported by Beam, see: -# https://s.apache.org/beam-python-dev-wiki -# Reach out to a committer if you need help. - -annotated-types==0.7.0 -async-timeout==4.0.3 -attrs==24.2.0 -backports.tarfile==1.2.0 -beautifulsoup4==4.12.3 -bs4==0.0.2 -build==1.2.2 -cachetools==5.5.0 -certifi==2024.8.30 -cffi==1.17.1 -charset-normalizer==3.3.2 -click==8.1.7 -cloudpickle==2.2.1 -cramjam==2.8.4 -crcmod==1.7 -cryptography==43.0.1 -Cython==3.0.11 -Deprecated==1.2.14 -deprecation==2.1.0 -dill==0.3.1.1 -dnspython==2.6.1 -docker==7.1.0 -docopt==0.6.2 -docstring_parser==0.16 -exceptiongroup==1.2.2 -execnet==2.1.1 -fastavro==1.9.7 -fasteners==0.19 -freezegun==1.5.1 -future==1.0.0 -google-api-core==2.20.0 -google-api-python-client==2.147.0 -google-apitools==0.5.31 -google-auth==2.35.0 -google-auth-httplib2==0.2.0 -google-cloud-aiplatform==1.69.0 -google-cloud-bigquery==3.26.0 -google-cloud-bigquery-storage==2.26.0 -google-cloud-bigtable==2.26.0 -google-cloud-core==2.4.1 -google-cloud-datastore==2.20.1 -google-cloud-dlp==3.23.0 -google-cloud-language==2.14.0 -google-cloud-profiler==4.1.0 -google-cloud-pubsub==2.25.2 -google-cloud-pubsublite==1.11.1 -google-cloud-recommendations-ai==0.10.12 -google-cloud-resource-manager==1.12.5 -google-cloud-spanner==3.49.1 -google-cloud-storage==2.18.2 -google-cloud-videointelligence==2.13.5 -google-cloud-vision==3.7.4 -google-crc32c==1.5.0 -google-resumable-media==2.7.2 -googleapis-common-protos==1.65.0 -greenlet==3.1.1 -grpc-google-iam-v1==0.13.1 -grpc-interceptor==0.15.4 -grpcio==1.65.5 -grpcio-status==1.62.3 -guppy3==3.1.4.post1 -hdfs==2.7.3 -httplib2==0.22.0 -hypothesis==6.112.3 -idna==3.10 -importlib_metadata==8.4.0 -importlib_resources==6.4.5 -iniconfig==2.0.0 -jaraco.classes==3.4.0 -jaraco.context==6.0.1 -jaraco.functools==4.1.0 -jeepney==0.8.0 -Jinja2==3.1.4 -joblib==1.4.2 -jsonpickle==3.3.0 -jsonschema==4.23.0 -jsonschema-specifications==2023.12.1 -keyring==25.4.1 -keyrings.google-artifactregistry-auth==1.1.2 -MarkupSafe==2.1.5 -mmh3==5.0.1 -mock==5.1.0 -more-itertools==10.5.0 -nltk==3.9.1 -nose==1.3.7 -numpy==1.24.4 -oauth2client==4.1.3 -objsize==0.7.0 -opentelemetry-api==1.27.0 -opentelemetry-sdk==1.27.0 -opentelemetry-semantic-conventions==0.48b0 -orjson==3.10.7 -overrides==7.7.0 -packaging==24.1 -pandas==2.0.3 -parameterized==0.9.0 -pkgutil_resolve_name==1.3.10 -pluggy==1.5.0 -proto-plus==1.24.0 -protobuf==4.25.5 -psycopg2-binary==2.9.9 -pyarrow==16.1.0 -pyarrow-hotfix==0.6 -pyasn1==0.6.1 -pyasn1_modules==0.4.1 -pycparser==2.22 -pydantic==2.9.2 -pydantic_core==2.23.4 -pydot==1.4.2 -PyHamcrest==2.1.0 -pymongo==4.10.1 -PyMySQL==1.1.1 -pyparsing==3.1.4 -pyproject_hooks==1.2.0 -pytest==7.4.4 -pytest-timeout==2.3.1 -pytest-xdist==3.6.1 -python-dateutil==2.9.0.post0 -python-snappy==0.7.3 -pytz==2024.2 -PyYAML==6.0.2 -redis==5.1.1 -referencing==0.35.1 -regex==2024.9.11 -requests==2.32.3 -requests-mock==1.12.1 -rpds-py==0.20.0 -rsa==4.9 -scikit-learn==1.3.2 -scipy==1.10.1 -SecretStorage==3.3.3 -shapely==2.0.6 -six==1.16.0 -sortedcontainers==2.4.0 -soupsieve==2.6 -SQLAlchemy==2.0.35 -sqlparse==0.5.1 -tenacity==8.5.0 -testcontainers==3.7.1 -threadpoolctl==3.5.0 -tomli==2.0.2 -tqdm==4.66.5 -typing_extensions==4.12.2 -tzdata==2024.2 -uritemplate==4.1.1 -urllib3==2.2.3 -wrapt==1.16.0 -zipp==3.20.2 -zstandard==0.23.0 diff --git a/sdks/python/container/py38/build.gradle b/sdks/python/container/py38/build.gradle deleted file mode 100644 index 304895a83718..000000000000 --- a/sdks/python/container/py38/build.gradle +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -plugins { - id 'base' - id 'org.apache.beam.module' -} -applyDockerNature() -applyPythonNature() - -pythonVersion = '3.8' - -apply from: "../common.gradle" diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 6bca904c1a64..845791e9c10f 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -543,8 +543,8 @@ task mockAPITests { } // add all RunInference E2E tests that run on DataflowRunner -// As of now, this test suite is enable in py38 suite as the base NVIDIA image used for Tensor RT -// contains Python 3.8. +// As of now, this test suite is enable in py310 suite as the base NVIDIA image used for Tensor RT +// contains Python 3.10. // TODO: https://github.com/apache/beam/issues/22651 project.tasks.register("inferencePostCommitIT") { dependsOn = [ diff --git a/sdks/python/test-suites/dataflow/py38/build.gradle b/sdks/python/test-suites/dataflow/py38/build.gradle deleted file mode 100644 index b3c3a5bfb8a6..000000000000 --- a/sdks/python/test-suites/dataflow/py38/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -apply plugin: org.apache.beam.gradle.BeamModulePlugin -applyPythonNature() - -// Required to setup a Python 3 virtualenv and task names. -pythonVersion = '3.8' -apply from: "../common.gradle" diff --git a/sdks/python/test-suites/direct/py38/build.gradle b/sdks/python/test-suites/direct/py38/build.gradle deleted file mode 100644 index edf86a7bf5a8..000000000000 --- a/sdks/python/test-suites/direct/py38/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -plugins { id 'org.apache.beam.module' } -applyPythonNature() - -// Required to setup a Python 3 virtualenv and task names. -pythonVersion = '3.8' -apply from: '../common.gradle' diff --git a/sdks/python/test-suites/portable/py38/build.gradle b/sdks/python/test-suites/portable/py38/build.gradle deleted file mode 100644 index e15443fa935f..000000000000 --- a/sdks/python/test-suites/portable/py38/build.gradle +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -apply plugin: org.apache.beam.gradle.BeamModulePlugin -applyPythonNature() - -addPortableWordCountTasks() - -// Required to setup a Python 3.8 virtualenv and task names. -pythonVersion = '3.8' -apply from: "../common.gradle" diff --git a/sdks/python/test-suites/tox/py38/build.gradle b/sdks/python/test-suites/tox/py38/build.gradle deleted file mode 100644 index 2ca82d3d9268..000000000000 --- a/sdks/python/test-suites/tox/py38/build.gradle +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Unit tests for Python 3.8 - */ - -plugins { id 'org.apache.beam.module' } -applyPythonNature() - -// Required to setup a Python 3 virtualenv and task names. -pythonVersion = '3.8' - -def posargs = project.findProperty("posargs") ?: "" - -apply from: "../common.gradle" - -toxTask "testPy38CloudCoverage", "py38-cloudcoverage", "${posargs}" -test.dependsOn "testPy38CloudCoverage" -project.tasks.register("preCommitPyCoverage") { - dependsOn = ["testPy38CloudCoverage"] -} - -// Dep Postcommit runs test suites that evaluate compatibility of particular -// dependencies. It is exercised on a single Python version. -// -// Should still leave at least one version in PreCommit unless the marked tests -// are also exercised by existing PreCommit -// e.g. pyarrow and pandas also run on PreCommit Dataframe and Coverage -project.tasks.register("postCommitPyDep") {} - -// Create a test task for supported major versions of pyarrow -// We should have a test for the lowest supported version and -// For versions that we would like to prioritize for testing, -// for example versions released in a timeframe of last 1-2 years. - -toxTask "testPy38pyarrow-3", "py38-pyarrow-3", "${posargs}" -test.dependsOn "testPy38pyarrow-3" -postCommitPyDep.dependsOn "testPy38pyarrow-3" - -toxTask "testPy38pyarrow-9", "py38-pyarrow-9", "${posargs}" -test.dependsOn "testPy38pyarrow-9" -postCommitPyDep.dependsOn "testPy38pyarrow-9" - -toxTask "testPy38pyarrow-10", "py38-pyarrow-10", "${posargs}" -test.dependsOn "testPy38pyarrow-10" -postCommitPyDep.dependsOn "testPy38pyarrow-10" - -toxTask "testPy38pyarrow-11", "py38-pyarrow-11", "${posargs}" -test.dependsOn "testPy38pyarrow-11" -postCommitPyDep.dependsOn "testPy38pyarrow-11" - -toxTask "testPy38pyarrow-12", "py38-pyarrow-12", "${posargs}" -test.dependsOn "testPy38pyarrow-12" -postCommitPyDep.dependsOn "testPy38pyarrow-12" - -toxTask "testPy38pyarrow-13", "py38-pyarrow-13", "${posargs}" -test.dependsOn "testPy38pyarrow-13" -postCommitPyDep.dependsOn "testPy38pyarrow-13" - -toxTask "testPy38pyarrow-14", "py38-pyarrow-14", "${posargs}" -test.dependsOn "testPy38pyarrow-14" -postCommitPyDep.dependsOn "testPy38pyarrow-14" - -toxTask "testPy38pyarrow-15", "py38-pyarrow-15", "${posargs}" -test.dependsOn "testPy38pyarrow-15" -postCommitPyDep.dependsOn "testPy38pyarrow-15" - -toxTask "testPy38pyarrow-16", "py38-pyarrow-16", "${posargs}" -test.dependsOn "testPy38pyarrow-16" -postCommitPyDep.dependsOn "testPy38pyarrow-16" - -// Create a test task for each supported minor version of pandas -toxTask "testPy38pandas-14", "py38-pandas-14", "${posargs}" -test.dependsOn "testPy38pandas-14" -postCommitPyDep.dependsOn "testPy38pandas-14" - -toxTask "testPy38pandas-15", "py38-pandas-15", "${posargs}" -test.dependsOn "testPy38pandas-15" -postCommitPyDep.dependsOn "testPy38pandas-15" - -toxTask "testPy38pandas-20", "py38-pandas-20", "${posargs}" -test.dependsOn "testPy38pandas-20" -postCommitPyDep.dependsOn "testPy38pandas-20" - -// TODO(https://github.com/apache/beam/issues/31192): Add below suites -// after dependency compat tests suite switches to Python 3.9 or we add -// Python 2.2 support. - -// toxTask "testPy39pandas-21", "py39-pandas-21", "${posargs}" -// test.dependsOn "testPy39pandas-21" -// postCommitPyDep.dependsOn "testPy39pandas-21" - -// toxTask "testPy39pandas-22", "py39-pandas-22", "${posargs}" -// test.dependsOn "testPy39pandas-22" -// postCommitPyDep.dependsOn "testPy39pandas-22" - -// TODO(https://github.com/apache/beam/issues/30908): Revise what are we testing - -// Create a test task for each minor version of pytorch -toxTask "testPy38pytorch-19", "py38-pytorch-19", "${posargs}" -test.dependsOn "testPy38pytorch-19" -postCommitPyDep.dependsOn "testPy38pytorch-19" - -toxTask "testPy38pytorch-110", "py38-pytorch-110", "${posargs}" -test.dependsOn "testPy38pytorch-110" -postCommitPyDep.dependsOn "testPy38pytorch-110" - -toxTask "testPy38pytorch-111", "py38-pytorch-111", "${posargs}" -test.dependsOn "testPy38pytorch-111" -postCommitPyDep.dependsOn "testPy38pytorch-111" - -toxTask "testPy38pytorch-112", "py38-pytorch-112", "${posargs}" -test.dependsOn "testPy38pytorch-112" -postCommitPyDep.dependsOn "testPy38pytorch-112" - -toxTask "testPy38pytorch-113", "py38-pytorch-113", "${posargs}" -test.dependsOn "testPy38pytorch-113" -postCommitPyDep.dependsOn "testPy38pytorch-113" - -// run on precommit -toxTask "testPy38pytorch-200", "py38-pytorch-200", "${posargs}" -test.dependsOn "testPy38pytorch-200" -postCommitPyDep.dependsOn "testPy38pytorch-200" - -toxTask "testPy38tft-113", "py38-tft-113", "${posargs}" -test.dependsOn "testPy38tft-113" -postCommitPyDep.dependsOn "testPy38tft-113" - -// TODO(https://github.com/apache/beam/issues/25796) - uncomment onnx tox task once onnx supports protobuf 4.x.x -// Create a test task for each minor version of onnx -// toxTask "testPy38onnx-113", "py38-onnx-113", "${posargs}" -// test.dependsOn "testPy38onnx-113" -// postCommitPyDep.dependsOn "testPy38onnx-113" - -// Create a test task for each minor version of tensorflow -toxTask "testPy38tensorflow-212", "py38-tensorflow-212", "${posargs}" -test.dependsOn "testPy38tensorflow-212" -postCommitPyDep.dependsOn "testPy38tensorflow-212" - -// Create a test task for each minor version of transformers -toxTask "testPy38transformers-428", "py38-transformers-428", "${posargs}" -test.dependsOn "testPy38transformers-428" -postCommitPyDep.dependsOn "testPy38transformers-428" - -toxTask "testPy38transformers-429", "py38-transformers-429", "${posargs}" -test.dependsOn "testPy38transformers-429" -postCommitPyDep.dependsOn "testPy38transformers-429" - -toxTask "testPy38transformers-430", "py38-transformers-430", "${posargs}" -test.dependsOn "testPy38transformers-430" -postCommitPyDep.dependsOn "testPy38transformers-430" - -toxTask "testPy38embeddingsMLTransform", "py38-embeddings", "${posargs}" -test.dependsOn "testPy38embeddingsMLTransform" -postCommitPyDep.dependsOn "testPy38embeddingsMLTransform" - -// Part of MLTransform embeddings test suite but requires tensorflow hub, which we need to test on -// mutliple versions so keeping this suite separate. -toxTask "testPy38TensorflowHubEmbeddings-014", "py38-TFHubEmbeddings-014", "${posargs}" -test.dependsOn "testPy38TensorflowHubEmbeddings-014" -postCommitPyDep.dependsOn "testPy38TensorflowHubEmbeddings-014" - -toxTask "testPy38TensorflowHubEmbeddings-015", "py38-TFHubEmbeddings-015", "${posargs}" -test.dependsOn "testPy38TensorflowHubEmbeddings-015" -postCommitPyDep.dependsOn "testPy38TensorflowHubEmbeddings-015" - -toxTask "whitespacelint", "whitespacelint", "${posargs}" - -task archiveFilesToLint(type: Zip) { - archiveFileName = "files-to-whitespacelint.zip" - destinationDirectory = file("$buildDir/dist") - - from ("$rootProject.projectDir") { - include "**/*.md" - include "**/build.gradle" - include '**/build.gradle.kts' - exclude '**/build/**' // intermediate build directory - exclude 'website/www/site/themes/docsy/**' // fork to google/docsy - exclude "**/node_modules/*" - exclude "**/.gogradle/*" - } -} - -task unpackFilesToLint(type: Copy) { - from zipTree("$buildDir/dist/files-to-whitespacelint.zip") - into "$buildDir/files-to-whitespacelint" -} - -whitespacelint.dependsOn archiveFilesToLint, unpackFilesToLint -unpackFilesToLint.dependsOn archiveFilesToLint -archiveFilesToLint.dependsOn cleanPython - -toxTask "jest", "jest", "${posargs}" - -toxTask "eslint", "eslint", "${posargs}" - -task copyTsSource(type: Copy) { - from ("$rootProject.projectDir") { - include "sdks/python/apache_beam/runners/interactive/extensions/**/*" - exclude "sdks/python/apache_beam/runners/interactive/extensions/**/lib/*" - exclude "sdks/python/apache_beam/runners/interactive/extensions/**/node_modules/*" - } - into "$buildDir/ts" -} - -jest.dependsOn copyTsSource -eslint.dependsOn copyTsSource -copyTsSource.dependsOn cleanPython From a5e09025dd053a5b893d086e3c15657f41970bb2 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Tue, 29 Oct 2024 11:57:15 -0400 Subject: [PATCH 081/181] Suppress future warning affecting Dataflow notebook (#32957) * Suppress future warning affecting Dataflow notebook * fix lint --- .../runners/interactive/display/pcoll_visualization.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py b/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py index 693abb2aeeee..d767a15a345d 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py +++ b/sdks/python/apache_beam/runners/interactive/display/pcoll_visualization.py @@ -26,6 +26,7 @@ import datetime import html import logging +import warnings from datetime import timedelta from typing import Optional @@ -350,7 +351,12 @@ def display(self, updating_pv=None): ] # String-ify the dictionaries for display because elements of type dict # cannot be ordered. - data = data.applymap(lambda x: str(x) if isinstance(x, dict) else x) + with warnings.catch_warnings(): + # TODO(yathu) switch to use DataFrame.map when dropped pandas<2.1 support + warnings.filterwarnings( + "ignore", message="DataFrame.applymap has been deprecated") + data = data.applymap(lambda x: str(x) if isinstance(x, dict) else x) + if updating_pv: # Only updates when data is not empty. Otherwise, consider it a bad # iteration and noop since there is nothing to be updated. From f2e1f941ce8b464234868aacb63071e6966c7d4c Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Tue, 29 Oct 2024 11:57:29 -0400 Subject: [PATCH 082/181] Use beam-vendor-grpc-1_60_1:0.3 (#32958) --- .../main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 533fd6a0d475..5af91ec2f056 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -909,7 +909,7 @@ class BeamModulePlugin implements Plugin { testcontainers_solace : "org.testcontainers:solace:$testcontainers_version", truth : "com.google.truth:truth:1.1.5", threetenbp : "org.threeten:threetenbp:1.6.8", - vendored_grpc_1_60_1 : "org.apache.beam:beam-vendor-grpc-1_60_1:0.2", + vendored_grpc_1_60_1 : "org.apache.beam:beam-vendor-grpc-1_60_1:0.3", vendored_guava_32_1_2_jre : "org.apache.beam:beam-vendor-guava-32_1_2-jre:0.1", vendored_calcite_1_28_0 : "org.apache.beam:beam-vendor-calcite-1_28_0:0.2", woodstox_core_asl : "org.codehaus.woodstox:woodstox-core-asl:4.4.1", From a0880a473c880ef1de11d6a2d5a3363fcf60ad74 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Tue, 29 Oct 2024 12:28:56 -0400 Subject: [PATCH 083/181] Revert "Override BQ load job location when necessary (#31986)" This reverts commit ea982127b60545164e0e280eb0d4140f35ae3156. --- .../apache_beam/io/gcp/bigquery_file_loads.py | 18 +----------------- .../io/gcp/bigquery_file_loads_test.py | 10 ---------- .../apache_beam/io/gcp/bigquery_tools.py | 8 ++------ 3 files changed, 3 insertions(+), 33 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py index a7311ad6d063..3145fb511068 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py @@ -777,26 +777,10 @@ def process( GlobalWindows.windowed_value((destination, job_reference))) def finish_bundle(self): - dataset_locations = {} - for windowed_value in self.pending_jobs: - table_ref = bigquery_tools.parse_table_reference(windowed_value.value[0]) - project_dataset = (table_ref.projectId, table_ref.datasetId) - job_ref = windowed_value.value[1] - # In some cases (e.g. when the load job op returns a 409 ALREADY_EXISTS), - # the returned job reference may not include a location. In such cases, - # we need to override with the dataset's location. - job_location = job_ref.location - if not job_location and project_dataset not in dataset_locations: - job_location = self.bq_wrapper.get_table_location( - table_ref.projectId, table_ref.datasetId, table_ref.tableId) - dataset_locations[project_dataset] = job_location - self.bq_wrapper.wait_for_bq_job( - job_ref, - sleep_duration_sec=_SLEEP_DURATION_BETWEEN_POLLS, - location=job_location) + job_ref, sleep_duration_sec=_SLEEP_DURATION_BETWEEN_POLLS) return self.pending_jobs diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py index e4c0e34d9c1f..10453d9c8baf 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py @@ -427,7 +427,6 @@ def test_records_traverse_transform_with_mocks(self): job_reference = bigquery_api.JobReference() job_reference.projectId = 'project1' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = bigquery_api.Job() result_job.jobReference = job_reference @@ -483,7 +482,6 @@ def test_load_job_id_used(self): job_reference = bigquery_api.JobReference() job_reference.projectId = 'loadJobProject' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = bigquery_api.Job() result_job.jobReference = job_reference @@ -521,7 +519,6 @@ def test_load_job_id_use_for_copy_job(self): job_reference = bigquery_api.JobReference() job_reference.projectId = 'loadJobProject' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = mock.Mock() result_job.jobReference = job_reference @@ -577,12 +574,10 @@ def test_wait_for_load_job_completion(self, sleep_mock): job_1.jobReference = bigquery_api.JobReference() job_1.jobReference.projectId = 'project1' job_1.jobReference.jobId = 'jobId1' - job_1.jobReference.location = 'US' job_2 = bigquery_api.Job() job_2.jobReference = bigquery_api.JobReference() job_2.jobReference.projectId = 'project1' job_2.jobReference.jobId = 'jobId2' - job_2.jobReference.location = 'US' job_1_waiting = mock.Mock() job_1_waiting.status.state = 'RUNNING' @@ -622,12 +617,10 @@ def test_one_load_job_failed_after_waiting(self, sleep_mock): job_1.jobReference = bigquery_api.JobReference() job_1.jobReference.projectId = 'project1' job_1.jobReference.jobId = 'jobId1' - job_1.jobReference.location = 'US' job_2 = bigquery_api.Job() job_2.jobReference = bigquery_api.JobReference() job_2.jobReference.projectId = 'project1' job_2.jobReference.jobId = 'jobId2' - job_2.jobReference.location = 'US' job_1_waiting = mock.Mock() job_1_waiting.status.state = 'RUNNING' @@ -664,7 +657,6 @@ def test_multiple_partition_files(self): job_reference = bigquery_api.JobReference() job_reference.projectId = 'project1' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = mock.Mock() result_job.jobReference = job_reference @@ -750,7 +742,6 @@ def test_multiple_partition_files_write_dispositions( job_reference = bigquery_api.JobReference() job_reference.projectId = 'project1' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = mock.Mock() result_job.jobReference = job_reference @@ -793,7 +784,6 @@ def test_triggering_frequency(self, is_streaming, with_auto_sharding): job_reference = bigquery_api.JobReference() job_reference.projectId = 'project1' job_reference.jobId = 'job_name1' - job_reference.location = 'US' result_job = bigquery_api.Job() result_job.jobReference = job_reference diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index c7128e7899ec..a92f30ec35ce 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -631,8 +631,7 @@ def _start_query_job( return self._start_job(request) - def wait_for_bq_job( - self, job_reference, sleep_duration_sec=5, max_retries=0, location=None): + def wait_for_bq_job(self, job_reference, sleep_duration_sec=5, max_retries=0): """Poll job until it is DONE. Args: @@ -640,7 +639,6 @@ def wait_for_bq_job( sleep_duration_sec: Specifies the delay in seconds between retries. max_retries: The total number of times to retry. If equals to 0, the function waits forever. - location: Fall back on this location if job_reference doesn't have one. Raises: `RuntimeError`: If the job is FAILED or the number of retries has been @@ -650,9 +648,7 @@ def wait_for_bq_job( while True: retry += 1 job = self.get_job( - job_reference.projectId, - job_reference.jobId, - job_reference.location or location) + job_reference.projectId, job_reference.jobId, job_reference.location) _LOGGER.info('Job %s status: %s', job.id, job.status.state) if job.status.state == 'DONE' and job.status.errorResult: raise RuntimeError( From ec428e4d036bd0e337c9b9ee81835666b3cf2c5d Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Tue, 29 Oct 2024 13:46:54 -0400 Subject: [PATCH 084/181] Parse load job location in HTTPError content --- .../apache_beam/io/gcp/bigquery_tools.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index a92f30ec35ce..c0ff29f3afe3 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -32,6 +32,7 @@ import io import json import logging +import re import sys import time import uuid @@ -558,6 +559,19 @@ def _insert_load_job( )) return self._start_job(request, stream=source_stream).jobReference + @staticmethod + def _parse_location_from_exc(content, job_id): + """Parse job location from Exception content.""" + if isinstance(content, bytes): + content = content.decode('ascii', 'replace') + # search for "Already Exists: Job :." + m = re.search(r"Already Exists: Job \S+\:(\S+)\." + job_id, content) + if not m: + _LOGGER.warning( + "Not able to parse BigQuery load job location for {}", job_id) + return None + return m.group(1) + def _start_job( self, request, # type: bigquery.BigqueryJobsInsertRequest @@ -585,11 +599,17 @@ def _start_job( return response except HttpError as exn: if exn.status_code == 409: + jobId = request.job.jobReference.jobId _LOGGER.info( "BigQuery job %s already exists, will not retry inserting it: %s", request.job.jobReference, exn) - return request.job + job_location = self._parse_location_from_exc(exn.content, jobId) + response = request.job + if not response.jobReference.location and job_location: + # Request not constructed with location + response.jobReference.location = job_location + return response else: _LOGGER.info( "Failed to insert job %s: %s", request.job.jobReference, exn) From a0cbe74e7ba69121666041627e12483b38a43d9e Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:29:31 -0400 Subject: [PATCH 085/181] Disable the Py39embeddingsMLTransform tests in the Dependency Postcommit (#32966) --- sdks/python/test-suites/tox/py39/build.gradle | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sdks/python/test-suites/tox/py39/build.gradle b/sdks/python/test-suites/tox/py39/build.gradle index ea02e9d5b1e8..e9624f8e810e 100644 --- a/sdks/python/test-suites/tox/py39/build.gradle +++ b/sdks/python/test-suites/tox/py39/build.gradle @@ -168,7 +168,9 @@ postCommitPyDep.dependsOn "testPy39transformers-430" toxTask "testPy39embeddingsMLTransform", "py39-embeddings", "${posargs}" test.dependsOn "testPy39embeddingsMLTransform" -postCommitPyDep.dependsOn "testPy39embeddingsMLTransform" +// TODO(https://github.com/apache/beam/issues/32965): re-enable this suite for the dep +// postcommit once the sentence-transformers import error is debugged +// postCommitPyDep.dependsOn "testPy39embeddingsMLTransform" // Part of MLTransform embeddings test suite but requires tensorflow hub, which we need to test on // mutliple versions so keeping this suite separate. From d169006a43f5be23738321cb26ecce97f8df8696 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Tue, 29 Oct 2024 16:01:02 -0400 Subject: [PATCH 086/181] Fix indent --- sdks/python/apache_beam/io/gcp/bigquery_tools.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index c0ff29f3afe3..b31f6449fe90 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -564,13 +564,13 @@ def _parse_location_from_exc(content, job_id): """Parse job location from Exception content.""" if isinstance(content, bytes): content = content.decode('ascii', 'replace') - # search for "Already Exists: Job :." - m = re.search(r"Already Exists: Job \S+\:(\S+)\." + job_id, content) - if not m: - _LOGGER.warning( - "Not able to parse BigQuery load job location for {}", job_id) - return None - return m.group(1) + # search for "Already Exists: Job :." + m = re.search(r"Already Exists: Job \S+\:(\S+)\." + job_id, content) + if not m: + _LOGGER.warning( + "Not able to parse BigQuery load job location for %s", job_id) + return None + return m.group(1) def _start_job( self, From 48a99e2b61703e17ed1b8947949883a9392bd916 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 29 Oct 2024 16:47:41 -0400 Subject: [PATCH 087/181] Update MLTransform code to PEP 585 types --- sdks/python/apache_beam/ml/transforms/base.py | 63 +++++++++---------- .../apache_beam/ml/transforms/base_test.py | 14 ++--- .../ml/transforms/embeddings/huggingface.py | 14 ++--- .../transforms/embeddings/tensorflow_hub.py | 7 +-- .../ml/transforms/embeddings/vertex_ai.py | 14 ++--- .../apache_beam/ml/transforms/handlers.py | 43 +++++++------ .../ml/transforms/handlers_test.py | 19 +++--- sdks/python/apache_beam/ml/transforms/tft.py | 61 +++++++++--------- .../python/apache_beam/ml/transforms/utils.py | 3 +- 9 files changed, 111 insertions(+), 127 deletions(-) diff --git a/sdks/python/apache_beam/ml/transforms/base.py b/sdks/python/apache_beam/ml/transforms/base.py index 678ab0882d24..a963f602a06d 100644 --- a/sdks/python/apache_beam/ml/transforms/base.py +++ b/sdks/python/apache_beam/ml/transforms/base.py @@ -20,14 +20,11 @@ import os import tempfile import uuid +from collections.abc import Mapping +from collections.abc import Sequence from typing import Any -from typing import Dict from typing import Generic -from typing import List -from typing import Mapping from typing import Optional -from typing import Sequence -from typing import Tuple from typing import TypeVar from typing import Union @@ -67,7 +64,7 @@ def _convert_list_of_dicts_to_dict_of_lists( - list_of_dicts: Sequence[Dict[str, Any]]) -> Dict[str, List[Any]]: + list_of_dicts: Sequence[dict[str, Any]]) -> dict[str, list[Any]]: keys_to_element_list = collections.defaultdict(list) input_keys = list_of_dicts[0].keys() for d in list_of_dicts: @@ -83,9 +80,9 @@ def _convert_list_of_dicts_to_dict_of_lists( def _convert_dict_of_lists_to_lists_of_dict( - dict_of_lists: Dict[str, List[Any]]) -> List[Dict[str, Any]]: + dict_of_lists: dict[str, list[Any]]) -> list[dict[str, Any]]: batch_length = len(next(iter(dict_of_lists.values()))) - result: List[Dict[str, Any]] = [{} for _ in range(batch_length)] + result: list[dict[str, Any]] = [{} for _ in range(batch_length)] # all the values in the dict_of_lists should have same length for key, values in dict_of_lists.items(): assert len(values) == batch_length, ( @@ -140,7 +137,7 @@ def get_counter(self): class BaseOperation(Generic[OperationInputT, OperationOutputT], MLTransformProvider, abc.ABC): - def __init__(self, columns: List[str]) -> None: + def __init__(self, columns: list[str]) -> None: """ Base Opertation class data processing transformations. Args: @@ -150,7 +147,7 @@ def __init__(self, columns: List[str]) -> None: @abc.abstractmethod def apply_transform(self, data: OperationInputT, - output_column_name: str) -> Dict[str, OperationOutputT]: + output_column_name: str) -> dict[str, OperationOutputT]: """ Define any processing logic in the apply_transform() method. processing logics are applied on inputs and returns a transformed @@ -160,7 +157,7 @@ def apply_transform(self, data: OperationInputT, """ def __call__(self, data: OperationInputT, - output_column_name: str) -> Dict[str, OperationOutputT]: + output_column_name: str) -> dict[str, OperationOutputT]: """ This method is called when the instance of the class is called. This method will invoke the apply() method of the class. @@ -172,7 +169,7 @@ def __call__(self, data: OperationInputT, class ProcessHandler( beam.PTransform[beam.PCollection[ExampleT], Union[beam.PCollection[MLTransformOutputT], - Tuple[beam.PCollection[MLTransformOutputT], + tuple[beam.PCollection[MLTransformOutputT], beam.PCollection[beam.Row]]]], abc.ABC): """ @@ -190,10 +187,10 @@ def append_transform(self, transform: BaseOperation): class EmbeddingsManager(MLTransformProvider): def __init__( self, - columns: List[str], + columns: list[str], *, # common args for all ModelHandlers. - load_model_args: Optional[Dict[str, Any]] = None, + load_model_args: Optional[dict[str, Any]] = None, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, large_model: bool = False, @@ -222,7 +219,7 @@ def get_columns_to_apply(self): class MLTransform( beam.PTransform[beam.PCollection[ExampleT], Union[beam.PCollection[MLTransformOutputT], - Tuple[beam.PCollection[MLTransformOutputT], + tuple[beam.PCollection[MLTransformOutputT], beam.PCollection[beam.Row]]]], Generic[ExampleT, MLTransformOutputT]): def __init__( @@ -230,7 +227,7 @@ def __init__( *, write_artifact_location: Optional[str] = None, read_artifact_location: Optional[str] = None, - transforms: Optional[List[MLTransformProvider]] = None): + transforms: Optional[list[MLTransformProvider]] = None): """ MLTransform is a Beam PTransform that can be used to apply transformations to the data. MLTransform is used to wrap the @@ -304,12 +301,12 @@ def __init__( self._counter = Metrics.counter( MLTransform, f'BeamML_{self.__class__.__name__}') self._with_exception_handling = False - self._exception_handling_args: Dict[str, Any] = {} + self._exception_handling_args: dict[str, Any] = {} def expand( self, pcoll: beam.PCollection[ExampleT] ) -> Union[beam.PCollection[MLTransformOutputT], - Tuple[beam.PCollection[MLTransformOutputT], + tuple[beam.PCollection[MLTransformOutputT], beam.PCollection[beam.Row]]]: """ This is the entrypoint for the MLTransform. This method will @@ -533,7 +530,7 @@ class _MLTransformToPTransformMapper: """ def __init__( self, - transforms: List[MLTransformProvider], + transforms: list[MLTransformProvider], artifact_location: str, artifact_mode: str = ArtifactMode.PRODUCE, pipeline_options: Optional[PipelineOptions] = None, @@ -595,7 +592,7 @@ class _EmbeddingHandler(ModelHandler): For example, if the original mode is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take a - PCollection[Dict[str, E]] to a PCollection[Dict[str, P]]. + PCollection[dict[str, E]] to a PCollection[dict[str, P]]. _EmbeddingHandler will accept an EmbeddingsManager instance, which contains the details of the model to be loaded and the inference_fn to be @@ -619,7 +616,7 @@ def load_model(self): def _validate_column_data(self, batch): pass - def _validate_batch(self, batch: Sequence[Dict[str, Any]]): + def _validate_batch(self, batch: Sequence[dict[str, Any]]): if not batch or not isinstance(batch[0], dict): raise TypeError( 'Expected data to be dicts, got ' @@ -627,10 +624,10 @@ def _validate_batch(self, batch: Sequence[Dict[str, Any]]): def _process_batch( self, - dict_batch: Dict[str, List[Any]], + dict_batch: dict[str, list[Any]], model: ModelT, - inference_args: Optional[Dict[str, Any]]) -> Dict[str, List[Any]]: - result: Dict[str, List[Any]] = collections.defaultdict(list) + inference_args: Optional[dict[str, Any]]) -> dict[str, list[Any]]: + result: dict[str, list[Any]] = collections.defaultdict(list) input_keys = dict_batch.keys() missing_columns_in_data = set(self.columns) - set(input_keys) if missing_columns_in_data: @@ -653,10 +650,10 @@ def _process_batch( def run_inference( self, - batch: Sequence[Dict[str, List[str]]], + batch: Sequence[dict[str, list[str]]], model: ModelT, - inference_args: Optional[Dict[str, Any]] = None, - ) -> List[Dict[str, Union[List[float], List[str]]]]: + inference_args: Optional[dict[str, Any]] = None, + ) -> list[dict[str, Union[list[float], list[str]]]]: """ Runs inference on a batch of text inputs. The inputs are expected to be a list of dicts. Each dict should have the same keys, and the shape @@ -696,7 +693,7 @@ class _TextEmbeddingHandler(_EmbeddingHandler): For example, if the original mode is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take a - PCollection[Dict[str, E]] to a PCollection[Dict[str, P]]. + PCollection[dict[str, E]] to a PCollection[dict[str, P]]. _TextEmbeddingHandler will accept an EmbeddingsManager instance, which contains the details of the model to be loaded and the inference_fn to be @@ -713,8 +710,8 @@ class _TextEmbeddingHandler(_EmbeddingHandler): def _validate_column_data(self, batch): if not isinstance(batch[0], (str, bytes)): raise TypeError( - 'Embeddings can only be generated on Dict[str, str].' - f'Got Dict[str, {type(batch[0])}] instead.') + 'Embeddings can only be generated on dict[str, str].' + f'Got dict[str, {type(batch[0])}] instead.') def get_metrics_namespace(self) -> str: return ( @@ -730,7 +727,7 @@ class _ImageEmbeddingHandler(_EmbeddingHandler): For example, if the original mode is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take a - PCollection[Dict[str, E]] to a PCollection[Dict[str, P]]. + PCollection[dict[str, E]] to a PCollection[dict[str, P]]. _ImageEmbeddingHandler will accept an EmbeddingsManager instance, which contains the details of the model to be loaded and the inference_fn to be @@ -750,8 +747,8 @@ def _validate_column_data(self, batch): # here, so just catch columns of primatives for now. if isinstance(batch[0], (int, str, float, bool)): raise TypeError( - 'Embeddings can only be generated on Dict[str, Image].' - f'Got Dict[str, {type(batch[0])}] instead.') + 'Embeddings can only be generated on dict[str, Image].' + f'Got dict[str, {type(batch[0])}] instead.') def get_metrics_namespace(self) -> str: return ( diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index 743c3683ce8e..3320627dc794 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -21,11 +21,9 @@ import tempfile import typing import unittest +from collections.abc import Sequence from typing import Any -from typing import Dict -from typing import List from typing import Optional -from typing import Sequence import numpy as np from parameterized import param @@ -162,7 +160,7 @@ def test_ml_transform_on_list_dict(self): 'x': [1, 2, 3], 'y': [2.0, 3.0, 4.0] }], input_types={ - 'x': List[int], 'y': List[float] + 'x': list[int], 'y': list[float] }, expected_dtype={ 'x': typing.Sequence[np.float32], @@ -320,7 +318,7 @@ def test_read_mode_with_transforms(self): class FakeModel: - def __call__(self, example: List[str]) -> List[str]: + def __call__(self, example: list[str]) -> list[str]: for i in range(len(example)): if not isinstance(example[i], str): raise TypeError('Input must be a string') @@ -333,7 +331,7 @@ def run_inference( self, batch: Sequence[str], model: Any, - inference_args: Optional[Dict[str, Any]] = None): + inference_args: Optional[dict[str, Any]] = None): return model(batch) def load_model(self): @@ -508,7 +506,7 @@ def test_handler_with_inconsistent_keys(self): class FakeImageModel: - def __call__(self, example: List[PIL_Image]) -> List[PIL_Image]: + def __call__(self, example: list[PIL_Image]) -> list[PIL_Image]: for i in range(len(example)): if not isinstance(example[i], PIL_Image): raise TypeError('Input must be an Image') @@ -520,7 +518,7 @@ def run_inference( self, batch: Sequence[PIL_Image], model: Any, - inference_args: Optional[Dict[str, Any]] = None): + inference_args: Optional[dict[str, Any]] = None): return model(batch) def load_model(self): diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py index 46b4ef9cf7d6..2162ed050c42 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/huggingface.py @@ -18,13 +18,11 @@ import logging import os +from collections.abc import Callable +from collections.abc import Mapping +from collections.abc import Sequence from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Mapping from typing import Optional -from typing import Sequence import requests @@ -80,7 +78,7 @@ def run_inference( self, batch: Sequence[str], model: SentenceTransformer, - inference_args: Optional[Dict[str, Any]] = None, + inference_args: Optional[dict[str, Any]] = None, ): inference_args = inference_args or {} return model.encode(batch, **inference_args) @@ -113,7 +111,7 @@ class SentenceTransformerEmbeddings(EmbeddingsManager): def __init__( self, model_name: str, - columns: List[str], + columns: list[str], max_seq_length: Optional[int] = None, image_model: bool = False, **kwargs): @@ -216,7 +214,7 @@ class InferenceAPIEmbeddings(EmbeddingsManager): def __init__( self, hf_token: Optional[str], - columns: List[str], + columns: list[str], model_name: Optional[str] = None, # example: "sentence-transformers/all-MiniLM-l6-v2" # pylint: disable=line-too-long api_url: Optional[str] = None, **kwargs, diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py index f78ddf3ff04a..c14904df7c2c 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/tensorflow_hub.py @@ -14,8 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable -from typing import List +from collections.abc import Iterable from typing import Optional import apache_beam as beam @@ -90,7 +89,7 @@ def run_inference(self, batch, model, inference_args, model_id=None): class TensorflowHubTextEmbeddings(EmbeddingsManager): def __init__( self, - columns: List[str], + columns: list[str], hub_url: str, preprocessing_url: Optional[str] = None, **kwargs): @@ -136,7 +135,7 @@ def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: class TensorflowHubImageEmbeddings(EmbeddingsManager): - def __init__(self, columns: List[str], hub_url: str, **kwargs): + def __init__(self, columns: list[str], hub_url: str, **kwargs): """ Embedding config for tensorflow hub models. This config can be used with MLTransform to embed image data. Models are loaded using the RunInference diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py index fbefeec231f1..6fe8320e758b 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py @@ -19,12 +19,10 @@ # Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long # to install Vertex AI Python SDK. +from collections.abc import Iterable +from collections.abc import Sequence from typing import Any -from typing import Dict -from typing import Iterable -from typing import List from typing import Optional -from typing import Sequence from google.auth.credentials import Credentials @@ -80,7 +78,7 @@ def run_inference( self, batch: Sequence[str], model: Any, - inference_args: Optional[Dict[str, Any]] = None, + inference_args: Optional[dict[str, Any]] = None, ) -> Iterable: embeddings = [] batch_size = _BATCH_SIZE @@ -110,7 +108,7 @@ class VertexAITextEmbeddings(EmbeddingsManager): def __init__( self, model_name: str, - columns: List[str], + columns: list[str], title: Optional[str] = None, task_type: str = DEFAULT_TASK_TYPE, project: Optional[str] = None, @@ -179,7 +177,7 @@ def run_inference( self, batch: Sequence[Image], model: MultiModalEmbeddingModel, - inference_args: Optional[Dict[str, Any]] = None, + inference_args: Optional[dict[str, Any]] = None, ) -> Iterable: embeddings = [] # Maximum request size for muli-model embedding models is 1. @@ -204,7 +202,7 @@ class VertexAIImageEmbeddings(EmbeddingsManager): def __init__( self, model_name: str, - columns: List[str], + columns: list[str], dimension: Optional[int], project: Optional[str] = None, location: Optional[str] = None, diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py b/sdks/python/apache_beam/ml/transforms/handlers.py index 7a912f2d88ea..e732e92e14d5 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers.py +++ b/sdks/python/apache_beam/ml/transforms/handlers.py @@ -19,12 +19,11 @@ import collections import copy import os +from collections.abc import Sequence import typing from typing import Any -from typing import Dict -from typing import List +from typing import NamedTuple from typing import Optional -from typing import Sequence from typing import Union import numpy as np @@ -71,18 +70,18 @@ np.str_: tf.string, } _primitive_types_to_typing_container_type = { - int: List[int], float: List[float], str: List[str], bytes: List[bytes] + int: list[int], float: list[float], str: list[str], bytes: list[bytes] } -tft_process_handler_input_type = typing.Union[typing.NamedTuple, +tft_process_handler_input_type = Union[NamedTuple, beam.Row, - Dict[str, - typing.Union[str, + dict[str, + Union[str, float, int, bytes, np.ndarray]]] -tft_process_handler_output_type = typing.Union[beam.Row, Dict[str, np.ndarray]] +tft_process_handler_output_type = Union[beam.Row, dict[str, np.ndarray]] class _DataCoder: @@ -131,15 +130,15 @@ def process( class _ConvertNamedTupleToDict( - beam.PTransform[beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]], - beam.PCollection[Dict[str, + beam.PTransform[beam.PCollection[Union[beam.Row, NamedTuple]], + beam.PCollection[dict[str, common_types.InstanceDictType]]]): """ A PTransform that converts a collection of NamedTuples or Rows into a collection of dictionaries. """ def expand( - self, pcoll: beam.PCollection[typing.Union[beam.Row, typing.NamedTuple]] + self, pcoll: beam.PCollection[Union[beam.Row, NamedTuple]] ) -> beam.PCollection[common_types.InstanceDictType]: """ Args: @@ -163,7 +162,7 @@ def __init__( operations. """ self.transforms = transforms if transforms else [] - self.transformed_schema: Dict[str, type] = {} + self.transformed_schema: dict[str, type] = {} self.artifact_location = artifact_location self.artifact_mode = artifact_mode if artifact_mode not in ['produce', 'consume']: @@ -217,7 +216,7 @@ def _map_column_names_to_types_from_transforms(self): return column_type_mapping def get_raw_data_feature_spec( - self, input_types: Dict[str, type]) -> Dict[str, tf.io.VarLenFeature]: + self, input_types: dict[str, type]) -> dict[str, tf.io.VarLenFeature]: """ Return a DatasetMetadata object to be used with tft_beam.AnalyzeAndTransformDataset. @@ -265,7 +264,7 @@ def _get_raw_data_feature_spec_per_column( f"Union type is not supported for column: {col_name}. " f"Please pass a PCollection with valid schema for column " f"{col_name} by passing a single type " - "in container. For example, List[int].") + "in container. For example, list[int].") elif issubclass(typ, np.generic) or typ in _default_type_to_tensor_type_map: dtype = typ else: @@ -276,7 +275,7 @@ def _get_raw_data_feature_spec_per_column( return tf.io.VarLenFeature(_default_type_to_tensor_type_map[dtype]) def get_raw_data_metadata( - self, input_types: Dict[str, type]) -> dataset_metadata.DatasetMetadata: + self, input_types: dict[str, type]) -> dataset_metadata.DatasetMetadata: raw_data_feature_spec = self.get_raw_data_feature_spec(input_types) raw_data_feature_spec[_TEMP_KEY] = tf.io.VarLenFeature(dtype=tf.string) return self.convert_raw_data_feature_spec_to_dataset_metadata( @@ -305,8 +304,8 @@ def _fail_on_non_default_windowing(self, pcoll: beam.PCollection): "to convert your PCollection to GlobalWindow.") def process_data_fn( - self, inputs: Dict[str, common_types.ConsistentTensorType] - ) -> Dict[str, common_types.ConsistentTensorType]: + self, inputs: dict[str, common_types.ConsistentTensorType] + ) -> dict[str, common_types.ConsistentTensorType]: """ This method is used in the AnalyzeAndTransformDataset step. It applies the transforms to the `inputs` in sequential order on the columns @@ -335,11 +334,11 @@ def _get_transformed_data_schema( name = feature.name feature_type = feature.type if feature_type == schema_pb2.FeatureType.FLOAT: - transformed_types[name] = typing.Sequence[np.float32] + transformed_types[name] = Sequence[np.float32] elif feature_type == schema_pb2.FeatureType.INT: - transformed_types[name] = typing.Sequence[np.int64] # type: ignore[assignment] + transformed_types[name] = Sequence[np.int64] # type: ignore[assignment] else: - transformed_types[name] = typing.Sequence[bytes] # type: ignore[assignment] + transformed_types[name] = Sequence[bytes] # type: ignore[assignment] return transformed_types def expand( @@ -372,7 +371,7 @@ def expand( raw_data = ( raw_data | _ConvertNamedTupleToDict().with_output_types( - Dict[str, typing.Union[tuple(column_type_mapping.values())]])) # type: ignore + dict[str, Union[tuple(column_type_mapping.values())]])) # type: ignore # AnalyzeAndTransformDataset raise type hint since this is # schema'd PCollection and the current output type would be a # custom type(NamedTuple) or a beam.Row type. @@ -408,7 +407,7 @@ def expand( raw_data = ( raw_data | _ConvertNamedTupleToDict().with_output_types( - Dict[str, typing.Union[tuple(column_type_mapping.values())]])) # type: ignore + dict[str, Union[tuple(column_type_mapping.values())]])) # type: ignore feature_set = [feature.name for feature in raw_data_metadata.schema.feature] diff --git a/sdks/python/apache_beam/ml/transforms/handlers_test.py b/sdks/python/apache_beam/ml/transforms/handlers_test.py index 1331f1308087..4b53026c36a4 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers_test.py +++ b/sdks/python/apache_beam/ml/transforms/handlers_test.py @@ -23,7 +23,6 @@ import typing import unittest import uuid -from typing import List from typing import NamedTuple from typing import Union @@ -65,7 +64,7 @@ class IntType(NamedTuple): class ListIntType(NamedTuple): - x: List[int] + x: list[int] class NumpyType(NamedTuple): @@ -111,7 +110,7 @@ def test_input_type_from_schema_named_tuple_pcoll(self): artifact_location=self.artifact_location) inferred_input_type = process_handler._map_column_names_to_types( element_type) - expected_input_type = dict(x=List[int]) + expected_input_type = dict(x=list[int]) self.assertEqual(inferred_input_type, expected_input_type) @@ -126,7 +125,7 @@ def test_input_type_from_schema_named_tuple_pcoll_list(self): artifact_location=self.artifact_location) inferred_input_type = process_handler._map_column_names_to_types( element_type) - expected_input_type = dict(x=List[int]) + expected_input_type = dict(x=list[int]) self.assertEqual(inferred_input_type, expected_input_type) def test_input_type_from_row_type_pcoll(self): @@ -140,7 +139,7 @@ def test_input_type_from_row_type_pcoll(self): artifact_location=self.artifact_location) inferred_input_type = process_handler._map_column_names_to_types( element_type) - expected_input_type = dict(x=List[int]) + expected_input_type = dict(x=list[int]) self.assertEqual(inferred_input_type, expected_input_type) def test_input_type_from_row_type_pcoll_list(self): @@ -149,14 +148,14 @@ def test_input_type_from_row_type_pcoll_list(self): data = ( p | beam.Create(data) | beam.Map(lambda ele: beam.Row(x=list(ele['x']))).with_output_types( - beam.row_type.RowTypeConstraint.from_fields([('x', List[int])]))) + beam.row_type.RowTypeConstraint.from_fields([('x', list[int])]))) element_type = data.element_type process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location) inferred_input_type = process_handler._map_column_names_to_types( element_type) - expected_input_type = dict(x=List[int]) + expected_input_type = dict(x=list[int]) self.assertEqual(inferred_input_type, expected_input_type) def test_input_type_from_named_tuple_pcoll_numpy(self): @@ -190,8 +189,8 @@ def test_tensorflow_raw_data_metadata_primitive_types(self): self.assertIsInstance(feature_spec, tf.io.VarLenFeature) def test_tensorflow_raw_data_metadata_primitive_types_in_containers(self): - input_types = dict([("x", List[int]), ("y", List[float]), - ("k", List[bytes]), ("l", List[str])]) + input_types = dict([("x", list[int]), ("y", list[float]), + ("k", list[bytes]), ("l", list[str])]) process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location) for col_name, typ in input_types.items(): @@ -211,7 +210,7 @@ def test_tensorflow_raw_data_metadata_primitive_native_container_types(self): self.assertIsInstance(feature_spec, tf.io.VarLenFeature) def test_tensorflow_raw_data_metadata_numpy_types(self): - input_types = dict(x=np.int64, y=np.float32, z=List[np.int64]) + input_types = dict(x=np.int64, y=np.float32, z=list[np.int64]) process_handler = handlers.TFTProcessHandler( artifact_location=self.artifact_location) for col_name, typ in input_types.items(): diff --git a/sdks/python/apache_beam/ml/transforms/tft.py b/sdks/python/apache_beam/ml/transforms/tft.py index 6903cca89419..bfe23757642b 100644 --- a/sdks/python/apache_beam/ml/transforms/tft.py +++ b/sdks/python/apache_beam/ml/transforms/tft.py @@ -34,12 +34,9 @@ # pytype: skip-file import logging +from collections.abc import Iterable from typing import Any -from typing import Dict -from typing import Iterable -from typing import List from typing import Optional -from typing import Tuple from typing import Union import apache_beam as beam @@ -67,7 +64,7 @@ # Register the expected input types for each operation # this will be used to determine schema for the tft.AnalyzeDataset -_EXPECTED_TYPES: Dict[str, Union[int, str, float]] = {} +_EXPECTED_TYPES: dict[str, Union[int, str, float]] = {} _LOGGER = logging.getLogger(__name__) @@ -84,7 +81,7 @@ def wrapper(fn): # Add support for outputting artifacts to a text file in human readable form. class TFTOperation(BaseOperation[common_types.TensorType, common_types.TensorType]): - def __init__(self, columns: List[str]) -> None: + def __init__(self, columns: list[str]) -> None: """ Base Operation class for TFT data processing transformations. Processing logic for the transformation is defined in the @@ -150,7 +147,7 @@ def _split_string_with_delimiter(self, data, delimiter): class ComputeAndApplyVocabulary(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], split_string_by_delimiter: Optional[str] = None, *, default_value: Any = -1, @@ -193,7 +190,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: if self.split_string_by_delimiter: data = self._split_string_with_delimiter( @@ -218,7 +215,7 @@ def apply_transform( class ScaleToZScore(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], *, elementwise: bool = False, name: Optional[str] = None): @@ -247,7 +244,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output_dict = { output_column_name: tft.scale_to_z_score( x=data, elementwise=self.elementwise, name=self.name) @@ -259,7 +256,7 @@ def apply_transform( class ScaleTo01(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], elementwise: bool = False, name: Optional[str] = None): """ @@ -287,7 +284,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output = tft.scale_to_0_1( x=data, elementwise=self.elementwise, name=self.name) @@ -299,7 +296,7 @@ def apply_transform( class ScaleToGaussian(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], elementwise: bool = False, name: Optional[str] = None): """ @@ -324,7 +321,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output_dict = { output_column_name: tft.scale_to_gaussian( x=data, elementwise=self.elementwise, name=self.name) @@ -336,7 +333,7 @@ def apply_transform( class ApplyBuckets(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], bucket_boundaries: Iterable[Union[int, float]], name: Optional[str] = None): """ @@ -359,7 +356,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output = { output_column_name: tft.apply_buckets( x=data, bucket_boundaries=self.bucket_boundaries, name=self.name) @@ -371,7 +368,7 @@ def apply_transform( class ApplyBucketsWithInterpolation(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], bucket_boundaries: Iterable[Union[int, float]], name: Optional[str] = None): """Interpolates values within the provided buckets and then normalizes to @@ -398,7 +395,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output = { output_column_name: tft.apply_buckets_with_interpolation( x=data, bucket_boundaries=self.bucket_boundaries, name=self.name) @@ -410,7 +407,7 @@ def apply_transform( class Bucketize(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], num_buckets: int, *, epsilon: Optional[float] = None, @@ -443,7 +440,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: output = { output_column_name: tft.bucketize( x=data, @@ -459,7 +456,7 @@ def apply_transform( class TFIDF(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], vocab_size: Optional[int] = None, smooth: bool = True, name: Optional[str] = None, @@ -530,7 +527,7 @@ def apply_transform( class ScaleByMinMax(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], min_value: float = 0.0, max_value: float = 1.0, name: Optional[str] = None): @@ -566,10 +563,10 @@ def apply_transform( class NGrams(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], split_string_by_delimiter: Optional[str] = None, *, - ngram_range: Tuple[int, int] = (1, 1), + ngram_range: tuple[int, int] = (1, 1), ngrams_separator: Optional[str] = None, name: Optional[str] = None): """ @@ -599,7 +596,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_column_name: str) -> Dict[str, common_types.TensorType]: + output_column_name: str) -> dict[str, common_types.TensorType]: if self.split_string_by_delimiter: data = self._split_string_with_delimiter( data, self.split_string_by_delimiter) @@ -611,10 +608,10 @@ def apply_transform( class BagOfWords(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], split_string_by_delimiter: Optional[str] = None, *, - ngram_range: Tuple[int, int] = (1, 1), + ngram_range: tuple[int, int] = (1, 1), ngrams_separator: Optional[str] = None, compute_word_count: bool = False, key_vocab_filename: Optional[str] = None, @@ -686,9 +683,9 @@ def count_unique_words( class HashStrings(TFTOperation): def __init__( self, - columns: List[str], + columns: list[str], hash_buckets: int, - key: Optional[Tuple[int, int]] = None, + key: Optional[tuple[int, int]] = None, name: Optional[str] = None): '''Hashes strings into the provided number of buckets. @@ -715,7 +712,7 @@ def __init__( def apply_transform( self, data: common_types.TensorType, - output_col_name: str) -> Dict[str, common_types.TensorType]: + output_col_name: str) -> dict[str, common_types.TensorType]: output_dict = { output_col_name: tft.hash_strings( strings=data, @@ -728,7 +725,7 @@ def apply_transform( @register_input_dtype(str) class DeduplicateTensorPerRow(TFTOperation): - def __init__(self, columns: List[str], name: Optional[str] = None): + def __init__(self, columns: list[str], name: Optional[str] = None): """ Deduplicates each row (0th dimension) of the provided tensor. Args: @@ -740,7 +737,7 @@ def __init__(self, columns: List[str], name: Optional[str] = None): def apply_transform( self, data: common_types.TensorType, - output_col_name: str) -> Dict[str, common_types.TensorType]: + output_col_name: str) -> dict[str, common_types.TensorType]: output_dict = { output_col_name: tft.deduplicate_tensor_per_row( input_tensor=data, name=self.name) diff --git a/sdks/python/apache_beam/ml/transforms/utils.py b/sdks/python/apache_beam/ml/transforms/utils.py index abf4c48fc642..023657895686 100644 --- a/sdks/python/apache_beam/ml/transforms/utils.py +++ b/sdks/python/apache_beam/ml/transforms/utils.py @@ -19,7 +19,6 @@ import os import tempfile -import typing from google.cloud.storage import Client from google.cloud.storage import transfer_manager @@ -72,7 +71,7 @@ def __init__(self, artifact_location: str): self._artifact_location = os.path.join(artifact_location, files[0]) self.transform_output = tft.TFTransformOutput(self._artifact_location) - def get_vocab_list(self, vocab_filename: str) -> typing.List[bytes]: + def get_vocab_list(self, vocab_filename: str) -> list[bytes]: """ Returns list of vocabulary terms created during MLTransform. """ From 88ada9dfee3c4602ff62b0f6ebdded29518760c9 Mon Sep 17 00:00:00 2001 From: martin trieu Date: Wed, 30 Oct 2024 04:19:36 -0600 Subject: [PATCH 088/181] fix silent failures in dispatch loop from stalling the pipeline (#32922) * use ExecutorService instead of ScheduledExecutorService which swallows exceptions into futures that were not examined Co-authored-by: Arun Pandian --- .../worker/DataflowWorkerHarnessHelper.java | 4 +- .../worker/StreamingDataflowWorker.java | 14 +-- .../WorkerUncaughtExceptionHandler.java | 10 +- .../FanOutStreamingEngineWorkerHarness.java | 2 +- .../harness/SingleSourceWorkerHarness.java | 2 +- .../StreamingApplianceWorkCommitter.java | 2 +- .../work/budget/GetWorkBudgetRefresher.java | 2 +- .../processing/StreamingWorkScheduler.java | 2 +- ...ingEngineComputationConfigFetcherTest.java | 1 - .../SingleSourceWorkerHarnessTest.java | 117 ++++++++++++++++++ 10 files changed, 136 insertions(+), 20 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkerHarnessHelper.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkerHarnessHelper.java index 94c894608a47..a28a5e989c88 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkerHarnessHelper.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowWorkerHarnessHelper.java @@ -82,7 +82,9 @@ public static T initializeGlobalStateAn @SuppressWarnings("Slf4jIllegalPassedClass") public static void initializeLogging(Class workerHarnessClass) { - /* Set up exception handling tied to the workerHarnessClass. */ + // Set up exception handling for raw Threads tied to the workerHarnessClass. + // Does NOT handle exceptions thrown by threads created by + // ScheduledExecutors/ScheduledExecutorServices. Thread.setDefaultUncaughtExceptionHandler( new WorkerUncaughtExceptionHandler(LoggerFactory.getLogger(workerHarnessClass))); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index c478341c1c39..ff72add83e4d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -175,7 +175,7 @@ private StreamingDataflowWorker( StreamingCounters streamingCounters, MemoryMonitor memoryMonitor, GrpcWindmillStreamFactory windmillStreamFactory, - Function executorSupplier, + ScheduledExecutorService activeWorkRefreshExecutorFn, ConcurrentMap stageInfoMap) { // Register standard file systems. FileSystems.setDefaultPipelineOptions(options); @@ -285,7 +285,7 @@ private StreamingDataflowWorker( stuckCommitDurationMillis, computationStateCache::getAllPresentComputations, sampler, - executorSupplier.apply("RefreshWork"), + activeWorkRefreshExecutorFn, getDataMetricTracker::trackHeartbeats); this.statusPages = @@ -347,10 +347,7 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o .setSizeMb(options.getWorkerCacheMb()) .setSupportMapViaMultimap(options.isEnableStreamingEngine()) .build(); - Function executorSupplier = - threadName -> - Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder().setNameFormat(threadName).build()); + GrpcWindmillStreamFactory.Builder windmillStreamFactoryBuilder = createGrpcwindmillStreamFactoryBuilder(options, clientId); @@ -417,7 +414,8 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o streamingCounters, memoryMonitor, configFetcherComputationStateCacheAndWindmillClient.windmillStreamFactory(), - executorSupplier, + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder().setNameFormat("RefreshWork").build()), stageInfo); } @@ -595,7 +593,7 @@ static StreamingDataflowWorker forTesting( options.getWindmillServiceStreamingRpcHealthCheckPeriodMs()) .build() : windmillStreamFactory.build(), - executorSupplier, + executorSupplier.apply("RefreshWork"), stageInfo); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerUncaughtExceptionHandler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerUncaughtExceptionHandler.java index 5a8e87d23ab9..b4ec170099d5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerUncaughtExceptionHandler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerUncaughtExceptionHandler.java @@ -28,16 +28,16 @@ * This uncaught exception handler logs the {@link Throwable} to the logger, {@link System#err} and * exits the application with status code 1. */ -class WorkerUncaughtExceptionHandler implements UncaughtExceptionHandler { +public final class WorkerUncaughtExceptionHandler implements UncaughtExceptionHandler { + @VisibleForTesting public static final int JVM_TERMINATED_STATUS_CODE = 1; private final JvmRuntime runtime; private final Logger logger; - WorkerUncaughtExceptionHandler(Logger logger) { + public WorkerUncaughtExceptionHandler(Logger logger) { this(JvmRuntime.INSTANCE, logger); } - @VisibleForTesting - WorkerUncaughtExceptionHandler(JvmRuntime runtime, Logger logger) { + public WorkerUncaughtExceptionHandler(JvmRuntime runtime, Logger logger) { this.runtime = runtime; this.logger = logger; } @@ -59,7 +59,7 @@ public void uncaughtException(Thread thread, Throwable e) { t.printStackTrace(originalStdErr); } } finally { - runtime.halt(1); + runtime.halt(JVM_TERMINATED_STATUS_CODE); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 458cf57ca8e7..3eed4ee6d835 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -137,7 +137,7 @@ private FanOutStreamingEngineWorkerHarness( Executors.newCachedThreadPool( new ThreadFactoryBuilder().setNameFormat(STREAM_MANAGER_THREAD_NAME).build()); this.workerMetadataConsumer = - Executors.newSingleThreadScheduledExecutor( + Executors.newSingleThreadExecutor( new ThreadFactoryBuilder().setNameFormat(WORKER_METADATA_CONSUMER_THREAD_NAME).build()); this.getWorkBudgetDistributor = getWorkBudgetDistributor; this.totalGetWorkBudget = totalGetWorkBudget; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java index bc93e6d89c41..06598b61c458 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java @@ -82,7 +82,7 @@ public final class SingleSourceWorkerHarness implements StreamingWorkerHarness { this.waitForResources = waitForResources; this.computationStateFetcher = computationStateFetcher; this.workProviderExecutor = - Executors.newSingleThreadScheduledExecutor( + Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setDaemon(true) .setPriority(Thread.MIN_PRIORITY) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java index d092ebf53fc1..6889764afe69 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java @@ -57,7 +57,7 @@ private StreamingApplianceWorkCommitter( WeightedBoundedQueue.create( MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); this.commitWorkers = - Executors.newSingleThreadScheduledExecutor( + Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setDaemon(true) .setPriority(Thread.MAX_PRIORITY) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java index e39aa8dbc8a5..d81c7d0593f3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/GetWorkBudgetRefresher.java @@ -51,7 +51,7 @@ public GetWorkBudgetRefresher( Supplier isBudgetRefreshPaused, Runnable redistributeBudget) { this.budgetRefreshTrigger = new AdvancingPhaser(1); this.budgetRefreshExecutor = - Executors.newSingleThreadScheduledExecutor( + Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setNameFormat(BUDGET_REFRESH_THREAD) .setUncaughtExceptionHandler( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java index 9a3e6eb6b099..c74874c465a6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingWorkScheduler.java @@ -70,7 +70,7 @@ */ @Internal @ThreadSafe -public final class StreamingWorkScheduler { +public class StreamingWorkScheduler { private static final Logger LOG = LoggerFactory.getLogger(StreamingWorkScheduler.class); private final DataflowWorkerHarnessOptions options; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcherTest.java index 9fa17588c94d..3a0ae7bb2084 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/config/StreamingEngineComputationConfigFetcherTest.java @@ -47,7 +47,6 @@ @RunWith(JUnit4.class) public class StreamingEngineComputationConfigFetcherTest { - private final WorkUnitClient mockDataflowServiceClient = mock(WorkUnitClient.class, new Returns(Optional.empty())); private StreamingEngineComputationConfigFetcher streamingEngineConfigFetcher; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java new file mode 100644 index 000000000000..5a2df4baae61 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarnessTest.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.harness; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.WorkerUncaughtExceptionHandler; +import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; +import org.apache.beam.runners.dataflow.worker.util.common.worker.JvmRuntime; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.work.processing.StreamingWorkScheduler; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@RunWith(JUnit4.class) +public class SingleSourceWorkerHarnessTest { + private static final Logger LOG = LoggerFactory.getLogger(SingleSourceWorkerHarnessTest.class); + private final WorkCommitter workCommitter = mock(WorkCommitter.class); + private final GetDataClient getDataClient = mock(GetDataClient.class); + private final HeartbeatSender heartbeatSender = mock(HeartbeatSender.class); + private final Runnable waitForResources = () -> {}; + private final Function> computationStateFetcher = + ignored -> Optional.empty(); + private final StreamingWorkScheduler streamingWorkScheduler = mock(StreamingWorkScheduler.class); + + private SingleSourceWorkerHarness createWorkerHarness( + SingleSourceWorkerHarness.GetWorkSender getWorkSender, JvmRuntime runtime) { + // In non-test scenario this is set in DataflowWorkerHarnessHelper.initializeLogging(...). + Thread.setDefaultUncaughtExceptionHandler(new WorkerUncaughtExceptionHandler(runtime, LOG)); + return SingleSourceWorkerHarness.builder() + .setWorkCommitter(workCommitter) + .setGetDataClient(getDataClient) + .setHeartbeatSender(heartbeatSender) + .setWaitForResources(waitForResources) + .setStreamingWorkScheduler(streamingWorkScheduler) + .setComputationStateFetcher(computationStateFetcher) + .setGetWorkSender(getWorkSender) + .build(); + } + + @Test + public void testDispatchLoop_unexpectedFailureKillsJvm_appliance() { + SingleSourceWorkerHarness.GetWorkSender getWorkSender = + SingleSourceWorkerHarness.GetWorkSender.forAppliance( + () -> { + throw new RuntimeException("something bad happened"); + }); + + FakeJvmRuntime fakeJvmRuntime = new FakeJvmRuntime(); + createWorkerHarness(getWorkSender, fakeJvmRuntime).start(); + assertTrue(fakeJvmRuntime.waitForRuntimeDeath(5, TimeUnit.SECONDS)); + fakeJvmRuntime.assertJvmTerminated(); + } + + @Test + public void testDispatchLoop_unexpectedFailureKillsJvm_streamingEngine() { + SingleSourceWorkerHarness.GetWorkSender getWorkSender = + SingleSourceWorkerHarness.GetWorkSender.forStreamingEngine( + workItemReceiver -> { + throw new RuntimeException("something bad happened"); + }); + + FakeJvmRuntime fakeJvmRuntime = new FakeJvmRuntime(); + createWorkerHarness(getWorkSender, fakeJvmRuntime).start(); + assertTrue(fakeJvmRuntime.waitForRuntimeDeath(5, TimeUnit.SECONDS)); + fakeJvmRuntime.assertJvmTerminated(); + } + + private static class FakeJvmRuntime implements JvmRuntime { + private final CountDownLatch haltedLatch = new CountDownLatch(1); + private volatile int exitStatus = 0; + + @Override + public void halt(int status) { + exitStatus = status; + haltedLatch.countDown(); + } + + public boolean waitForRuntimeDeath(long timeout, TimeUnit unit) { + try { + return haltedLatch.await(timeout, unit); + } catch (InterruptedException e) { + return false; + } + } + + private void assertJvmTerminated() { + assertThat(exitStatus).isEqualTo(WorkerUncaughtExceptionHandler.JVM_TERMINATED_STATUS_CODE); + } + } +} From 13ac5be2ee77651245823d29a4af42039506a189 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Wed, 30 Oct 2024 10:16:00 -0400 Subject: [PATCH 089/181] formatting --- .../python/apache_beam/ml/transforms/handlers.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/ml/transforms/handlers.py b/sdks/python/apache_beam/ml/transforms/handlers.py index e732e92e14d5..1e752049f6e5 100644 --- a/sdks/python/apache_beam/ml/transforms/handlers.py +++ b/sdks/python/apache_beam/ml/transforms/handlers.py @@ -19,8 +19,8 @@ import collections import copy import os -from collections.abc import Sequence import typing +from collections.abc import Sequence from typing import Any from typing import NamedTuple from typing import Optional @@ -74,13 +74,13 @@ } tft_process_handler_input_type = Union[NamedTuple, - beam.Row, - dict[str, - Union[str, - float, - int, - bytes, - np.ndarray]]] + beam.Row, + dict[str, + Union[str, + float, + int, + bytes, + np.ndarray]]] tft_process_handler_output_type = Union[beam.Row, dict[str, np.ndarray]] From 2a27cc6182f02418f22b0c303177cb3f644e423a Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Wed, 30 Oct 2024 12:16:19 -0400 Subject: [PATCH 090/181] Copy in correct requirements file (#32974) * Copy in correct requirements file * Trigger postcommit --- .../trigger_files/beam_PostCommit_TransformService_Direct.json | 3 ++- sdks/python/expansion-service-container/build.gradle | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/trigger_files/beam_PostCommit_TransformService_Direct.json b/.github/trigger_files/beam_PostCommit_TransformService_Direct.json index c4edaa85a89d..7663aee09101 100644 --- a/.github/trigger_files/beam_PostCommit_TransformService_Direct.json +++ b/.github/trigger_files/beam_PostCommit_TransformService_Direct.json @@ -1,3 +1,4 @@ { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "revision: "1" } diff --git a/sdks/python/expansion-service-container/build.gradle b/sdks/python/expansion-service-container/build.gradle index 3edcaee35b4a..4e46f060e59f 100644 --- a/sdks/python/expansion-service-container/build.gradle +++ b/sdks/python/expansion-service-container/build.gradle @@ -40,7 +40,7 @@ task copyDockerfileDependencies(type: Copy) { } task copyRequirementsFile(type: Copy) { - from project(':sdks:python:container:py38').fileTree("./") + from project(':sdks:python:container:py39').fileTree("./") include 'base_image_requirements.txt' rename 'base_image_requirements.txt', 'requirements.txt' setDuplicatesStrategy(DuplicatesStrategy.INCLUDE) From dd556b23b249dff72ff0dba2e6e9b809b334d1c9 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Wed, 30 Oct 2024 18:06:02 +0100 Subject: [PATCH 091/181] Merge pull request #32081: make FieldValueTypeInformation creators take a TypeDescriptor parameter --- .../beam/sdk/schemas/AutoValueSchema.java | 20 +- .../beam/sdk/schemas/CachingFactory.java | 22 ++- .../beam/sdk/schemas/FieldValueGetter.java | 3 +- .../schemas/FieldValueTypeInformation.java | 73 ++++--- .../schemas/GetterBasedSchemaProvider.java | 178 ++++++++++------- .../schemas/GetterBasedSchemaProviderV2.java | 7 +- .../beam/sdk/schemas/JavaBeanSchema.java | 32 ++-- .../beam/sdk/schemas/JavaFieldSchema.java | 19 +- .../org/apache/beam/sdk/schemas/Schema.java | 12 +- .../sdk/schemas/utils/AutoValueUtils.java | 44 +++-- .../sdk/schemas/utils/ByteBuddyUtils.java | 180 ++++++++++-------- .../beam/sdk/schemas/utils/JavaBeanUtils.java | 124 +++++++----- .../beam/sdk/schemas/utils/POJOUtils.java | 131 +++++++------ .../beam/sdk/schemas/utils/ReflectUtils.java | 26 ++- .../beam/sdk/util/common/ReflectHelpers.java | 3 +- .../java/org/apache/beam/sdk/values/Row.java | 8 +- .../beam/sdk/values/RowWithGetters.java | 27 ++- .../FieldValueTypeInformationTest.java | 70 +++++++ .../sdk/schemas/utils/JavaBeanUtilsTest.java | 7 +- .../beam/sdk/schemas/utils/POJOUtilsTest.java | 11 +- .../sdk/extensions/arrow/ArrowConversion.java | 7 +- .../avro/schemas/AvroRecordSchema.java | 5 +- .../avro/schemas/utils/AvroUtils.java | 158 +++++++++------ .../protobuf/ProtoByteBuddyUtils.java | 128 +++++++------ .../protobuf/ProtoMessageSchema.java | 21 +- .../io/aws2/schemas/AwsSchemaProvider.java | 17 +- .../sdk/io/aws2/schemas/AwsSchemaUtils.java | 6 +- .../beam/sdk/io/thrift/ThriftSchema.java | 19 +- 28 files changed, 827 insertions(+), 531 deletions(-) create mode 100644 sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java index 5ccfe39b92af..f35782c2b9a2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/AutoValueSchema.java @@ -19,7 +19,6 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.util.Comparator; import java.util.List; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; @@ -32,13 +31,10 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** A {@link SchemaProvider} for AutoValue classes. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) public class AutoValueSchema extends GetterBasedSchemaProviderV2 { /** {@link FieldValueTypeSupplier} that's based on AutoValue getters. */ @VisibleForTesting @@ -49,7 +45,11 @@ public static class AbstractGetterTypeSupplier implements FieldValueTypeSupplier public List get(TypeDescriptor typeDescriptor) { // If the generated class is passed in, we want to look at the base class to find the getters. - TypeDescriptor targetTypeDescriptor = AutoValueUtils.getBaseAutoValueClass(typeDescriptor); + TypeDescriptor targetTypeDescriptor = + Preconditions.checkNotNull( + AutoValueUtils.getBaseAutoValueClass(typeDescriptor), + "unable to determine base AutoValue class for type {}", + typeDescriptor); List methods = ReflectUtils.getMethods(targetTypeDescriptor.getRawType()).stream() @@ -62,9 +62,9 @@ public List get(TypeDescriptor typeDescriptor) { .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); + types.add(FieldValueTypeInformation.forGetter(typeDescriptor, methods.get(i), i)); } - types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); + types.sort(JavaBeanUtils.comparingNullFirst(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); return types; } @@ -89,8 +89,8 @@ private static void validateFieldNumbers(List types) } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return JavaBeanUtils.getGetters( targetTypeDescriptor, schema, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java index 8725833bc1da..6e244fefb263 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/CachingFactory.java @@ -20,6 +20,9 @@ import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.initialization.qual.NotOnlyInitialized; +import org.checkerframework.checker.initialization.qual.UnknownInitialization; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -32,24 +35,25 @@ * significant for larger schemas) on each lookup. This wrapper caches the value returned by the * inner factory, so the schema comparison only need happen on the first lookup. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) -public class CachingFactory implements Factory { +public class CachingFactory implements Factory { private transient @Nullable ConcurrentHashMap, CreatedT> cache = null; - private final Factory innerFactory; + private final @NotOnlyInitialized Factory innerFactory; - public CachingFactory(Factory innerFactory) { + public CachingFactory(@UnknownInitialization Factory innerFactory) { this.innerFactory = innerFactory; } - @Override - public CreatedT create(TypeDescriptor typeDescriptor, Schema schema) { + private ConcurrentHashMap, CreatedT> getCache() { if (cache == null) { cache = new ConcurrentHashMap<>(); } + return cache; + } + + @Override + public CreatedT create(TypeDescriptor typeDescriptor, Schema schema) { + ConcurrentHashMap, CreatedT> cache = getCache(); CreatedT cached = cache.get(typeDescriptor); if (cached != null) { return cached; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java index fb98db8e8343..63ab56dc7609 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java @@ -19,6 +19,7 @@ import java.io.Serializable; import org.apache.beam.sdk.annotations.Internal; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -29,7 +30,7 @@ *

Implementations of this interface are generated at runtime to map object fields to Row fields. */ @Internal -public interface FieldValueGetter extends Serializable { +public interface FieldValueGetter extends Serializable { @Nullable ValueT get(ObjectT object); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java index 750709192c08..43aac6a5e20c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java @@ -27,7 +27,9 @@ import java.util.Arrays; import java.util.Collections; import java.util.Map; +import java.util.Optional; import java.util.stream.Stream; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; @@ -40,10 +42,7 @@ /** Represents type information for a Java type that will be used to infer a Schema type. */ @AutoValue -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@Internal public abstract class FieldValueTypeInformation implements Serializable { /** Optionally returns the field index. */ public abstract @Nullable Integer getNumber(); @@ -125,8 +124,13 @@ public static FieldValueTypeInformation forOneOf( .build(); } - public static FieldValueTypeInformation forField(Field field, int index) { - TypeDescriptor type = TypeDescriptor.of(field.getGenericType()); + public static FieldValueTypeInformation forField( + @Nullable TypeDescriptor typeDescriptor, Field field, int index) { + TypeDescriptor type = + Optional.ofNullable(typeDescriptor) + .map(td -> (TypeDescriptor) td.resolveType(field.getGenericType())) + // fall back to previous behavior + .orElseGet(() -> TypeDescriptor.of(field.getGenericType())); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(field.getName(), field)) .setNumber(getNumberOverride(index, field)) @@ -134,9 +138,9 @@ public static FieldValueTypeInformation forField(Field field, int index) { .setType(type) .setRawType(type.getRawType()) .setField(field) - .setElementType(getIterableComponentType(field)) - .setMapKeyType(getMapKeyType(field)) - .setMapValueType(getMapValueType(field)) + .setElementType(getIterableComponentType(type)) + .setMapKeyType(getMapKeyType(type)) + .setMapValueType(getMapValueType(type)) .setOneOfTypes(Collections.emptyMap()) .setDescription(getFieldDescription(field)) .build(); @@ -185,6 +189,11 @@ public static String getNameOverride( } public static FieldValueTypeInformation forGetter(Method method, int index) { + return forGetter(null, method, index); + } + + public static FieldValueTypeInformation forGetter( + @Nullable TypeDescriptor typeDescriptor, Method method, int index) { String name; if (method.getName().startsWith("get")) { name = ReflectUtils.stripPrefix(method.getName(), "get"); @@ -194,7 +203,12 @@ public static FieldValueTypeInformation forGetter(Method method, int index) { throw new RuntimeException("Getter has wrong prefix " + method.getName()); } - TypeDescriptor type = TypeDescriptor.of(method.getGenericReturnType()); + TypeDescriptor type = + Optional.ofNullable(typeDescriptor) + .map(td -> (TypeDescriptor) td.resolveType(method.getGenericReturnType())) + // fall back to previous behavior + .orElseGet(() -> TypeDescriptor.of(method.getGenericReturnType())); + boolean nullable = hasNullableReturnType(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(getNameOverride(name, method)) @@ -253,10 +267,20 @@ private static boolean isNullableAnnotation(Annotation annotation) { } public static FieldValueTypeInformation forSetter(Method method) { - return forSetter(method, "set"); + return forSetter(null, method); } public static FieldValueTypeInformation forSetter(Method method, String setterPrefix) { + return forSetter(null, method, setterPrefix); + } + + public static FieldValueTypeInformation forSetter( + @Nullable TypeDescriptor typeDescriptor, Method method) { + return forSetter(typeDescriptor, method, "set"); + } + + public static FieldValueTypeInformation forSetter( + @Nullable TypeDescriptor typeDescriptor, Method method, String setterPrefix) { String name; if (method.getName().startsWith(setterPrefix)) { name = ReflectUtils.stripPrefix(method.getName(), setterPrefix); @@ -264,7 +288,11 @@ public static FieldValueTypeInformation forSetter(Method method, String setterPr throw new RuntimeException("Setter has wrong prefix " + method.getName()); } - TypeDescriptor type = TypeDescriptor.of(method.getGenericParameterTypes()[0]); + TypeDescriptor type = + Optional.ofNullable(typeDescriptor) + .map(td -> (TypeDescriptor) td.resolveType(method.getGenericParameterTypes()[0])) + // fall back to previous behavior + .orElseGet(() -> TypeDescriptor.of(method.getGenericParameterTypes()[0])); boolean nullable = hasSingleNullableParameter(method); return new AutoValue_FieldValueTypeInformation.Builder() .setName(name) @@ -283,10 +311,6 @@ public FieldValueTypeInformation withName(String name) { return toBuilder().setName(name).build(); } - private static FieldValueTypeInformation getIterableComponentType(Field field) { - return getIterableComponentType(TypeDescriptor.of(field.getGenericType())); - } - static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) { // TODO: Figure out nullable elements. TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType); @@ -306,23 +330,13 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) { .build(); } - // If the Field is a map type, returns the key type, otherwise returns a null reference. - - private static @Nullable FieldValueTypeInformation getMapKeyType(Field field) { - return getMapKeyType(TypeDescriptor.of(field.getGenericType())); - } - + // If the type is a map type, returns the key type, otherwise returns a null reference. private static @Nullable FieldValueTypeInformation getMapKeyType( TypeDescriptor typeDescriptor) { return getMapType(typeDescriptor, 0); } - // If the Field is a map type, returns the value type, otherwise returns a null reference. - - private static @Nullable FieldValueTypeInformation getMapValueType(Field field) { - return getMapType(TypeDescriptor.of(field.getGenericType()), 1); - } - + // If the type is a map type, returns the value type, otherwise returns a null reference. private static @Nullable FieldValueTypeInformation getMapValueType( TypeDescriptor typeDescriptor) { return getMapType(typeDescriptor, 1); @@ -330,10 +344,9 @@ private static FieldValueTypeInformation getIterableComponentType(Field field) { // If the Field is a map type, returns the key or value type (0 is key type, 1 is value). // Otherwise returns a null reference. - @SuppressWarnings("unchecked") private static @Nullable FieldValueTypeInformation getMapType( TypeDescriptor valueType, int index) { - TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index); + TypeDescriptor mapType = ReflectUtils.getMapType(valueType, index); if (mapType == null) { return null; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java index ce5be71933b8..4e431bb45207 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java @@ -17,13 +17,12 @@ */ package org.apache.beam.sdk.schemas; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; - import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.LogicalType; import org.apache.beam.sdk.schemas.Schema.TypeName; @@ -32,10 +31,13 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collections2; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.initialization.qual.NotOnlyInitialized; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -46,10 +48,7 @@ * methods which receive {@link TypeDescriptor}s instead of ordinary {@link Class}es as * arguments, which permits to support generic type signatures during schema inference */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) @Deprecated public abstract class GetterBasedSchemaProvider implements SchemaProvider { @@ -67,9 +66,9 @@ public abstract class GetterBasedSchemaProvider implements SchemaProvider { * override it if you want to use the richer type signature contained in the {@link * TypeDescriptor} not subject to the type erasure. */ - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { - return fieldValueGetters(targetTypeDescriptor.getRawType(), schema); + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { + return (List) fieldValueGetters(targetTypeDescriptor.getRawType(), schema); } /** @@ -112,9 +111,10 @@ public SchemaUserTypeCreator schemaTypeCreator( return schemaTypeCreator(targetTypeDescriptor.getRawType(), schema); } - private class ToRowWithValueGetters implements SerializableFunction { + private class ToRowWithValueGetters + implements SerializableFunction { private final Schema schema; - private final Factory> getterFactory; + private final Factory>> getterFactory; public ToRowWithValueGetters(Schema schema) { this.schema = schema; @@ -122,7 +122,12 @@ public ToRowWithValueGetters(Schema schema) { // schema, return a caching factory that caches the first value seen for each class. This // prevents having to lookup the getter list each time createGetters is called. this.getterFactory = - RowValueGettersFactory.of(GetterBasedSchemaProvider.this::fieldValueGetters); + RowValueGettersFactory.of( + (Factory>>) + (typeDescriptor, schema1) -> + (List) + GetterBasedSchemaProvider.this.fieldValueGetters( + typeDescriptor, schema1)); } @Override @@ -160,13 +165,15 @@ public SerializableFunction toRowFunction(TypeDescriptor typeDesc // important to capture the schema once here, so all invocations of the toRowFunction see the // same version of the schema. If schemaFor were to be called inside the lambda below, different // workers would see different versions of the schema. - Schema schema = schemaFor(typeDescriptor); + @NonNull + Schema schema = + Verify.verifyNotNull( + schemaFor(typeDescriptor), "can't create a ToRowFunction with null schema"); return new ToRowWithValueGetters<>(schema); } @Override - @SuppressWarnings("unchecked") public SerializableFunction fromRowFunction(TypeDescriptor typeDescriptor) { return new FromRowUsingCreator<>(typeDescriptor, this); } @@ -181,23 +188,27 @@ public boolean equals(@Nullable Object obj) { return obj != null && this.getClass() == obj.getClass(); } - private static class RowValueGettersFactory implements Factory> { - private final Factory> gettersFactory; - private final Factory> cachingGettersFactory; + private static class RowValueGettersFactory + implements Factory>> { + private final Factory>> gettersFactory; + private final @NotOnlyInitialized Factory>> + cachingGettersFactory; - static Factory> of(Factory> gettersFactory) { - return new RowValueGettersFactory(gettersFactory).cachingGettersFactory; + static Factory>> of( + Factory>> gettersFactory) { + return new RowValueGettersFactory<>(gettersFactory).cachingGettersFactory; } - RowValueGettersFactory(Factory> gettersFactory) { + RowValueGettersFactory(Factory>> gettersFactory) { this.gettersFactory = gettersFactory; this.cachingGettersFactory = new CachingFactory<>(this); } @Override - public List create(TypeDescriptor typeDescriptor, Schema schema) { - List getters = gettersFactory.create(typeDescriptor, schema); - List rowGetters = new ArrayList<>(getters.size()); + public List> create( + TypeDescriptor typeDescriptor, Schema schema) { + List> getters = gettersFactory.create(typeDescriptor, schema); + List> rowGetters = new ArrayList<>(getters.size()); for (int i = 0; i < getters.size(); i++) { rowGetters.add(rowValueGetter(getters.get(i), schema.getField(i).getType())); } @@ -209,71 +220,80 @@ static boolean needsConversion(FieldType type) { return typeName.equals(TypeName.ROW) || typeName.isLogicalType() || ((typeName.equals(TypeName.ARRAY) || typeName.equals(TypeName.ITERABLE)) - && needsConversion(type.getCollectionElementType())) + && needsConversion(Verify.verifyNotNull(type.getCollectionElementType()))) || (typeName.equals(TypeName.MAP) - && (needsConversion(type.getMapKeyType()) - || needsConversion(type.getMapValueType()))); + && (needsConversion(Verify.verifyNotNull(type.getMapKeyType())) + || needsConversion(Verify.verifyNotNull(type.getMapValueType())))); } - FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { + FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { TypeName typeName = type.getTypeName(); if (!needsConversion(type)) { return base; } if (typeName.equals(TypeName.ROW)) { - return new GetRow(base, type.getRowSchema(), cachingGettersFactory); + return new GetRow(base, Verify.verifyNotNull(type.getRowSchema()), cachingGettersFactory); } else if (typeName.equals(TypeName.ARRAY)) { - FieldType elementType = type.getCollectionElementType(); + FieldType elementType = Verify.verifyNotNull(type.getCollectionElementType()); return elementType.getTypeName().equals(TypeName.ROW) ? new GetEagerCollection(base, converter(elementType)) : new GetCollection(base, converter(elementType)); } else if (typeName.equals(TypeName.ITERABLE)) { - return new GetIterable(base, converter(type.getCollectionElementType())); + return new GetIterable( + base, converter(Verify.verifyNotNull(type.getCollectionElementType()))); } else if (typeName.equals(TypeName.MAP)) { - return new GetMap(base, converter(type.getMapKeyType()), converter(type.getMapValueType())); + return new GetMap( + base, + converter(Verify.verifyNotNull(type.getMapKeyType())), + converter(Verify.verifyNotNull(type.getMapValueType()))); } else if (type.isLogicalType(OneOfType.IDENTIFIER)) { OneOfType oneOfType = type.getLogicalType(OneOfType.class); Schema oneOfSchema = oneOfType.getOneOfSchema(); Map values = oneOfType.getCaseEnumType().getValuesMap(); - Map converters = Maps.newHashMapWithExpectedSize(values.size()); + Map> converters = + Maps.newHashMapWithExpectedSize(values.size()); for (Map.Entry kv : values.entrySet()) { FieldType fieldType = oneOfSchema.getField(kv.getKey()).getType(); - FieldValueGetter converter = converter(fieldType); + FieldValueGetter converter = converter(fieldType); converters.put(kv.getValue(), converter); } return new GetOneOf(base, converters, oneOfType); } else if (typeName.isLogicalType()) { - return new GetLogicalInputType(base, type.getLogicalType()); + return new GetLogicalInputType(base, Verify.verifyNotNull(type.getLogicalType())); } return base; } - FieldValueGetter converter(FieldType type) { + FieldValueGetter converter(FieldType type) { return rowValueGetter(IDENTITY, type); } - static class GetRow extends Converter { + static class GetRow + extends Converter { final Schema schema; - final Factory> factory; + final Factory>> factory; - GetRow(FieldValueGetter getter, Schema schema, Factory> factory) { + GetRow( + FieldValueGetter getter, + Schema schema, + Factory>> factory) { super(getter); this.schema = schema; this.factory = factory; } @Override - Object convert(Object value) { + Object convert(V value) { return Row.withSchema(schema).withFieldValueGetters(factory, value); } } - static class GetEagerCollection extends Converter { + static class GetEagerCollection extends Converter { final FieldValueGetter converter; - GetEagerCollection(FieldValueGetter getter, FieldValueGetter converter) { + GetEagerCollection(FieldValueGetter getter, FieldValueGetter converter) { super(getter); this.converter = converter; } @@ -288,15 +308,16 @@ Object convert(Collection collection) { } } - static class GetCollection extends Converter { + static class GetCollection extends Converter { final FieldValueGetter converter; - GetCollection(FieldValueGetter getter, FieldValueGetter converter) { + GetCollection(FieldValueGetter getter, FieldValueGetter converter) { super(getter); this.converter = converter; } @Override + @SuppressWarnings({"nullness"}) Object convert(Collection collection) { if (collection instanceof List) { // For performance reasons if the input is a list, make sure that we produce a list. @@ -309,45 +330,51 @@ Object convert(Collection collection) { } } - static class GetIterable extends Converter { + static class GetIterable extends Converter { final FieldValueGetter converter; - GetIterable(FieldValueGetter getter, FieldValueGetter converter) { + GetIterable(FieldValueGetter getter, FieldValueGetter converter) { super(getter); this.converter = converter; } @Override + @SuppressWarnings({"nullness"}) Object convert(Iterable value) { return Iterables.transform(value, converter::get); } } - static class GetMap extends Converter> { - final FieldValueGetter keyConverter; - final FieldValueGetter valueConverter; + static class GetMap + extends Converter> { + final FieldValueGetter<@NonNull K1, K2> keyConverter; + final FieldValueGetter<@NonNull V1, V2> valueConverter; GetMap( - FieldValueGetter getter, FieldValueGetter keyConverter, FieldValueGetter valueConverter) { + FieldValueGetter> getter, + FieldValueGetter<@NonNull K1, K2> keyConverter, + FieldValueGetter<@NonNull V1, V2> valueConverter) { super(getter); this.keyConverter = keyConverter; this.valueConverter = valueConverter; } @Override - Object convert(Map value) { - Map returnMap = Maps.newHashMapWithExpectedSize(value.size()); - for (Map.Entry entry : value.entrySet()) { - returnMap.put(keyConverter.get(entry.getKey()), valueConverter.get(entry.getValue())); + Map<@Nullable K2, @Nullable V2> convert(Map<@Nullable K1, @Nullable V1> value) { + Map<@Nullable K2, @Nullable V2> returnMap = Maps.newHashMapWithExpectedSize(value.size()); + for (Map.Entry<@Nullable K1, @Nullable V1> entry : value.entrySet()) { + returnMap.put( + Optional.ofNullable(entry.getKey()).map(keyConverter::get).orElse(null), + Optional.ofNullable(entry.getValue()).map(valueConverter::get).orElse(null)); } return returnMap; } } - static class GetLogicalInputType extends Converter { + static class GetLogicalInputType extends Converter { final LogicalType logicalType; - GetLogicalInputType(FieldValueGetter getter, LogicalType logicalType) { + GetLogicalInputType(FieldValueGetter getter, LogicalType logicalType) { super(getter); this.logicalType = logicalType; } @@ -359,12 +386,14 @@ Object convert(Object value) { } } - static class GetOneOf extends Converter { + static class GetOneOf extends Converter { final OneOfType oneOfType; - final Map converters; + final Map> converters; GetOneOf( - FieldValueGetter getter, Map converters, OneOfType oneOfType) { + FieldValueGetter getter, + Map> converters, + OneOfType oneOfType) { super(getter); this.converters = converters; this.oneOfType = oneOfType; @@ -373,24 +402,31 @@ static class GetOneOf extends Converter { @Override Object convert(OneOfType.Value value) { EnumerationType.Value caseType = value.getCaseType(); - FieldValueGetter converter = converters.get(caseType.getValue()); - checkState(converter != null, "Missing OneOf converter for case %s.", caseType); + + @NonNull + FieldValueGetter<@NonNull Object, Object> converter = + Verify.verifyNotNull( + converters.get(caseType.getValue()), + "Missing OneOf converter for case %s.", + caseType); + return oneOfType.createValue(caseType, converter.get(value.getValue())); } } - abstract static class Converter implements FieldValueGetter { - final FieldValueGetter getter; + abstract static class Converter + implements FieldValueGetter { + final FieldValueGetter getter; - public Converter(FieldValueGetter getter) { + public Converter(FieldValueGetter getter) { this.getter = getter; } - abstract Object convert(T value); + abstract Object convert(ValueT value); @Override - public @Nullable Object get(Object object) { - T value = (T) getter.get(object); + public @Nullable Object get(ObjectT object) { + ValueT value = getter.get(object); if (value == null) { return null; } @@ -398,7 +434,7 @@ public Converter(FieldValueGetter getter) { } @Override - public @Nullable Object getRaw(Object object) { + public @Nullable Object getRaw(ObjectT object) { return getter.getRaw(object); } @@ -408,16 +444,16 @@ public String name() { } } - private static final FieldValueGetter IDENTITY = - new FieldValueGetter() { + private static final FieldValueGetter<@NonNull Object, Object> IDENTITY = + new FieldValueGetter<@NonNull Object, Object>() { @Override - public @Nullable Object get(Object object) { + public Object get(@NonNull Object object) { return object; } @Override public String name() { - return null; + return "IDENTITY"; } }; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java index de31f9947c36..e7214d8f663a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProviderV2.java @@ -19,6 +19,7 @@ import java.util.List; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A newer version of {@link GetterBasedSchemaProvider}, which works with {@link TypeDescriptor}s, @@ -28,12 +29,12 @@ public abstract class GetterBasedSchemaProviderV2 extends GetterBasedSchemaProvider { @Override public List fieldValueGetters(Class targetClass, Schema schema) { - return fieldValueGetters(TypeDescriptor.of(targetClass), schema); + return (List) fieldValueGetters(TypeDescriptor.of(targetClass), schema); } @Override - public abstract List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema); + public abstract List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema); @Override public List fieldValueTypeInformations( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java index a9cf01c52057..14adf2f6603e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaBeanSchema.java @@ -19,7 +19,6 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; -import java.util.Comparator; import java.util.List; import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.annotations.SchemaCaseFormat; @@ -34,6 +33,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -49,10 +49,7 @@ *

TODO: Validate equals() method is provided, and if not generate a "slow" equals method based * on the schema. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class JavaBeanSchema extends GetterBasedSchemaProviderV2 { /** {@link FieldValueTypeSupplier} that's based on getter methods. */ @VisibleForTesting @@ -68,9 +65,9 @@ public List get(TypeDescriptor typeDescriptor) { .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(methods.size()); for (int i = 0; i < methods.size(); ++i) { - types.add(FieldValueTypeInformation.forGetter(methods.get(i), i)); + types.add(FieldValueTypeInformation.forGetter(typeDescriptor, methods.get(i), i)); } - types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); + types.sort(JavaBeanUtils.comparingNullFirst(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); return types; } @@ -114,29 +111,32 @@ public List get(TypeDescriptor typeDescriptor) { return ReflectUtils.getMethods(typeDescriptor.getRawType()).stream() .filter(ReflectUtils::isSetter) .filter(m -> !m.isAnnotationPresent(SchemaIgnore.class)) - .map(FieldValueTypeInformation::forSetter) + .map(m -> FieldValueTypeInformation.forSetter(typeDescriptor, m)) .map( t -> { - if (t.getMethod().getAnnotation(SchemaFieldNumber.class) != null) { + Method m = + Preconditions.checkNotNull( + t.getMethod(), JavaBeanUtils.SETTER_WITH_NULL_METHOD_ERROR); + if (m.getAnnotation(SchemaFieldNumber.class) != null) { throw new RuntimeException( String.format( "@SchemaFieldNumber can only be used on getters in Java Beans. Found on" + " setter '%s'", - t.getMethod().getName())); + m.getName())); } - if (t.getMethod().getAnnotation(SchemaFieldName.class) != null) { + if (m.getAnnotation(SchemaFieldName.class) != null) { throw new RuntimeException( String.format( "@SchemaFieldName can only be used on getters in Java Beans. Found on" + " setter '%s'", - t.getMethod().getName())); + m.getName())); } - if (t.getMethod().getAnnotation(SchemaCaseFormat.class) != null) { + if (m.getAnnotation(SchemaCaseFormat.class) != null) { throw new RuntimeException( String.format( "@SchemaCaseFormat can only be used on getters in Java Beans. Found on" + " setter '%s'", - t.getMethod().getName())); + m.getName())); } return t; }) @@ -172,8 +172,8 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return JavaBeanUtils.getGetters( targetTypeDescriptor, schema, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java index 21f07c47b47f..9a8eef2bf2c8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/JavaFieldSchema.java @@ -21,20 +21,22 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; -import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.DefaultTypeConversionsFactory; import org.apache.beam.sdk.schemas.utils.FieldValueTypeSupplier; +import org.apache.beam.sdk.schemas.utils.JavaBeanUtils; import org.apache.beam.sdk.schemas.utils.POJOUtils; import org.apache.beam.sdk.schemas.utils.ReflectUtils; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A {@link SchemaProvider} for Java POJO objects. @@ -49,7 +51,6 @@ *

TODO: Validate equals() method is provided, and if not generate a "slow" equals method based * on the schema. */ -@SuppressWarnings({"nullness", "rawtypes"}) public class JavaFieldSchema extends GetterBasedSchemaProviderV2 { /** {@link FieldValueTypeSupplier} that's based on public fields. */ @VisibleForTesting @@ -64,9 +65,9 @@ public List get(TypeDescriptor typeDescriptor) { .collect(Collectors.toList()); List types = Lists.newArrayListWithCapacity(fields.size()); for (int i = 0; i < fields.size(); ++i) { - types.add(FieldValueTypeInformation.forField(fields.get(i), i)); + types.add(FieldValueTypeInformation.forField(typeDescriptor, fields.get(i), i)); } - types.sort(Comparator.comparing(FieldValueTypeInformation::getNumber)); + types.sort(JavaBeanUtils.comparingNullFirst(FieldValueTypeInformation::getNumber)); validateFieldNumbers(types); // If there are no creators registered, then make sure none of the schema fields are final, @@ -75,7 +76,9 @@ public List get(TypeDescriptor typeDescriptor) { && ReflectUtils.getAnnotatedConstructor(typeDescriptor.getRawType()) == null) { Optional finalField = types.stream() - .map(FieldValueTypeInformation::getField) + .flatMap( + fvti -> + Optional.ofNullable(fvti.getField()).map(Stream::of).orElse(Stream.empty())) .filter(f -> Modifier.isFinal(f.getModifiers())) .findAny(); if (finalField.isPresent()) { @@ -115,8 +118,8 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return POJOUtils.getGetters( targetTypeDescriptor, schema, @@ -149,7 +152,7 @@ public SchemaUserTypeCreator schemaTypeCreator( ReflectUtils.getAnnotatedConstructor(targetTypeDescriptor.getRawType()); if (constructor != null) { return POJOUtils.getConstructorCreator( - targetTypeDescriptor, + (TypeDescriptor) targetTypeDescriptor, constructor, schema, JavaFieldTypeSupplier.INSTANCE, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java index 5af59356b174..02607d91b079 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java @@ -90,6 +90,7 @@ public String toString() { return Arrays.toString(array); } } + // A mapping between field names an indices. private final BiMap fieldIndices; @@ -830,10 +831,11 @@ public static FieldType iterable(FieldType elementType) { public static FieldType map(FieldType keyType, FieldType valueType) { if (FieldType.BYTES.equals(keyType)) { LOG.warn( - "Using byte arrays as keys in a Map may lead to unexpected behavior and may not work as intended. " - + "Since arrays do not override equals() or hashCode, comparisons will be done on reference equality only. " - + "ByteBuffers, when used as keys, present similar challenges because Row stores ByteBuffer as a byte array. " - + "Consider using a different type of key for more consistent and predictable behavior."); + "Using byte arrays as keys in a Map may lead to unexpected behavior and may not work as" + + " intended. Since arrays do not override equals() or hashCode, comparisons will" + + " be done on reference equality only. ByteBuffers, when used as keys, present" + + " similar challenges because Row stores ByteBuffer as a byte array. Consider" + + " using a different type of key for more consistent and predictable behavior."); } return FieldType.forTypeName(TypeName.MAP) .setMapKeyType(keyType) @@ -1443,7 +1445,7 @@ private static Schema fromFields(List fields) { } /** Return the list of all field names. */ - public List getFieldNames() { + public List<@NonNull String> getFieldNames() { return getFields().stream().map(Schema.Field::getName).collect(Collectors.toList()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java index d7fddd8abfed..300dce61e2ea 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java @@ -27,6 +27,7 @@ import java.lang.reflect.Parameter; import java.lang.reflect.Type; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -62,21 +63,25 @@ import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversionsFactory; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.checkerframework.checker.nullness.qual.Nullable; /** Utilities for managing AutoValue schemas. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class AutoValueUtils { - public static TypeDescriptor getBaseAutoValueClass(TypeDescriptor typeDescriptor) { + public static @Nullable TypeDescriptor getBaseAutoValueClass( + TypeDescriptor typeDescriptor) { // AutoValue extensions may be nested - while (typeDescriptor != null && typeDescriptor.getRawType().getName().contains("AutoValue_")) { - typeDescriptor = TypeDescriptor.of(typeDescriptor.getRawType().getSuperclass()); + @Nullable TypeDescriptor baseTypeDescriptor = typeDescriptor; + while (baseTypeDescriptor != null + && baseTypeDescriptor.getRawType().getName().contains("AutoValue_")) { + baseTypeDescriptor = + Optional.ofNullable(baseTypeDescriptor.getRawType().getSuperclass()) + .map(TypeDescriptor::of) + .orElse(null); } - return typeDescriptor; + return baseTypeDescriptor; } private static TypeDescriptor getAutoValueGenerated(TypeDescriptor typeDescriptor) { @@ -154,7 +159,11 @@ private static boolean matchConstructor( getterTypes.stream() .collect( Collectors.toMap( - f -> ReflectUtils.stripGetterPrefix(f.getMethod().getName()), + f -> + ReflectUtils.stripGetterPrefix( + Preconditions.checkNotNull( + f.getMethod(), JavaBeanUtils.GETTER_WITH_NULL_METHOD_ERROR) + .getName()), Function.identity())); boolean valid = true; @@ -196,18 +205,23 @@ private static boolean matchConstructor( return null; } - Map setterTypes = - ReflectUtils.getMethods(builderClass).stream() - .filter(ReflectUtils::isSetter) - .map(FieldValueTypeInformation::forSetter) - .collect(Collectors.toMap(FieldValueTypeInformation::getName, Function.identity())); + Map setterTypes = new HashMap<>(); + + ReflectUtils.getMethods(builderClass).stream() + .filter(ReflectUtils::isSetter) + .map(m -> FieldValueTypeInformation.forSetter(TypeDescriptor.of(builderClass), m)) + .forEach(fv -> setterTypes.putIfAbsent(fv.getName(), fv)); List setterMethods = Lists.newArrayList(); // The builder methods to call in order. List schemaTypes = fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), schema); for (FieldValueTypeInformation type : schemaTypes) { - String autoValueFieldName = ReflectUtils.stripGetterPrefix(type.getMethod().getName()); + String autoValueFieldName = + ReflectUtils.stripGetterPrefix( + Preconditions.checkNotNull( + type.getMethod(), JavaBeanUtils.GETTER_WITH_NULL_METHOD_ERROR) + .getName()); FieldValueTypeInformation setterType = setterTypes.get(autoValueFieldName); if (setterType == null) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java index 540f09b7b553..5297eb113a97 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java @@ -22,6 +22,7 @@ import java.io.Serializable; import java.lang.reflect.Constructor; +import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.Parameter; @@ -34,6 +35,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.SortedMap; import net.bytebuddy.ByteBuddy; @@ -42,6 +44,7 @@ import net.bytebuddy.asm.AsmVisitorWrapper; import net.bytebuddy.description.method.MethodDescription.ForLoadedConstructor; import net.bytebuddy.description.method.MethodDescription.ForLoadedMethod; +import net.bytebuddy.description.type.PackageDescription; import net.bytebuddy.description.type.TypeDescription; import net.bytebuddy.description.type.TypeDescription.ForLoadedType; import net.bytebuddy.dynamic.DynamicType; @@ -78,6 +81,8 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeParameter; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Function; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Collections2; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -85,6 +90,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Primitives; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ClassUtils; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTimeZone; import org.joda.time.Instant; @@ -95,8 +101,6 @@ @Internal @SuppressWarnings({ "keyfor", - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" }) public class ByteBuddyUtils { private static final ForLoadedType ARRAYS_TYPE = new ForLoadedType(Arrays.class); @@ -147,7 +151,11 @@ protected String name(TypeDescription superClass) { // If the target class is in a prohibited package (java.*) then leave the original package // alone. String realPackage = - overridePackage(targetPackage) ? targetPackage : superClass.getPackage().getName(); + overridePackage(targetPackage) + ? targetPackage + : Optional.ofNullable(superClass.getPackage()) + .map(PackageDescription::getName) + .orElse(""); return realPackage + className + "$" + SUFFIX + "$" + randomString.nextString(); } @@ -202,25 +210,27 @@ static class ShortCircuitReturnNull extends IfNullElse { // Create a new FieldValueGetter subclass. @SuppressWarnings("unchecked") - public static DynamicType.Builder subclassGetterInterface( - ByteBuddy byteBuddy, Type objectType, Type fieldType) { + public static + DynamicType.Builder> subclassGetterInterface( + ByteBuddy byteBuddy, Type objectType, Type fieldType) { TypeDescription.Generic getterGenericType = TypeDescription.Generic.Builder.parameterizedType( FieldValueGetter.class, objectType, fieldType) .build(); - return (DynamicType.Builder) + return (DynamicType.Builder>) byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(getterGenericType); } // Create a new FieldValueSetter subclass. @SuppressWarnings("unchecked") - public static DynamicType.Builder subclassSetterInterface( - ByteBuddy byteBuddy, Type objectType, Type fieldType) { + public static + DynamicType.Builder> subclassSetterInterface( + ByteBuddy byteBuddy, Type objectType, Type fieldType) { TypeDescription.Generic setterGenericType = TypeDescription.Generic.Builder.parameterizedType( FieldValueSetter.class, objectType, fieldType) .build(); - return (DynamicType.Builder) + return (DynamicType.Builder) byteBuddy.with(new InjectPackageStrategy((Class) objectType)).subclass(setterGenericType); } @@ -252,9 +262,11 @@ public TypeConversion createSetterConversions(StackManipulati // Base class used below to convert types. @SuppressWarnings("unchecked") public abstract static class TypeConversion { - public T convert(TypeDescriptor typeDescriptor) { + public T convert(TypeDescriptor typeDescriptor) { if (typeDescriptor.isArray() - && !typeDescriptor.getComponentType().getRawType().equals(byte.class)) { + && !Preconditions.checkNotNull(typeDescriptor.getComponentType()) + .getRawType() + .equals(byte.class)) { // Byte arrays are special, so leave those alone. return convertArray(typeDescriptor); } else if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(Map.class))) { @@ -339,25 +351,32 @@ protected ConvertType(boolean returnRawTypes) { @Override protected Type convertArray(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(type.getComponentType()); + TypeDescriptor ret = + createCollectionType(Preconditions.checkNotNull(type.getComponentType())); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertCollection(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createCollectionType( + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type))); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertList(TypeDescriptor type) { - TypeDescriptor ret = createCollectionType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createCollectionType( + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type))); return returnRawTypes ? ret.getRawType() : ret.getType(); } @Override protected Type convertIterable(TypeDescriptor type) { - TypeDescriptor ret = createIterableType(ReflectUtils.getIterableComponentType(type)); + TypeDescriptor ret = + createIterableType( + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type))); return returnRawTypes ? ret.getRawType() : ret.getType(); } @@ -399,8 +418,9 @@ protected Type convertDefault(TypeDescriptor type) { @SuppressWarnings("unchecked") private TypeDescriptor> createCollectionType( TypeDescriptor componentType) { - TypeDescriptor wrappedComponentType = - TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); + TypeDescriptor wrappedComponentType = + (TypeDescriptor) + TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); return new TypeDescriptor>() {}.where( new TypeParameter() {}, wrappedComponentType); } @@ -408,8 +428,9 @@ private TypeDescriptor> createCollectionType( @SuppressWarnings("unchecked") private TypeDescriptor> createIterableType( TypeDescriptor componentType) { - TypeDescriptor wrappedComponentType = - TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); + TypeDescriptor wrappedComponentType = + (TypeDescriptor) + TypeDescriptor.of(ClassUtils.primitiveToWrapper(componentType.getRawType())); return new TypeDescriptor>() {}.where( new TypeParameter() {}, wrappedComponentType); } @@ -421,7 +442,7 @@ private TypeDescriptor> createIterableType( // This function // generates a subclass of Function that can be used to recursively transform each element of the // container. - static Class createCollectionTransformFunction( + static Class createCollectionTransformFunction( Type fromType, Type toType, Function convertElement) { // Generate a TypeDescription for the class we want to generate. TypeDescription.Generic functionGenericType = @@ -429,8 +450,8 @@ static Class createCollectionTransformFunction( Function.class, Primitives.wrap((Class) fromType), Primitives.wrap((Class) toType)) .build(); - DynamicType.Builder builder = - (DynamicType.Builder) + DynamicType.Builder> builder = + (DynamicType.Builder) BYTE_BUDDY .with(new InjectPackageStrategy((Class) fromType)) .subclass(functionGenericType) @@ -464,9 +485,11 @@ public InstrumentedType prepare(InstrumentedType instrumentedType) { .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader(((Class) fromType).getClassLoader()), + ReflectHelpers.findClassLoader(((Class) fromType).getClassLoader()), getClassLoadingStrategy( - ((Class) fromType).getClassLoader() == null ? Function.class : (Class) fromType)) + ((Class) fromType).getClassLoader() == null + ? Function.class + : (Class) fromType)) .getLoaded(); } @@ -548,17 +571,17 @@ public boolean containsValue(Object value) { } @Override - public V2 get(Object key) { + public @Nullable V2 get(Object key) { return delegateMap.get(key); } @Override - public V2 put(K2 key, V2 value) { + public @Nullable V2 put(K2 key, V2 value) { return delegateMap.put(key, value); } @Override - public V2 remove(Object key) { + public @Nullable V2 remove(Object key) { return delegateMap.remove(key); } @@ -636,12 +659,12 @@ protected StackManipulation convertArray(TypeDescriptor type) { // return isComponentTypePrimitive ? Arrays.asList(ArrayUtils.toObject(value)) // : Arrays.asList(value); - TypeDescriptor componentType = type.getComponentType(); + TypeDescriptor componentType = Preconditions.checkNotNull(type.getComponentType()); ForLoadedType loadedArrayType = new ForLoadedType(type.getRawType()); StackManipulation readArrayValue = readValue; // Row always expects to get an Iterable back for array types. Wrap this array into a // List using Arrays.asList before returning. - if (loadedArrayType.getComponentType().isPrimitive()) { + if (Preconditions.checkNotNull(loadedArrayType.getComponentType()).isPrimitive()) { // Arrays.asList doesn't take primitive arrays, so convert first using ArrayUtils.toObject. readArrayValue = new Compound( @@ -669,7 +692,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { // Generate a SerializableFunction to convert the element-type objects. StackManipulation stackManipulation; - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); @@ -688,10 +711,11 @@ protected StackManipulation convertArray(TypeDescriptor type) { @Override protected StackManipulation convertIterable(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { ForLoadedType functionType = new ForLoadedType( @@ -708,9 +732,10 @@ protected StackManipulation convertIterable(TypeDescriptor type) { @Override protected StackManipulation convertCollection(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { ForLoadedType functionType = new ForLoadedType( @@ -727,9 +752,10 @@ protected StackManipulation convertCollection(TypeDescriptor type) { @Override protected StackManipulation convertList(TypeDescriptor type) { - TypeDescriptor componentType = ReflectUtils.getIterableComponentType(type); + TypeDescriptor componentType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); Type convertedComponentType = getFactory().createTypeConversion(true).convert(componentType); - final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); + final TypeDescriptor finalComponentType = ReflectUtils.boxIfPrimitive(componentType); if (!finalComponentType.hasUnresolvedParameters()) { ForLoadedType functionType = new ForLoadedType( @@ -746,8 +772,8 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { - final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0); - final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1); + final TypeDescriptor keyType = ReflectUtils.getMapType(type, 0); + final TypeDescriptor valueType = ReflectUtils.getMapType(type, 1); Type convertedKeyType = getFactory().createTypeConversion(true).convert(keyType); Type convertedValueType = getFactory().createTypeConversion(true).convert(valueType); @@ -971,16 +997,18 @@ protected StackManipulation convertArray(TypeDescriptor type) { // return isPrimitive ? toArray : ArrayUtils.toPrimitive(toArray); ForLoadedType loadedType = new ForLoadedType(type.getRawType()); + TypeDescription loadedTypeComponentType = Verify.verifyNotNull(loadedType.getComponentType()); + // The type of the array containing the (possibly) boxed values. TypeDescription arrayType = - TypeDescription.Generic.Builder.rawType(loadedType.getComponentType().asBoxed()) + TypeDescription.Generic.Builder.rawType(loadedTypeComponentType.asBoxed()) .asArray() .build() .asErasure(); - Type rowElementType = - getFactory().createTypeConversion(false).convert(type.getComponentType()); - final TypeDescriptor arrayElementType = ReflectUtils.boxIfPrimitive(type.getComponentType()); + TypeDescriptor componentType = Preconditions.checkNotNull(type.getComponentType()); + Type rowElementType = getFactory().createTypeConversion(false).convert(componentType); + final TypeDescriptor arrayElementType = ReflectUtils.boxIfPrimitive(componentType); StackManipulation readTransformedValue = readValue; if (!arrayElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = @@ -1000,7 +1028,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { // Call Collection.toArray(T[[]) to extract the array. Push new T[0] on the stack // before // calling toArray. - ArrayFactory.forType(loadedType.getComponentType().asBoxed().asGenericType()) + ArrayFactory.forType(loadedTypeComponentType.asBoxed().asGenericType()) .withValues(Collections.emptyList()), MethodInvocation.invoke( COLLECTION_TYPE @@ -1017,7 +1045,7 @@ protected StackManipulation convertArray(TypeDescriptor type) { // Cast the result to T[]. TypeCasting.to(arrayType)); - if (loadedType.getComponentType().isPrimitive()) { + if (loadedTypeComponentType.isPrimitive()) { // The array we extract will be an array of objects. If the pojo field is an array of // primitive types, we need to then convert to an array of unboxed objects. stackManipulation = @@ -1036,11 +1064,9 @@ protected StackManipulation convertArray(TypeDescriptor type) { @Override protected StackManipulation convertIterable(TypeDescriptor type) { - Type rowElementType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor iterableElementType = ReflectUtils.getIterableComponentType(type); + final TypeDescriptor iterableElementType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); + Type rowElementType = getFactory().createTypeConversion(false).convert(iterableElementType); if (!iterableElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = new ForLoadedType( @@ -1058,11 +1084,9 @@ protected StackManipulation convertIterable(TypeDescriptor type) { @Override protected StackManipulation convertCollection(TypeDescriptor type) { - Type rowElementType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); + final TypeDescriptor collectionElementType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); + Type rowElementType = getFactory().createTypeConversion(false).convert(collectionElementType); if (!collectionElementType.hasUnresolvedParameters()) { ForLoadedType conversionFunction = @@ -1081,11 +1105,9 @@ protected StackManipulation convertCollection(TypeDescriptor type) { @Override protected StackManipulation convertList(TypeDescriptor type) { - Type rowElementType = - getFactory() - .createTypeConversion(false) - .convert(ReflectUtils.getIterableComponentType(type)); - final TypeDescriptor collectionElementType = ReflectUtils.getIterableComponentType(type); + final TypeDescriptor collectionElementType = + Preconditions.checkNotNull(ReflectUtils.getIterableComponentType(type)); + Type rowElementType = getFactory().createTypeConversion(false).convert(collectionElementType); StackManipulation readTrasformedValue = readValue; if (!collectionElementType.hasUnresolvedParameters()) { @@ -1113,12 +1135,12 @@ protected StackManipulation convertList(TypeDescriptor type) { @Override protected StackManipulation convertMap(TypeDescriptor type) { - Type rowKeyType = - getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 0)); - final TypeDescriptor keyElementType = ReflectUtils.getMapType(type, 0); - Type rowValueType = - getFactory().createTypeConversion(false).convert(ReflectUtils.getMapType(type, 1)); - final TypeDescriptor valueElementType = ReflectUtils.getMapType(type, 1); + final TypeDescriptor keyElementType = + Preconditions.checkNotNull(ReflectUtils.getMapType(type, 0)); + final TypeDescriptor valueElementType = + Preconditions.checkNotNull(ReflectUtils.getMapType(type, 1)); + Type rowKeyType = getFactory().createTypeConversion(false).convert(keyElementType); + Type rowValueType = getFactory().createTypeConversion(false).convert(valueElementType); StackManipulation readTrasformedValue = readValue; if (!keyElementType.hasUnresolvedParameters() @@ -1333,12 +1355,12 @@ protected StackManipulation convertDefault(TypeDescriptor type) { * constructor. */ static class ConstructorCreateInstruction extends InvokeUserCreateInstruction { - private final Constructor constructor; + private final Constructor constructor; ConstructorCreateInstruction( List fields, - Class targetClass, - Constructor constructor, + Class targetClass, + Constructor constructor, TypeConversionsFactory typeConversionsFactory) { super( fields, @@ -1376,7 +1398,7 @@ static class StaticFactoryMethodInstruction extends InvokeUserCreateInstruction StaticFactoryMethodInstruction( List fields, - Class targetClass, + Class targetClass, Method creator, TypeConversionsFactory typeConversionsFactory) { super( @@ -1400,14 +1422,14 @@ protected StackManipulation afterPushingParameters() { static class InvokeUserCreateInstruction implements Implementation { protected final List fields; - protected final Class targetClass; + protected final Class targetClass; protected final List parameters; protected final Map fieldMapping; private final TypeConversionsFactory typeConversionsFactory; protected InvokeUserCreateInstruction( List fields, - Class targetClass, + Class targetClass, List parameters, TypeConversionsFactory typeConversionsFactory) { this.fields = fields; @@ -1425,11 +1447,15 @@ protected InvokeUserCreateInstruction( // actual Java field or method names. FieldValueTypeInformation fieldValue = checkNotNull(fields.get(i)); fieldsByLogicalName.put(fieldValue.getName(), i); - if (fieldValue.getField() != null) { - fieldsByJavaClassMember.put(fieldValue.getField().getName(), i); - } else if (fieldValue.getMethod() != null) { - String name = ReflectUtils.stripGetterPrefix(fieldValue.getMethod().getName()); - fieldsByJavaClassMember.put(name, i); + Field field = fieldValue.getField(); + if (field != null) { + fieldsByJavaClassMember.put(field.getName(), i); + } else { + Method method = fieldValue.getMethod(); + if (method != null) { + String name = ReflectUtils.stripGetterPrefix(method.getName()); + fieldsByJavaClassMember.put(name, i); + } } } @@ -1483,7 +1509,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { StackManipulation readParameter = new StackManipulation.Compound( MethodVariableAccess.REFERENCE.loadFrom(1), - IntegerConstant.forValue(fieldMapping.get(i)), + IntegerConstant.forValue(Preconditions.checkNotNull(fieldMapping.get(i))), ArrayAccess.REFERENCE.load(), TypeCasting.to(convertedType)); stackManipulation = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java index 911f79f6eeed..ee4868ddb2b6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java @@ -22,9 +22,11 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; import net.bytebuddy.ByteBuddy; import net.bytebuddy.asm.AsmVisitorWrapper; @@ -54,14 +56,22 @@ import org.apache.beam.sdk.schemas.utils.ReflectUtils.TypeDescriptorWithSchema; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; /** A set of utilities to generate getter and setter classes for JavaBean objects. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class JavaBeanUtils { + + private static final String X_WITH_NULL_METHOD_ERROR_FMT = + "a %s FieldValueTypeInformation object has a null method field"; + public static final String GETTER_WITH_NULL_METHOD_ERROR = + String.format(X_WITH_NULL_METHOD_ERROR_FMT, "getter"); + public static final String SETTER_WITH_NULL_METHOD_ERROR = + String.format(X_WITH_NULL_METHOD_ERROR_FMT, "setter"); + /** Create a {@link Schema} for a Java Bean class. */ public static Schema schemaFromJavaBeanClass( TypeDescriptor typeDescriptor, FieldValueTypeSupplier fieldValueTypeSupplier) { @@ -69,7 +79,9 @@ public static Schema schemaFromJavaBeanClass( } private static final String CONSTRUCTOR_HELP_STRING = - "In order to infer a Schema from a Java Bean, it must have a constructor annotated with @SchemaCreate, or it must have a compatible setter for every getter used as a Schema field."; + "In order to infer a Schema from a Java Bean, it must have a constructor annotated with" + + " @SchemaCreate, or it must have a compatible setter for every getter used as a Schema" + + " field."; // Make sure that there are matching setters and getters. public static void validateJavaBean( @@ -88,23 +100,26 @@ public static void validateJavaBean( for (FieldValueTypeInformation type : getters) { FieldValueTypeInformation setterType = setterMap.get(type.getName()); + Method m = Preconditions.checkNotNull(type.getMethod(), GETTER_WITH_NULL_METHOD_ERROR); if (setterType == null) { throw new RuntimeException( String.format( - "Java Bean '%s' contains a getter for field '%s', but does not contain a matching setter. %s", - type.getMethod().getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); + "Java Bean '%s' contains a getter for field '%s', but does not contain a matching" + + " setter. %s", + m.getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); } if (!type.getType().equals(setterType.getType())) { throw new RuntimeException( String.format( "Java Bean '%s' contains a setter for field '%s' that has a mismatching type. %s", - type.getMethod().getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); + m.getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); } if (!type.isNullable() == setterType.isNullable()) { throw new RuntimeException( String.format( - "Java Bean '%s' contains a setter for field '%s' that has a mismatching nullable attribute. %s", - type.getMethod().getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); + "Java Bean '%s' contains a setter for field '%s' that has a mismatching nullable" + + " attribute. %s", + m.getDeclaringClass(), type.getName(), CONSTRUCTOR_HELP_STRING)); } } } @@ -126,36 +141,41 @@ public static List getFieldTypes( // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map, List> CACHED_GETTERS = - Maps.newConcurrentMap(); + private static final Map, List>> + CACHED_GETTERS = Maps.newConcurrentMap(); /** * Return the list of {@link FieldValueGetter}s for a Java Bean class * *

The returned list is ordered by the order of fields in the schema. */ - public static List getGetters( - TypeDescriptor typeDescriptor, + public static List> getGetters( + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { - return CACHED_GETTERS.computeIfAbsent( - TypeDescriptorWithSchema.create(typeDescriptor, schema), - c -> { - List types = - fieldValueTypeSupplier.get(typeDescriptor, schema); - return types.stream() - .map(t -> createGetter(t, typeConversionsFactory)) - .collect(Collectors.toList()); - }); + return (List) + CACHED_GETTERS.computeIfAbsent( + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> { + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + return types.stream() + .map(t -> JavaBeanUtils.createGetter(t, typeConversionsFactory)) + .collect(Collectors.toList()); + }); } - public static FieldValueGetter createGetter( - FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - DynamicType.Builder builder = + public static + FieldValueGetter createGetter( + FieldValueTypeInformation typeInformation, + TypeConversionsFactory typeConversionsFactory) { + final Method m = + Preconditions.checkNotNull(typeInformation.getMethod(), GETTER_WITH_NULL_METHOD_ERROR); + DynamicType.Builder> builder = ByteBuddyUtils.subclassGetterInterface( BYTE_BUDDY, - typeInformation.getMethod().getDeclaringClass(), + m.getDeclaringClass(), typeConversionsFactory.createTypeConversion(false).convert(typeInformation.getType())); builder = implementGetterMethods(builder, typeInformation, typeConversionsFactory); try { @@ -163,9 +183,8 @@ public static FieldValueGetter createGetter( .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader( - typeInformation.getMethod().getDeclaringClass().getClassLoader()), - getClassLoadingStrategy(typeInformation.getMethod().getDeclaringClass())) + ReflectHelpers.findClassLoader(m.getDeclaringClass().getClassLoader()), + getClassLoadingStrategy(m.getDeclaringClass())) .getLoaded() .getDeclaredConstructor() .newInstance(); @@ -178,10 +197,11 @@ public static FieldValueGetter createGetter( } } - private static DynamicType.Builder implementGetterMethods( - DynamicType.Builder builder, - FieldValueTypeInformation typeInformation, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementGetterMethods( + DynamicType.Builder> builder, + FieldValueTypeInformation typeInformation, + TypeConversionsFactory typeConversionsFactory) { return builder .method(ElementMatchers.named("name")) .intercept(FixedValue.reference(typeInformation.getName())) @@ -215,12 +235,14 @@ public static List getSetters( }); } - public static FieldValueSetter createSetter( + public static FieldValueSetter createSetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - DynamicType.Builder builder = + final Method m = + Preconditions.checkNotNull(typeInformation.getMethod(), SETTER_WITH_NULL_METHOD_ERROR); + DynamicType.Builder> builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, - typeInformation.getMethod().getDeclaringClass(), + m.getDeclaringClass(), typeConversionsFactory.createTypeConversion(false).convert(typeInformation.getType())); builder = implementSetterMethods(builder, typeInformation, typeConversionsFactory); try { @@ -228,9 +250,8 @@ public static FieldValueSetter createSetter( .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .make() .load( - ReflectHelpers.findClassLoader( - typeInformation.getMethod().getDeclaringClass().getClassLoader()), - getClassLoadingStrategy(typeInformation.getMethod().getDeclaringClass())) + ReflectHelpers.findClassLoader(m.getDeclaringClass().getClassLoader()), + getClassLoadingStrategy(m.getDeclaringClass())) .getLoaded() .getDeclaredConstructor() .newInstance(); @@ -243,10 +264,11 @@ public static FieldValueSetter createSetter( } } - private static DynamicType.Builder implementSetterMethods( - DynamicType.Builder builder, - FieldValueTypeInformation fieldValueTypeInformation, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementSetterMethods( + DynamicType.Builder> builder, + FieldValueTypeInformation fieldValueTypeInformation, + TypeConversionsFactory typeConversionsFactory) { return builder .method(ElementMatchers.named("name")) .intercept(FixedValue.reference(fieldValueTypeInformation.getName())) @@ -358,6 +380,11 @@ public static SchemaUserTypeCreator createStaticCreator( } } + public static > Comparator comparingNullFirst( + Function keyExtractor) { + return Comparator.comparing(keyExtractor, Comparator.nullsFirst(Comparator.naturalOrder())); + } + // Implements a method to read a public getter out of an object. private static class InvokeGetterInstruction implements Implementation { private final FieldValueTypeInformation typeInformation; @@ -386,7 +413,10 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Method param is offset 1 (offset 0 is the this parameter). MethodVariableAccess.REFERENCE.loadFrom(1), // Invoke the getter - MethodInvocation.invoke(new ForLoadedMethod(typeInformation.getMethod()))); + MethodInvocation.invoke( + new ForLoadedMethod( + Preconditions.checkNotNull( + typeInformation.getMethod(), GETTER_WITH_NULL_METHOD_ERROR)))); StackManipulation stackManipulation = new StackManipulation.Compound( @@ -428,7 +458,9 @@ public ByteCodeAppender appender(final Target implementationTarget) { // The instruction to read the field. StackManipulation readField = MethodVariableAccess.REFERENCE.loadFrom(2); - Method method = fieldValueTypeInformation.getMethod(); + Method method = + Preconditions.checkNotNull( + fieldValueTypeInformation.getMethod(), SETTER_WITH_NULL_METHOD_ERROR); boolean setterMethodReturnsVoid = method.getReturnType().equals(Void.TYPE); // Read the object onto the stack. StackManipulation stackManipulation = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java index 571b9c690900..8e33d321a1c6 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java @@ -62,8 +62,9 @@ import org.apache.beam.sdk.schemas.utils.ReflectUtils.TypeDescriptorWithSchema; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; -import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.NonNull; /** A set of utilities to generate getter and setter classes for POJOs. */ @SuppressWarnings({ @@ -94,38 +95,40 @@ public static List getFieldTypes( // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_GETTERS = - Maps.newConcurrentMap(); + private static final Map, List>> + CACHED_GETTERS = Maps.newConcurrentMap(); - public static List getGetters( - TypeDescriptor typeDescriptor, + public static List> getGetters( + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { // Return the getters ordered by their position in the schema. - return CACHED_GETTERS.computeIfAbsent( - TypeDescriptorWithSchema.create(typeDescriptor, schema), - c -> { - List types = - fieldValueTypeSupplier.get(typeDescriptor, schema); - List getters = - types.stream() - .map(t -> createGetter(t, typeConversionsFactory)) - .collect(Collectors.toList()); - if (getters.size() != schema.getFieldCount()) { - throw new RuntimeException( - "Was not able to generate getters for schema: " - + schema - + " class: " - + typeDescriptor); - } - return getters; - }); + return (List) + CACHED_GETTERS.computeIfAbsent( + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> { + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + List> getters = + types.stream() + .>map( + t -> POJOUtils.createGetter(t, typeConversionsFactory)) + .collect(Collectors.toList()); + if (getters.size() != schema.getFieldCount()) { + throw new RuntimeException( + "Was not able to generate getters for schema: " + + schema + + " class: " + + typeDescriptor); + } + return (List) getters; + }); } // The list of constructors for a class is cached, so we only create the classes the first time // getConstructor is called. - public static final Map CACHED_CREATORS = + public static final Map, SchemaUserTypeCreator> CACHED_CREATORS = Maps.newConcurrentMap(); public static SchemaUserTypeCreator getSetFieldCreator( @@ -150,7 +153,9 @@ private static SchemaUserTypeCreator createSetFieldCreator( TypeConversionsFactory typeConversionsFactory) { // Get the list of class fields ordered by schema. List fields = - types.stream().map(FieldValueTypeInformation::getField).collect(Collectors.toList()); + types.stream() + .map(type -> Preconditions.checkNotNull(type.getField())) + .collect(Collectors.toList()); try { DynamicType.Builder builder = BYTE_BUDDY @@ -175,14 +180,16 @@ private static SchemaUserTypeCreator createSetFieldCreator( | InvocationTargetException e) { throw new RuntimeException( String.format( - "Unable to generate a creator for POJO '%s' with inferred schema: %s%nNote POJOs must have a zero-argument constructor, or a constructor annotated with @SchemaCreate.", + "Unable to generate a creator for POJO '%s' with inferred schema: %s%nNote POJOs must" + + " have a zero-argument constructor, or a constructor annotated with" + + " @SchemaCreate.", clazz, schema)); } } - public static SchemaUserTypeCreator getConstructorCreator( - TypeDescriptor typeDescriptor, - Constructor constructor, + public static SchemaUserTypeCreator getConstructorCreator( + TypeDescriptor typeDescriptor, + Constructor constructor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { @@ -191,13 +198,13 @@ public static SchemaUserTypeCreator getConstructorCreator( c -> { List types = fieldValueTypeSupplier.get(typeDescriptor, schema); - return createConstructorCreator( + return POJOUtils.createConstructorCreator( typeDescriptor.getRawType(), constructor, schema, types, typeConversionsFactory); }); } public static SchemaUserTypeCreator createConstructorCreator( - Class clazz, + Class clazz, Constructor constructor, Schema schema, List types, @@ -291,11 +298,10 @@ public static SchemaUserTypeCreator createStaticCreator( * } * */ - @SuppressWarnings("unchecked") - static @Nullable FieldValueGetter createGetter( + static FieldValueGetter<@NonNull ObjectT, ValueT> createGetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - Field field = typeInformation.getField(); - DynamicType.Builder builder = + Field field = Preconditions.checkNotNull(typeInformation.getField()); + DynamicType.Builder> builder = ByteBuddyUtils.subclassGetterInterface( BYTE_BUDDY, field.getDeclaringClass(), @@ -322,11 +328,12 @@ public static SchemaUserTypeCreator createStaticCreator( } } - private static DynamicType.Builder implementGetterMethods( - DynamicType.Builder builder, - Field field, - String name, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementGetterMethods( + DynamicType.Builder> builder, + Field field, + String name, + TypeConversionsFactory typeConversionsFactory) { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .method(ElementMatchers.named("name")) @@ -337,24 +344,25 @@ private static DynamicType.Builder implementGetterMethods( // The list of setters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_SETTERS = - Maps.newConcurrentMap(); + private static final Map, List>> + CACHED_SETTERS = Maps.newConcurrentMap(); - public static List getSetters( - TypeDescriptor typeDescriptor, + public static List> getSetters( + TypeDescriptor typeDescriptor, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { // Return the setters, ordered by their position in the schema. - return CACHED_SETTERS.computeIfAbsent( - TypeDescriptorWithSchema.create(typeDescriptor, schema), - c -> { - List types = - fieldValueTypeSupplier.get(typeDescriptor, schema); - return types.stream() - .map(t -> createSetter(t, typeConversionsFactory)) - .collect(Collectors.toList()); - }); + return (List) + CACHED_SETTERS.computeIfAbsent( + TypeDescriptorWithSchema.create(typeDescriptor, schema), + c -> { + List types = + fieldValueTypeSupplier.get(typeDescriptor, schema); + return types.stream() + .map(t -> createSetter(t, typeConversionsFactory)) + .collect(Collectors.toList()); + }); } /** @@ -376,8 +384,8 @@ public static List getSetters( @SuppressWarnings("unchecked") private static FieldValueSetter createSetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { - Field field = typeInformation.getField(); - DynamicType.Builder builder = + Field field = Preconditions.checkNotNull(typeInformation.getField()); + DynamicType.Builder> builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, field.getDeclaringClass(), @@ -403,10 +411,11 @@ private static FieldValueSetter createSetter( } } - private static DynamicType.Builder implementSetterMethods( - DynamicType.Builder builder, - Field field, - TypeConversionsFactory typeConversionsFactory) { + private static + DynamicType.Builder> implementSetterMethods( + DynamicType.Builder> builder, + Field field, + TypeConversionsFactory typeConversionsFactory) { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) .method(ElementMatchers.named("name")) @@ -505,11 +514,11 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Implements a method to construct an object. static class SetFieldCreateInstruction implements Implementation { private final List fields; - private final Class pojoClass; + private final Class pojoClass; private final TypeConversionsFactory typeConversionsFactory; SetFieldCreateInstruction( - List fields, Class pojoClass, TypeConversionsFactory typeConversionsFactory) { + List fields, Class pojoClass, TypeConversionsFactory typeConversionsFactory) { this.fields = fields; this.pojoClass = pojoClass; this.typeConversionsFactory = typeConversionsFactory; diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java index 4349a04c28ad..423fea4c3845 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java @@ -32,7 +32,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.annotations.SchemaCreate; @@ -88,14 +87,23 @@ public static List getMethods(Class clazz) { return DECLARED_METHODS.computeIfAbsent( clazz, c -> { - return Arrays.stream(c.getDeclaredMethods()) - .filter( - m -> !m.isBridge()) // Covariant overloads insert bridge functions, which we must - // ignore. - .filter(m -> !Modifier.isPrivate(m.getModifiers())) - .filter(m -> !Modifier.isProtected(m.getModifiers())) - .filter(m -> !Modifier.isStatic(m.getModifiers())) - .collect(Collectors.toList()); + List methods = Lists.newArrayList(); + do { + if (c.getPackage() != null && c.getPackage().getName().startsWith("java.")) { + break; // skip java built-in classes + } + Arrays.stream(c.getDeclaredMethods()) + .filter( + m -> + !m.isBridge()) // Covariant overloads insert bridge functions, which we must + // ignore. + .filter(m -> !Modifier.isPrivate(m.getModifiers())) + .filter(m -> !Modifier.isProtected(m.getModifiers())) + .filter(m -> !Modifier.isStatic(m.getModifiers())) + .forEach(methods::add); + c = c.getSuperclass(); + } while (c != null); + return methods; }); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java index aeb76492bb6d..c2d945bbaac1 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/common/ReflectHelpers.java @@ -44,6 +44,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSortedSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Queues; +import org.checkerframework.checker.nullness.qual.Nullable; /** Utilities for working with with {@link Class Classes} and {@link Method Methods}. */ @SuppressWarnings({"nullness", "keyfor"}) // TODO(https://github.com/apache/beam/issues/20497) @@ -216,7 +217,7 @@ public static Iterable loadServicesOrdered(Class iface) { * which by default would use the proposed {@code ClassLoader}, which can be null. The fallback is * as follows: context ClassLoader, class ClassLoader and finally the system ClassLoader. */ - public static ClassLoader findClassLoader(final ClassLoader proposed) { + public static ClassLoader findClassLoader(@Nullable final ClassLoader proposed) { ClassLoader classLoader = proposed; if (classLoader == null) { classLoader = ReflectHelpers.class.getClassLoader(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index ee3852d70bbe..591a83600561 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -48,6 +48,7 @@ import org.apache.beam.sdk.values.RowUtils.RowFieldMatcher; import org.apache.beam.sdk.values.RowUtils.RowPosition; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; import org.joda.time.ReadableDateTime; @@ -771,6 +772,7 @@ public FieldValueBuilder withFieldValue( checkState(values.isEmpty()); return new FieldValueBuilder(schema, null).withFieldValue(fieldAccessDescriptor, value); } + /** * Sets field values using the field names. Nested values can be set using the field selection * syntax. @@ -836,10 +838,10 @@ public int nextFieldId() { } @Internal - public Row withFieldValueGetters( - Factory> fieldValueGetterFactory, Object getterTarget) { + public <@NonNull T> Row withFieldValueGetters( + Factory>> fieldValueGetterFactory, T getterTarget) { checkState(getterTarget != null, "getters require withGetterTarget."); - return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); + return new RowWithGetters<>(schema, fieldValueGetterFactory, getterTarget); } public Row build() { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index 9731507fb0f6..35e0ac20d3f7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -42,13 +42,13 @@ * the appropriate fields from the POJO. */ @SuppressWarnings("rawtypes") -public class RowWithGetters extends Row { - private final Object getterTarget; - private final List getters; +public class RowWithGetters extends Row { + private final T getterTarget; + private final List> getters; private @Nullable Map cache = null; RowWithGetters( - Schema schema, Factory> getterFactory, Object getterTarget) { + Schema schema, Factory>> getterFactory, T getterTarget) { super(schema); this.getterTarget = getterTarget; this.getters = getterFactory.create(TypeDescriptor.of(getterTarget.getClass()), schema); @@ -56,7 +56,7 @@ public class RowWithGetters extends Row { @Override @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"}) - public @Nullable T getValue(int fieldIdx) { + public W getValue(int fieldIdx) { Field field = getSchema().getField(fieldIdx); boolean cacheField = cacheFieldType(field); @@ -64,7 +64,7 @@ public class RowWithGetters extends Row { cache = new TreeMap<>(); } - Object fieldValue; + @Nullable Object fieldValue; if (cacheField) { if (cache == null) { cache = new TreeMap<>(); @@ -72,15 +72,12 @@ public class RowWithGetters extends Row { fieldValue = cache.computeIfAbsent( fieldIdx, - new Function() { + new Function() { @Override - public Object apply(Integer idx) { - FieldValueGetter getter = getters.get(idx); + public @Nullable Object apply(Integer idx) { + FieldValueGetter getter = getters.get(idx); checkStateNotNull(getter); - @SuppressWarnings("nullness") - @NonNull - Object value = getter.get(getterTarget); - return value; + return getter.get(getterTarget); } }); } else { @@ -90,7 +87,7 @@ public Object apply(Integer idx) { if (fieldValue == null && !field.getType().getNullable()) { throw new RuntimeException("Null value set on non-nullable field " + field); } - return (T) fieldValue; + return (W) fieldValue; } private boolean cacheFieldType(Field field) { @@ -116,7 +113,7 @@ public int getFieldCount() { return rawValues; } - public List getGetters() { + public List> getGetters() { return getters; } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java new file mode 100644 index 000000000000..26e3278df025 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/FieldValueTypeInformationTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.schemas; + +import static org.junit.Assert.assertEquals; + +import java.util.Map; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.junit.Test; + +public class FieldValueTypeInformationTest { + public static class GenericClass { + public T t; + + public GenericClass(T t) { + this.t = t; + } + + public T getT() { + return t; + } + + public void setT(T t) { + this.t = t; + } + } + + private final TypeDescriptor>> typeDescriptor = + new TypeDescriptor>>() {}; + private final TypeDescriptor> expectedFieldTypeDescriptor = + new TypeDescriptor>() {}; + + @Test + public void testForGetter() throws Exception { + FieldValueTypeInformation actual = + FieldValueTypeInformation.forGetter( + typeDescriptor, GenericClass.class.getMethod("getT"), 0); + assertEquals(expectedFieldTypeDescriptor, actual.getType()); + } + + @Test + public void testForField() throws Exception { + FieldValueTypeInformation actual = + FieldValueTypeInformation.forField(typeDescriptor, GenericClass.class.getField("t"), 0); + assertEquals(expectedFieldTypeDescriptor, actual.getType()); + } + + @Test + public void testForSetter() throws Exception { + FieldValueTypeInformation actual = + FieldValueTypeInformation.forSetter( + typeDescriptor, GenericClass.class.getMethod("setT", Object.class)); + assertEquals(expectedFieldTypeDescriptor, actual.getType()); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java index 021e39b84849..7e9cf9a894b9 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtilsTest.java @@ -53,6 +53,7 @@ import org.apache.beam.sdk.schemas.utils.TestJavaBeans.PrimitiveMapBean; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.SimpleBean; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; import org.joda.time.DateTime; import org.junit.Test; @@ -142,11 +143,11 @@ public void testGeneratedSimpleGetters() { simpleBean.setBigDecimal(new BigDecimal(42)); simpleBean.setStringBuilder(new StringBuilder("stringBuilder")); - List getters = + List> getters = JavaBeanUtils.getGetters( new TypeDescriptor() {}, SIMPLE_BEAN_SCHEMA, - new JavaBeanSchema.GetterTypeSupplier(), + new GetterTypeSupplier(), new DefaultTypeConversionsFactory()); assertEquals(12, getters.size()); assertEquals("str", getters.get(0).name()); @@ -220,7 +221,7 @@ public void testGeneratedSimpleBoxedGetters() { bean.setaLong(44L); bean.setaBoolean(true); - List getters = + List> getters = JavaBeanUtils.getGetters( new TypeDescriptor() {}, BEAN_WITH_BOXED_FIELDS_SCHEMA, diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java index 723353ed8d15..378cdc06805f 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/POJOUtilsTest.java @@ -52,6 +52,7 @@ import org.apache.beam.sdk.schemas.utils.TestPOJOs.PrimitiveMapPOJO; import org.apache.beam.sdk.schemas.utils.TestPOJOs.SimplePOJO; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; import org.joda.time.DateTime; import org.joda.time.Instant; import org.junit.Test; @@ -158,7 +159,7 @@ public void testGeneratedSimpleGetters() { new BigDecimal(42), new StringBuilder("stringBuilder")); - List getters = + List> getters = POJOUtils.getGetters( new TypeDescriptor() {}, SIMPLE_POJO_SCHEMA, @@ -184,7 +185,7 @@ public void testGeneratedSimpleGetters() { @Test public void testGeneratedSimpleSetters() { SimplePOJO simplePojo = new SimplePOJO(); - List setters = + List> setters = POJOUtils.getSetters( new TypeDescriptor() {}, SIMPLE_POJO_SCHEMA, @@ -223,7 +224,7 @@ public void testGeneratedSimpleSetters() { public void testGeneratedSimpleBoxedGetters() { POJOWithBoxedFields pojo = new POJOWithBoxedFields((byte) 41, (short) 42, 43, 44L, true); - List getters = + List> getters = POJOUtils.getGetters( new TypeDescriptor() {}, POJO_WITH_BOXED_FIELDS_SCHEMA, @@ -239,7 +240,7 @@ public void testGeneratedSimpleBoxedGetters() { @Test public void testGeneratedSimpleBoxedSetters() { POJOWithBoxedFields pojo = new POJOWithBoxedFields(); - List setters = + List> setters = POJOUtils.getSetters( new TypeDescriptor() {}, POJO_WITH_BOXED_FIELDS_SCHEMA, @@ -262,7 +263,7 @@ public void testGeneratedSimpleBoxedSetters() { @Test public void testGeneratedByteBufferSetters() { POJOWithByteArray pojo = new POJOWithByteArray(); - List setters = + List> setters = POJOUtils.getSetters( new TypeDescriptor() {}, POJO_WITH_BYTE_ARRAY_SCHEMA, diff --git a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java index 4b6538157fd0..78ba610ad4d1 100644 --- a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java +++ b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java @@ -276,11 +276,11 @@ public static class RecordBatchRowIterator implements Iterator, AutoCloseab new ArrowValueConverterVisitor(); private final Schema schema; private final VectorSchemaRoot vectorSchemaRoot; - private final Factory> fieldValueGetters; + private final Factory>> fieldValueGetters; private Integer currRowIndex; private static class FieldVectorListValueGetterFactory - implements Factory> { + implements Factory>> { private final List fieldVectors; static FieldVectorListValueGetterFactory of(List fieldVectors) { @@ -292,7 +292,8 @@ private FieldVectorListValueGetterFactory(List fieldVectors) { } @Override - public List create(TypeDescriptor typeDescriptor, Schema schema) { + public List> create( + TypeDescriptor typeDescriptor, Schema schema) { return this.fieldVectors.stream() .map( (fieldVector) -> { diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java index e75647a2ccfa..203bcccbf562 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/AvroRecordSchema.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.schemas.SchemaProvider; import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; import org.apache.beam.sdk.values.TypeDescriptor; +import org.checkerframework.checker.nullness.qual.NonNull; /** * A {@link SchemaProvider} for AVRO generated SpecificRecords and POJOs. @@ -44,8 +45,8 @@ public Schema schemaFor(TypeDescriptor typeDescriptor) { } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return AvroUtils.getGetters(targetTypeDescriptor, schema); } diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java index 1b1c45969307..bfbab6fe87f6 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java @@ -94,11 +94,13 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.CaseFormat; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.EnsuresNonNullIf; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; import org.joda.time.Days; import org.joda.time.Duration; import org.joda.time.Instant; @@ -139,10 +141,7 @@ * * is used. */ -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" -}) +@SuppressWarnings({"rawtypes"}) public class AvroUtils { private static final ForLoadedType BYTES = new ForLoadedType(byte[].class); private static final ForLoadedType JAVA_INSTANT = new ForLoadedType(java.time.Instant.class); @@ -152,6 +151,38 @@ public class AvroUtils { new ForLoadedType(ReadableInstant.class); private static final ForLoadedType JODA_INSTANT = new ForLoadedType(Instant.class); + // contains workarounds for third-party methods that accept nullable arguments but lack proper + // annotations + private static class NullnessCheckerWorkarounds { + + private static ReflectData newReflectData(Class clazz) { + // getClassLoader returns @Nullable Classloader, but it's ok, as ReflectData constructor + // actually tolerates null classloader argument despite lacking the @Nullable annotation + @SuppressWarnings("nullness") + @NonNull + ClassLoader classLoader = clazz.getClassLoader(); + return new ReflectData(classLoader); + } + + private static void builderSet( + GenericRecordBuilder builder, String fieldName, @Nullable Object value) { + // the value argument can actually be null here, it's not annotated as such in the method + // though, hence this wrapper + builder.set(fieldName, castToNonNull(value)); + } + + private static Object createFixed( + @Nullable Object old, byte[] bytes, org.apache.avro.Schema schema) { + // old is tolerated when null, due to an instanceof check + return GenericData.get().createFixed(castToNonNull(old), bytes, schema); + } + + @SuppressWarnings("nullness") + private static @NonNull T castToNonNull(@Nullable T value) { + return value; + } + } + public static void addLogicalTypeConversions(final GenericData data) { // do not add DecimalConversion by default as schema must have extra 'scale' and 'precision' // properties. avro reflect already handles BigDecimal as string with the 'java-class' property @@ -235,7 +266,9 @@ public static FixedBytesField withSize(int size) { /** Create a {@link FixedBytesField} from a Beam {@link FieldType}. */ public static @Nullable FixedBytesField fromBeamFieldType(FieldType fieldType) { if (fieldType.getTypeName().isLogicalType() - && fieldType.getLogicalType().getIdentifier().equals(FixedBytes.IDENTIFIER)) { + && checkNotNull(fieldType.getLogicalType()) + .getIdentifier() + .equals(FixedBytes.IDENTIFIER)) { int length = fieldType.getLogicalType(FixedBytes.class).getLength(); return new FixedBytesField(length); } else { @@ -264,7 +297,7 @@ public FieldType toBeamType() { /** Convert to an AVRO type. */ public org.apache.avro.Schema toAvroType(String name, String namespace) { - return org.apache.avro.Schema.createFixed(name, null, namespace, size); + return org.apache.avro.Schema.createFixed(name, "", namespace, size); } } @@ -451,8 +484,7 @@ public static Field toBeamField(org.apache.avro.Schema.Field field) { public static org.apache.avro.Schema.Field toAvroField(Field field, String namespace) { org.apache.avro.Schema fieldSchema = getFieldSchema(field.getType(), field.getName(), namespace); - return new org.apache.avro.Schema.Field( - field.getName(), fieldSchema, field.getDescription(), (Object) null); + return new org.apache.avro.Schema.Field(field.getName(), fieldSchema, field.getDescription()); } private AvroUtils() {} @@ -463,7 +495,7 @@ private AvroUtils() {} * @param clazz avro class */ public static Schema toBeamSchema(Class clazz) { - ReflectData data = new ReflectData(clazz.getClassLoader()); + ReflectData data = NullnessCheckerWorkarounds.newReflectData(clazz); return toBeamSchema(data.getSchema(clazz)); } @@ -486,10 +518,17 @@ public static Schema toBeamSchema(org.apache.avro.Schema schema) { return builder.build(); } + @EnsuresNonNullIf( + expression = {"#1"}, + result = false) + private static boolean isNullOrEmpty(@Nullable String str) { + return str == null || str.isEmpty(); + } + /** Converts a Beam Schema into an AVRO schema. */ public static org.apache.avro.Schema toAvroSchema( Schema beamSchema, @Nullable String name, @Nullable String namespace) { - final String schemaName = Strings.isNullOrEmpty(name) ? "topLevelRecord" : name; + final String schemaName = isNullOrEmpty(name) ? "topLevelRecord" : name; final String schemaNamespace = namespace == null ? "" : namespace; String childNamespace = !"".equals(schemaNamespace) ? schemaNamespace + "." + schemaName : schemaName; @@ -498,7 +537,7 @@ public static org.apache.avro.Schema toAvroSchema( org.apache.avro.Schema.Field recordField = toAvroField(field, childNamespace); fields.add(recordField); } - return org.apache.avro.Schema.createRecord(schemaName, null, schemaNamespace, false, fields); + return org.apache.avro.Schema.createRecord(schemaName, "", schemaNamespace, false, fields); } public static org.apache.avro.Schema toAvroSchema(Schema beamSchema) { @@ -557,7 +596,8 @@ public static GenericRecord toGenericRecord( GenericRecordBuilder builder = new GenericRecordBuilder(avroSchema); for (int i = 0; i < beamSchema.getFieldCount(); ++i) { Field field = beamSchema.getField(i); - builder.set( + NullnessCheckerWorkarounds.builderSet( + builder, field.getName(), genericFromBeamField( field.getType(), avroSchema.getField(field.getName()).schema(), row.getValue(i))); @@ -567,7 +607,7 @@ public static GenericRecord toGenericRecord( @SuppressWarnings("unchecked") public static SerializableFunction getToRowFunction( - Class clazz, org.apache.avro.@Nullable Schema schema) { + Class clazz, org.apache.avro.Schema schema) { if (GenericRecord.class.equals(clazz)) { Schema beamSchema = toBeamSchema(schema); return (SerializableFunction) getGenericRecordToRowFunction(beamSchema); @@ -662,9 +702,9 @@ public static SerializableFunction getGenericRecordToRowFunc } private static class GenericRecordToRowFn implements SerializableFunction { - private final Schema schema; + private final @Nullable Schema schema; - GenericRecordToRowFn(Schema schema) { + GenericRecordToRowFn(@Nullable Schema schema) { this.schema = schema; } @@ -701,7 +741,7 @@ public static SerializableFunction getRowToGenericRecordFunc } private static class RowToGenericRecordFn implements SerializableFunction { - private transient org.apache.avro.Schema avroSchema; + private transient org.apache.avro.@Nullable Schema avroSchema; RowToGenericRecordFn(org.apache.avro.@Nullable Schema avroSchema) { this.avroSchema = avroSchema; @@ -751,7 +791,8 @@ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundE public static SchemaCoder schemaCoder(TypeDescriptor type) { @SuppressWarnings("unchecked") Class clazz = (Class) type.getRawType(); - org.apache.avro.Schema avroSchema = new ReflectData(clazz.getClassLoader()).getSchema(clazz); + org.apache.avro.Schema avroSchema = + NullnessCheckerWorkarounds.newReflectData(clazz).getSchema(clazz); Schema beamSchema = toBeamSchema(avroSchema); return SchemaCoder.of( beamSchema, type, getToRowFunction(clazz, avroSchema), getFromRowFunction(clazz)); @@ -790,7 +831,7 @@ public static SchemaCoder schemaCoder(org.apache.avro.Schema sche */ public static SchemaCoder schemaCoder(Class clazz, org.apache.avro.Schema schema) { return SchemaCoder.of( - getSchema(clazz, schema), + checkNotNull(getSchema(clazz, schema)), TypeDescriptor.of(clazz), getToRowFunction(clazz, schema), getFromRowFunction(clazz)); @@ -821,7 +862,7 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = methods.get(i); if (ReflectUtils.isGetter(method)) { FieldValueTypeInformation fieldValueTypeInformation = - FieldValueTypeInformation.forGetter(method, i); + FieldValueTypeInformation.forGetter(typeDescriptor, method, i); String name = mapping.get(fieldValueTypeInformation.getName()); if (name != null) { types.add(fieldValueTypeInformation.withName(name)); @@ -871,7 +912,8 @@ public List get(TypeDescriptor typeDescriptor) { for (int i = 0; i < classFields.size(); ++i) { java.lang.reflect.Field f = classFields.get(i); if (!f.isAnnotationPresent(AvroIgnore.class)) { - FieldValueTypeInformation typeInformation = FieldValueTypeInformation.forField(f, i); + FieldValueTypeInformation typeInformation = + FieldValueTypeInformation.forField(typeDescriptor, f, i); AvroName avroname = f.getAnnotation(AvroName.class); if (avroname != null) { typeInformation = typeInformation.withName(avroname.value()); @@ -895,7 +937,7 @@ public static List getFieldTypes( } /** Get generated getters for an AVRO-generated SpecificRecord or a POJO. */ - public static List getGetters( + public static List> getGetters( TypeDescriptor typeDescriptor, Schema schema) { if (typeDescriptor.isSubtypeOf(TypeDescriptor.of(SpecificRecord.class))) { return JavaBeanUtils.getGetters( @@ -968,7 +1010,7 @@ private static FieldType toFieldType(TypeWithNullability type) { break; case FIXED: - fieldType = FixedBytesField.fromAvroType(type.type).toBeamType(); + fieldType = checkNotNull(FixedBytesField.fromAvroType(type.type)).toBeamType(); break; case STRING: @@ -1066,7 +1108,8 @@ private static org.apache.avro.Schema getFieldSchema( break; case LOGICAL_TYPE: - String identifier = fieldType.getLogicalType().getIdentifier(); + Schema.LogicalType logicalType = checkNotNull(fieldType.getLogicalType()); + String identifier = logicalType.getIdentifier(); if (FixedBytes.IDENTIFIER.equals(identifier)) { FixedBytesField fixedBytesField = checkNotNull(FixedBytesField.fromBeamFieldType(fieldType)); @@ -1077,15 +1120,13 @@ private static org.apache.avro.Schema getFieldSchema( } else if (FixedString.IDENTIFIER.equals(identifier) || "CHAR".equals(identifier) || "NCHAR".equals(identifier)) { - baseType = - buildHiveLogicalTypeSchema("char", (int) fieldType.getLogicalType().getArgument()); + baseType = buildHiveLogicalTypeSchema("char", checkNotNull(logicalType.getArgument())); } else if (VariableString.IDENTIFIER.equals(identifier) || "NVARCHAR".equals(identifier) || "VARCHAR".equals(identifier) || "LONGNVARCHAR".equals(identifier) || "LONGVARCHAR".equals(identifier)) { - baseType = - buildHiveLogicalTypeSchema("varchar", (int) fieldType.getLogicalType().getArgument()); + baseType = buildHiveLogicalTypeSchema("varchar", checkNotNull(logicalType.getArgument())); } else if (EnumerationType.IDENTIFIER.equals(identifier)) { EnumerationType enumerationType = fieldType.getLogicalType(EnumerationType.class); baseType = @@ -1103,7 +1144,7 @@ private static org.apache.avro.Schema getFieldSchema( baseType = LogicalTypes.timeMillis().addToSchema(org.apache.avro.Schema.create(Type.INT)); } else { throw new RuntimeException( - "Unhandled logical type " + fieldType.getLogicalType().getIdentifier()); + "Unhandled logical type " + checkNotNull(fieldType.getLogicalType()).getIdentifier()); } break; @@ -1111,22 +1152,23 @@ private static org.apache.avro.Schema getFieldSchema( case ITERABLE: baseType = org.apache.avro.Schema.createArray( - getFieldSchema(fieldType.getCollectionElementType(), fieldName, namespace)); + getFieldSchema( + checkNotNull(fieldType.getCollectionElementType()), fieldName, namespace)); break; case MAP: - if (fieldType.getMapKeyType().getTypeName().isStringType()) { + if (checkNotNull(fieldType.getMapKeyType()).getTypeName().isStringType()) { // Avro only supports string keys in maps. baseType = org.apache.avro.Schema.createMap( - getFieldSchema(fieldType.getMapValueType(), fieldName, namespace)); + getFieldSchema(checkNotNull(fieldType.getMapValueType()), fieldName, namespace)); } else { throw new IllegalArgumentException("Avro only supports maps with string keys"); } break; case ROW: - baseType = toAvroSchema(fieldType.getRowSchema(), fieldName, namespace); + baseType = toAvroSchema(checkNotNull(fieldType.getRowSchema()), fieldName, namespace); break; default: @@ -1167,7 +1209,9 @@ private static org.apache.avro.Schema getFieldSchema( case DECIMAL: BigDecimal decimal = (BigDecimal) value; LogicalType logicalType = typeWithNullability.type.getLogicalType(); - return new Conversions.DecimalConversion().toBytes(decimal, null, logicalType); + @SuppressWarnings("nullness") + ByteBuffer result = new Conversions.DecimalConversion().toBytes(decimal, null, logicalType); + return result; case DATETIME: if (typeWithNullability.type.getType() == Type.INT) { @@ -1185,7 +1229,7 @@ private static org.apache.avro.Schema getFieldSchema( return ByteBuffer.wrap((byte[]) value); case LOGICAL_TYPE: - String identifier = fieldType.getLogicalType().getIdentifier(); + String identifier = checkNotNull(fieldType.getLogicalType()).getIdentifier(); if (FixedBytes.IDENTIFIER.equals(identifier)) { FixedBytesField fixedBytesField = checkNotNull(FixedBytesField.fromBeamFieldType(fieldType)); @@ -1193,9 +1237,11 @@ private static org.apache.avro.Schema getFieldSchema( if (byteArray.length != fixedBytesField.getSize()) { throw new IllegalArgumentException("Incorrectly sized byte array."); } - return GenericData.get().createFixed(null, (byte[]) value, typeWithNullability.type); + return NullnessCheckerWorkarounds.createFixed( + null, (byte[]) value, typeWithNullability.type); } else if (VariableBytes.IDENTIFIER.equals(identifier)) { - return GenericData.get().createFixed(null, (byte[]) value, typeWithNullability.type); + return NullnessCheckerWorkarounds.createFixed( + null, (byte[]) value, typeWithNullability.type); } else if (FixedString.IDENTIFIER.equals(identifier) || "CHAR".equals(identifier) || "NCHAR".equals(identifier)) { @@ -1239,26 +1285,27 @@ private static org.apache.avro.Schema getFieldSchema( case ARRAY: case ITERABLE: Iterable iterable = (Iterable) value; - List translatedArray = Lists.newArrayListWithExpectedSize(Iterables.size(iterable)); + List<@Nullable Object> translatedArray = + Lists.newArrayListWithExpectedSize(Iterables.size(iterable)); for (Object arrayElement : iterable) { translatedArray.add( genericFromBeamField( - fieldType.getCollectionElementType(), + checkNotNull(fieldType.getCollectionElementType()), typeWithNullability.type.getElementType(), arrayElement)); } return translatedArray; case MAP: - Map map = Maps.newHashMap(); + Map map = Maps.newHashMap(); Map valueMap = (Map) value; for (Map.Entry entry : valueMap.entrySet()) { - Utf8 key = new Utf8((String) entry.getKey()); + Utf8 key = new Utf8((String) checkNotNull(entry.getKey())); map.put( key, genericFromBeamField( - fieldType.getMapValueType(), + checkNotNull(fieldType.getMapValueType()), typeWithNullability.type.getValueType(), entry.getValue())); } @@ -1282,8 +1329,8 @@ private static org.apache.avro.Schema getFieldSchema( * @return value converted for {@link Row} */ @SuppressWarnings("unchecked") - public static @Nullable Object convertAvroFieldStrict( - @Nullable Object value, + public static @PolyNull Object convertAvroFieldStrict( + @PolyNull Object value, @Nonnull org.apache.avro.Schema avroSchema, @Nonnull FieldType fieldType) { if (value == null) { @@ -1383,7 +1430,8 @@ private static Object convertBytesStrict(ByteBuffer bb, FieldType fieldType) { private static Object convertFixedStrict(GenericFixed fixed, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "fixed"); - checkArgument(FixedBytes.IDENTIFIER.equals(fieldType.getLogicalType().getIdentifier())); + checkArgument( + FixedBytes.IDENTIFIER.equals(checkNotNull(fieldType.getLogicalType()).getIdentifier())); return fixed.bytes().clone(); // clone because GenericFixed is mutable } @@ -1434,7 +1482,10 @@ private static Object convertBooleanStrict(Boolean value, FieldType fieldType) { private static Object convertEnumStrict(Object value, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "enum"); - checkArgument(fieldType.getLogicalType().getIdentifier().equals(EnumerationType.IDENTIFIER)); + checkArgument( + checkNotNull(fieldType.getLogicalType()) + .getIdentifier() + .equals(EnumerationType.IDENTIFIER)); EnumerationType enumerationType = fieldType.getLogicalType(EnumerationType.class); return enumerationType.valueOf(value.toString()); } @@ -1442,7 +1493,8 @@ private static Object convertEnumStrict(Object value, FieldType fieldType) { private static Object convertUnionStrict( Object value, org.apache.avro.Schema unionAvroSchema, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.LOGICAL_TYPE, "oneOfType"); - checkArgument(fieldType.getLogicalType().getIdentifier().equals(OneOfType.IDENTIFIER)); + checkArgument( + checkNotNull(fieldType.getLogicalType()).getIdentifier().equals(OneOfType.IDENTIFIER)); OneOfType oneOfType = fieldType.getLogicalType(OneOfType.class); int fieldNumber = GenericData.get().resolveUnion(unionAvroSchema, value); FieldType baseFieldType = oneOfType.getOneOfSchema().getField(fieldNumber).getType(); @@ -1459,7 +1511,7 @@ private static Object convertArrayStrict( FieldType elemFieldType = fieldType.getCollectionElementType(); for (Object value : values) { - ret.add(convertAvroFieldStrict(value, elemAvroSchema, elemFieldType)); + ret.add(convertAvroFieldStrict(value, elemAvroSchema, checkNotNull(elemFieldType))); } return ret; @@ -1470,10 +1522,10 @@ private static Object convertMapStrict( org.apache.avro.Schema valueAvroSchema, FieldType fieldType) { checkTypeName(fieldType.getTypeName(), TypeName.MAP, "map"); - checkNotNull(fieldType.getMapKeyType()); - checkNotNull(fieldType.getMapValueType()); + FieldType mapKeyType = checkNotNull(fieldType.getMapKeyType()); + FieldType mapValueType = checkNotNull(fieldType.getMapValueType()); - if (!fieldType.getMapKeyType().equals(FieldType.STRING)) { + if (!FieldType.STRING.equals(fieldType.getMapKeyType())) { throw new IllegalArgumentException( "Can't convert 'string' map keys to " + fieldType.getMapKeyType()); } @@ -1482,8 +1534,8 @@ private static Object convertMapStrict( for (Map.Entry value : values.entrySet()) { ret.put( - convertStringStrict(value.getKey(), fieldType.getMapKeyType()), - convertAvroFieldStrict(value.getValue(), valueAvroSchema, fieldType.getMapValueType())); + convertStringStrict(value.getKey(), mapKeyType), + convertAvroFieldStrict(value.getValue(), valueAvroSchema, mapValueType)); } return ret; diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java index d159e9de44a8..9fe6162ec936 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java @@ -104,16 +104,14 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) class ProtoByteBuddyUtils { private static final ByteBuddy BYTE_BUDDY = new ByteBuddy(); private static final TypeDescriptor BYTE_STRING_TYPE_DESCRIPTOR = @@ -270,7 +268,7 @@ static class ProtoConvertType extends ConvertType { .build(); @Override - public Type convert(TypeDescriptor typeDescriptor) { + public Type convert(TypeDescriptor typeDescriptor) { if (typeDescriptor.equals(BYTE_STRING_TYPE_DESCRIPTOR) || typeDescriptor.isSubtypeOf(BYTE_STRING_TYPE_DESCRIPTOR)) { return byte[].class; @@ -297,7 +295,7 @@ protected ProtoTypeConversionsFactory getFactory() { } @Override - public StackManipulation convert(TypeDescriptor type) { + public StackManipulation convert(TypeDescriptor type) { if (type.equals(BYTE_STRING_TYPE_DESCRIPTOR) || type.isSubtypeOf(BYTE_STRING_TYPE_DESCRIPTOR)) { return new Compound( @@ -372,7 +370,7 @@ protected ProtoTypeConversionsFactory getFactory() { } @Override - public StackManipulation convert(TypeDescriptor type) { + public StackManipulation convert(TypeDescriptor type) { if (type.isSubtypeOf(TypeDescriptor.of(ByteString.class))) { return new Compound( readValue, @@ -459,7 +457,7 @@ public TypeConversion createSetterConversions(StackManipulati // The list of getters for a class is cached, so we only create the classes the first time // getSetters is called. - private static final Map> CACHED_GETTERS = + private static final Map>> CACHED_GETTERS = Maps.newConcurrentMap(); /** @@ -467,35 +465,36 @@ public TypeConversion createSetterConversions(StackManipulati * *

The returned list is ordered by the order of fields in the schema. */ - public static List getGetters( - Class clazz, + public static List> getGetters( + Class clazz, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { Multimap methods = ReflectUtils.getMethodsMap(clazz); - return CACHED_GETTERS.computeIfAbsent( - ClassWithSchema.create(clazz, schema), - c -> { - List types = - fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), schema); - return types.stream() - .map( - t -> - createGetter( - t, - typeConversionsFactory, - clazz, - methods, - schema.getField(t.getName()), - fieldValueTypeSupplier)) - .collect(Collectors.toList()); - }); + return (List) + CACHED_GETTERS.computeIfAbsent( + ClassWithSchema.create(clazz, schema), + c -> { + List types = + fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), schema); + return types.stream() + .map( + t -> + createGetter( + t, + typeConversionsFactory, + clazz, + methods, + schema.getField(t.getName()), + fieldValueTypeSupplier)) + .collect(Collectors.toList()); + }); } - static FieldValueGetter createOneOfGetter( + static FieldValueGetter<@NonNull ProtoT, OneOfType.Value> createOneOfGetter( FieldValueTypeInformation typeInformation, - TreeMap> getterMethodMap, - Class protoClass, + TreeMap> getterMethodMap, + Class protoClass, OneOfType oneOfType, Method getCaseMethod) { Set indices = getterMethodMap.keySet(); @@ -505,7 +504,7 @@ static FieldValueGetter createOneOfGetter( int[] keys = getterMethodMap.keySet().stream().mapToInt(Integer::intValue).toArray(); - DynamicType.Builder builder = + DynamicType.Builder> builder = ByteBuddyUtils.subclassGetterInterface(BYTE_BUDDY, protoClass, OneOfType.Value.class); builder = builder @@ -514,7 +513,8 @@ static FieldValueGetter createOneOfGetter( .method(ElementMatchers.named("get")) .intercept(new OneOfGetterInstruction(contiguous, keys, getCaseMethod)); - List getters = Lists.newArrayList(getterMethodMap.values()); + List> getters = + Lists.newArrayList(getterMethodMap.values()); builder = builder // Store a field with the list of individual getters. The get() instruction will pick @@ -556,12 +556,12 @@ static FieldValueGetter createOneOfGetter( FieldValueSetter createOneOfSetter( String name, TreeMap> setterMethodMap, - Class protoBuilderClass) { + Class protoBuilderClass) { Set indices = setterMethodMap.keySet(); boolean contiguous = isContiguous(indices); int[] keys = setterMethodMap.keySet().stream().mapToInt(Integer::intValue).toArray(); - DynamicType.Builder builder = + DynamicType.Builder> builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, protoBuilderClass, OneOfType.Value.class); builder = @@ -585,7 +585,8 @@ FieldValueSetter createOneOfSetter( .withParameters(List.class) .intercept(new OneOfSetterConstructor()); - List setters = Lists.newArrayList(setterMethodMap.values()); + List> setters = + Lists.newArrayList(setterMethodMap.values()); try { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) @@ -947,10 +948,10 @@ public ByteCodeAppender appender(final Target implementationTarget) { } } - private static FieldValueGetter createGetter( + private static FieldValueGetter<@NonNull ProtoT, ?> createGetter( FieldValueTypeInformation fieldValueTypeInformation, TypeConversionsFactory typeConversionsFactory, - Class clazz, + Class clazz, Multimap methods, Field field, FieldValueTypeSupplier fieldValueTypeSupplier) { @@ -964,21 +965,23 @@ private static FieldValueGetter createGetter( field.getName() + "_case", FieldType.logicalType(oneOfType.getCaseEnumType())); // Create a map of case enum value to getter. This must be sorted, so store in a TreeMap. - TreeMap> oneOfGetters = Maps.newTreeMap(); + TreeMap> oneOfGetters = + Maps.newTreeMap(); Map oneOfFieldTypes = fieldValueTypeSupplier.get(TypeDescriptor.of(clazz), oneOfType.getOneOfSchema()).stream() .collect(Collectors.toMap(FieldValueTypeInformation::getName, f -> f)); for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { int protoFieldIndex = getFieldNumber(oneOfField); - FieldValueGetter oneOfFieldGetter = + FieldValueGetter<@NonNull ProtoT, ?> oneOfFieldGetter = createGetter( - oneOfFieldTypes.get(oneOfField.getName()), + Verify.verifyNotNull(oneOfFieldTypes.get(oneOfField.getName())), typeConversionsFactory, clazz, methods, oneOfField, fieldValueTypeSupplier); - oneOfGetters.put(protoFieldIndex, oneOfFieldGetter); + oneOfGetters.put( + protoFieldIndex, (FieldValueGetter<@NonNull ProtoT, OneOfType.Value>) oneOfFieldGetter); } return createOneOfGetter( fieldValueTypeInformation, oneOfGetters, clazz, oneOfType, caseMethod); @@ -987,10 +990,11 @@ private static FieldValueGetter createGetter( } } - private static Class getProtoGeneratedBuilder(Class clazz) { + private static @Nullable Class getProtoGeneratedBuilder( + Class clazz) { String builderClassName = clazz.getName() + "$Builder"; try { - return Class.forName(builderClassName); + return (Class) Class.forName(builderClassName); } catch (ClassNotFoundException e) { return null; } @@ -1018,51 +1022,59 @@ static Method getProtoGetter(Multimap methods, String name, Fiel public static @Nullable SchemaUserTypeCreator getBuilderCreator( - Class protoClass, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier) { - Class builderClass = getProtoGeneratedBuilder(protoClass); + TypeDescriptor protoTypeDescriptor, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier) { + Class builderClass = getProtoGeneratedBuilder(protoTypeDescriptor.getRawType()); if (builderClass == null) { return null; } Multimap methods = ReflectUtils.getMethodsMap(builderClass); List> setters = schema.getFields().stream() - .map(f -> getProtoFieldValueSetter(f, methods, builderClass)) + .map(f -> getProtoFieldValueSetter(protoTypeDescriptor, f, methods, builderClass)) .collect(Collectors.toList()); - return createBuilderCreator(protoClass, builderClass, setters, schema); + return createBuilderCreator(protoTypeDescriptor.getRawType(), builderClass, setters, schema); } private static FieldValueSetter getProtoFieldValueSetter( - Field field, Multimap methods, Class builderClass) { + TypeDescriptor typeDescriptor, + Field field, + Multimap methods, + Class builderClass) { if (field.getType().isLogicalType(OneOfType.IDENTIFIER)) { OneOfType oneOfType = field.getType().getLogicalType(OneOfType.class); TreeMap> oneOfSetters = Maps.newTreeMap(); for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { - FieldValueSetter setter = getProtoFieldValueSetter(oneOfField, methods, builderClass); + FieldValueSetter setter = + getProtoFieldValueSetter(typeDescriptor, oneOfField, methods, builderClass); oneOfSetters.put(getFieldNumber(oneOfField), setter); } return createOneOfSetter(field.getName(), oneOfSetters, builderClass); } else { Method method = getProtoSetter(methods, field.getName(), field.getType()); return JavaBeanUtils.createSetter( - FieldValueTypeInformation.forSetter(method, protoSetterPrefix(field.getType())), + FieldValueTypeInformation.forSetter( + typeDescriptor, method, protoSetterPrefix(field.getType())), new ProtoTypeConversionsFactory()); } } static SchemaUserTypeCreator createBuilderCreator( Class protoClass, - Class builderClass, + Class builderClass, List> setters, Schema schema) { try { - DynamicType.Builder builder = - BYTE_BUDDY - .with(new InjectPackageStrategy(builderClass)) - .subclass(Supplier.class) - .method(ElementMatchers.named("get")) - .intercept(new BuilderSupplier(protoClass)); - Supplier supplier = + DynamicType.Builder> builder = + (DynamicType.Builder) + BYTE_BUDDY + .with(new InjectPackageStrategy(builderClass)) + .subclass(Supplier.class) + .method(ElementMatchers.named("get")) + .intercept(new BuilderSupplier(protoClass)); + Supplier supplier = builder .visit( new AsmVisitorWrapper.ForDeclaredMethods() diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java index faf3ad407af5..b0bb9071524b 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchema.java @@ -43,12 +43,9 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; -@SuppressWarnings({ - "nullness", // TODO(https://github.com/apache/beam/issues/20497) - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) -}) public class ProtoMessageSchema extends GetterBasedSchemaProviderV2 { private static final class ProtoClassFieldValueTypeSupplier implements FieldValueTypeSupplier { @@ -72,7 +69,8 @@ public List get(TypeDescriptor typeDescriptor, Sch Method method = getProtoGetter(methods, oneOfField.getName(), oneOfField.getType()); oneOfTypes.put( oneOfField.getName(), - FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); + FieldValueTypeInformation.forGetter(typeDescriptor, method, i) + .withName(field.getName())); } // Add an entry that encapsulates information about all possible getters. types.add( @@ -82,7 +80,9 @@ public List get(TypeDescriptor typeDescriptor, Sch } else { // This is a simple field. Add the getter. Method method = getProtoGetter(methods, field.getName(), field.getType()); - types.add(FieldValueTypeInformation.forGetter(method, i).withName(field.getName())); + types.add( + FieldValueTypeInformation.forGetter(typeDescriptor, method, i) + .withName(field.getName())); } } return types; @@ -96,8 +96,8 @@ public List get(TypeDescriptor typeDescriptor, Sch } @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { return ProtoByteBuddyUtils.getGetters( targetTypeDescriptor.getRawType(), schema, @@ -117,7 +117,7 @@ public SchemaUserTypeCreator schemaTypeCreator( TypeDescriptor targetTypeDescriptor, Schema schema) { SchemaUserTypeCreator creator = ProtoByteBuddyUtils.getBuilderCreator( - targetTypeDescriptor.getRawType(), schema, new ProtoClassFieldValueTypeSupplier()); + targetTypeDescriptor, schema, new ProtoClassFieldValueTypeSupplier()); if (creator == null) { throw new RuntimeException("Cannot create creator for " + targetTypeDescriptor); } @@ -152,7 +152,8 @@ public static SimpleFunction getRowToProtoBytesFn(Class claz private void checkForDynamicType(TypeDescriptor typeDescriptor) { if (typeDescriptor.getRawType().equals(DynamicMessage.class)) { throw new RuntimeException( - "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use ProtoDynamicMessageSchema instead."); + "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use" + + " ProtoDynamicMessageSchema instead."); } } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java index acdfcfc1ad09..e8b05a8a319e 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java @@ -19,7 +19,6 @@ import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; -import static org.apache.beam.sdk.io.aws2.schemas.AwsSchemaUtils.getter; import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets.difference; @@ -46,6 +45,7 @@ import org.apache.beam.sdk.values.RowWithGetters; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.awssdk.core.SdkField; import software.amazon.awssdk.core.SdkPojo; @@ -73,17 +73,20 @@ public class AwsSchemaProvider extends GetterBasedSchemaProviderV2 { return AwsTypes.schemaFor(sdkFields((Class) type.getRawType())); } - @SuppressWarnings("rawtypes") @Override - public List fieldValueGetters( - TypeDescriptor targetTypeDescriptor, Schema schema) { + public List> fieldValueGetters( + TypeDescriptor targetTypeDescriptor, Schema schema) { ConverterFactory fromAws = ConverterFactory.fromAws(); Map> sdkFields = sdkFieldsByName((Class) targetTypeDescriptor.getRawType()); - List getters = new ArrayList<>(schema.getFieldCount()); - for (String field : schema.getFieldNames()) { + List> getters = new ArrayList<>(schema.getFieldCount()); + for (@NonNull String field : schema.getFieldNames()) { SdkField sdkField = checkStateNotNull(sdkFields.get(field), "Unknown field"); - getters.add(getter(field, fromAws.create(sdkField::getValueOrDefault, sdkField))); + getters.add( + AwsSchemaUtils.getter( + field, + (SerializableFunction<@NonNull T, Object>) + fromAws.create(sdkField::getValueOrDefault, sdkField))); } return getters; } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java index d36c197d80a4..9e994702fe61 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.awssdk.core.SdkPojo; import software.amazon.awssdk.utils.builder.SdkBuilder; @@ -78,7 +79,7 @@ static SdkBuilderSetter setter(String name, BiConsumer, Object> return new ValueSetter(name, setter); } - static FieldValueGetter getter( + static FieldValueGetter getter( String name, SerializableFunction getter) { return new ValueGetter<>(name, getter); } @@ -107,7 +108,8 @@ public String name() { } } - private static class ValueGetter implements FieldValueGetter { + private static class ValueGetter + implements FieldValueGetter { private final SerializableFunction getter; private final String name; diff --git a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java index 5f4e195f227f..3094ea47d6ad 100644 --- a/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java +++ b/sdks/java/io/thrift/src/main/java/org/apache/beam/sdk/io/thrift/ThriftSchema.java @@ -202,10 +202,10 @@ private Schema.Field beamField(FieldMetaData fieldDescriptor) { @SuppressWarnings("rawtypes") @Override - public @NonNull List fieldValueGetters( - @NonNull TypeDescriptor targetTypeDescriptor, @NonNull Schema schema) { + public @NonNull List> fieldValueGetters( + @NonNull TypeDescriptor targetTypeDescriptor, @NonNull Schema schema) { return schemaFieldDescriptors(targetTypeDescriptor.getRawType(), schema).keySet().stream() - .map(FieldExtractor::new) + .>map(FieldExtractor::new) .collect(Collectors.toList()); } @@ -242,10 +242,12 @@ private FieldValueTypeInformation fieldValueTypeInfo(Class type, String field if (factoryMethods.size() > 1) { throw new IllegalStateException("Overloaded factory methods: " + factoryMethods); } - return FieldValueTypeInformation.forSetter(factoryMethods.get(0), ""); + return FieldValueTypeInformation.forSetter( + TypeDescriptor.of(type), factoryMethods.get(0), ""); } else { try { - return FieldValueTypeInformation.forField(type.getDeclaredField(fieldName), 0); + return FieldValueTypeInformation.forField( + TypeDescriptor.of(type), type.getDeclaredField(fieldName), 0); } catch (NoSuchFieldException e) { throw new IllegalArgumentException(e); } @@ -373,7 +375,7 @@ private & TEnum> FieldType beamType(FieldValueMetaDat } } - private static class FieldExtractor> + private static class FieldExtractor implements FieldValueGetter { private final FieldT field; @@ -383,8 +385,9 @@ private FieldExtractor(FieldT field) { @Override public @Nullable Object get(T thrift) { - if (!(thrift instanceof TUnion) || thrift.isSet(field)) { - final Object value = thrift.getFieldValue(field); + TBase t = (TBase) thrift; + if (!(thrift instanceof TUnion) || t.isSet(field)) { + final Object value = t.getFieldValue(field); if (value instanceof Enum) { return ((Enum) value).ordinal(); } else { From 24a0447518463f7b1a244e02431be37f8c5d6cbc Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 30 Oct 2024 13:54:10 -0400 Subject: [PATCH 092/181] Bump Flink and Spark job server container base imager to tumerin-11 (#32976) --- .github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json | 2 +- .../trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json | 2 +- .../beam_PostCommit_Java_PVR_Spark3_Streaming.json | 2 +- .github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json | 2 +- runners/flink/job-server-container/Dockerfile | 2 +- runners/spark/job-server/container/Dockerfile | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json index e3d6056a5de9..b26833333238 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Batch.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json index b970762c8397..bdd2197e534a 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Flink_Docker.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test" + "modification": "1" } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json index e3d6056a5de9..c537844dc84a 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark3_Streaming.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 3 } diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json index e3d6056a5de9..e0266d62f2e0 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 4 } diff --git a/runners/flink/job-server-container/Dockerfile b/runners/flink/job-server-container/Dockerfile index c5a81ecf6466..cbb73512400e 100644 --- a/runners/flink/job-server-container/Dockerfile +++ b/runners/flink/job-server-container/Dockerfile @@ -16,7 +16,7 @@ # limitations under the License. ############################################################################### -FROM openjdk:8 +FROM eclipse-temurin:11 MAINTAINER "Apache Beam " RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y libltdl7 diff --git a/runners/spark/job-server/container/Dockerfile b/runners/spark/job-server/container/Dockerfile index ec4a123f2b9d..f5639430a33b 100644 --- a/runners/spark/job-server/container/Dockerfile +++ b/runners/spark/job-server/container/Dockerfile @@ -16,7 +16,7 @@ # limitations under the License. ############################################################################### -FROM openjdk:8 +FROM eclipse-temurin:11 MAINTAINER "Apache Beam " RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y libltdl7 From f892cfef50a7ce0cc6c72a7ef29be8f2b98d4d68 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 30 Oct 2024 21:35:24 -0400 Subject: [PATCH 093/181] Bump Hadoop versions for compatibility test --- sdks/java/io/hadoop-common/build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/java/io/hadoop-common/build.gradle b/sdks/java/io/hadoop-common/build.gradle index 466aa8fb6730..b0303d29ff98 100644 --- a/sdks/java/io/hadoop-common/build.gradle +++ b/sdks/java/io/hadoop-common/build.gradle @@ -25,10 +25,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Hadoop Common" ext.summary = "Library to add shared Hadoop classes among Beam IOs." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} From f105f26a93a5dbb205a35f672da11a45e8216fb7 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Thu, 31 Oct 2024 12:10:49 -0400 Subject: [PATCH 094/181] Reapply disable Gradle cache for expansion service (#32984) * Reapply disable Gradle cache for expansion service This reverts commit 379dcd4903b577c13d07b3e1ee423281181ba82e. * trigger test --- .github/trigger_files/beam_PostCommit_XVR_Direct.json | 3 +++ .github/trigger_files/beam_PostCommit_XVR_Flink.json | 2 +- sdks/java/expansion-service/build.gradle | 4 ++++ sdks/java/extensions/sql/expansion-service/build.gradle | 4 ++++ sdks/java/io/expansion-service/build.gradle | 1 + 5 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 .github/trigger_files/beam_PostCommit_XVR_Direct.json diff --git a/.github/trigger_files/beam_PostCommit_XVR_Direct.json b/.github/trigger_files/beam_PostCommit_XVR_Direct.json new file mode 100644 index 000000000000..236b7bee8af8 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_XVR_Direct.json @@ -0,0 +1,3 @@ +{ + "https://github.com/apache/beam/pull/32648": "testing Flink 1.19 support" +} diff --git a/.github/trigger_files/beam_PostCommit_XVR_Flink.json b/.github/trigger_files/beam_PostCommit_XVR_Flink.json index 0b34d452d42c..236b7bee8af8 100644 --- a/.github/trigger_files/beam_PostCommit_XVR_Flink.json +++ b/.github/trigger_files/beam_PostCommit_XVR_Flink.json @@ -1,3 +1,3 @@ { - "https://github.com/apache/beam/pull/32648": "testing addition of Flink 1.19 support" + "https://github.com/apache/beam/pull/32648": "testing Flink 1.19 support" } diff --git a/sdks/java/expansion-service/build.gradle b/sdks/java/expansion-service/build.gradle index 4dd8c8968ed9..a25583870acf 100644 --- a/sdks/java/expansion-service/build.gradle +++ b/sdks/java/expansion-service/build.gradle @@ -57,3 +57,7 @@ task runExpansionService (type: JavaExec) { classpath = sourceSets.main.runtimeClasspath args = [project.findProperty("constructionService.port") ?: "8097"] } + +compileJava { + outputs.upToDateWhen { false } +} \ No newline at end of file diff --git a/sdks/java/extensions/sql/expansion-service/build.gradle b/sdks/java/extensions/sql/expansion-service/build.gradle index b6963cf7547b..b8d78e4e1bb9 100644 --- a/sdks/java/extensions/sql/expansion-service/build.gradle +++ b/sdks/java/extensions/sql/expansion-service/build.gradle @@ -46,3 +46,7 @@ task runExpansionService (type: JavaExec) { classpath = sourceSets.main.runtimeClasspath args = [project.findProperty("constructionService.port") ?: "8097"] } + +shadowJar { + outputs.upToDateWhen { false } +} \ No newline at end of file diff --git a/sdks/java/io/expansion-service/build.gradle b/sdks/java/io/expansion-service/build.gradle index 8b817163ae39..cc8eccf98997 100644 --- a/sdks/java/io/expansion-service/build.gradle +++ b/sdks/java/io/expansion-service/build.gradle @@ -35,6 +35,7 @@ configurations.runtimeClasspath { shadowJar { mergeServiceFiles() + outputs.upToDateWhen { false } } description = "Apache Beam :: SDKs :: Java :: IO :: Expansion Service" From 61e7258344c0df09ac1536c05bad114904a6b44b Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Thu, 31 Oct 2024 13:01:02 -0400 Subject: [PATCH 095/181] Upgrade mypy to version 1.13.0 (#32978) * Upgrade mypy to version 1.13.0 * formatting, yaml io fix --- sdks/python/apache_beam/coders/coder_impl.py | 2 +- .../apache_beam/coders/coders_property_based_test.py | 2 +- sdks/python/apache_beam/coders/coders_test_common.py | 2 +- sdks/python/apache_beam/metrics/metric.py | 8 ++++---- .../apache_beam/ml/inference/xgboost_inference_test.py | 2 +- sdks/python/apache_beam/ml/transforms/base_test.py | 4 ++-- .../dataflow/internal/clients/dataflow/__init__.py | 2 +- sdks/python/apache_beam/runners/pipeline_context.py | 2 +- .../runners/portability/fn_api_runner/translations.py | 2 +- .../runners/portability/fn_api_runner/worker_handlers.py | 2 +- .../apache_beam/runners/portability/local_job_service.py | 2 +- sdks/python/apache_beam/runners/render.py | 2 +- sdks/python/apache_beam/runners/worker/log_handler.py | 2 +- sdks/python/apache_beam/runners/worker/opcounters.py | 2 +- sdks/python/apache_beam/runners/worker/sdk_worker.py | 4 ++-- sdks/python/apache_beam/runners/worker/sdk_worker_main.py | 2 +- .../benchmarks/chicago_taxi/tfdv_analyze_and_validate.py | 2 +- .../testing/benchmarks/chicago_taxi/trainer/taxi.py | 2 +- .../apache_beam/testing/load_tests/sideinput_test.py | 8 ++++---- .../apache_beam/testing/test_stream_service_test.py | 6 +++--- sdks/python/apache_beam/transforms/core.py | 4 ++-- sdks/python/apache_beam/transforms/display.py | 2 +- sdks/python/apache_beam/transforms/external.py | 2 +- sdks/python/apache_beam/transforms/ptransform.py | 2 +- sdks/python/apache_beam/transforms/stats.py | 2 +- sdks/python/apache_beam/transforms/userstate.py | 4 ++-- sdks/python/apache_beam/utils/profiler.py | 4 +++- sdks/python/apache_beam/utils/proto_utils.py | 2 +- sdks/python/mypy.ini | 8 +++++++- sdks/python/tox.ini | 2 +- 30 files changed, 50 insertions(+), 42 deletions(-) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index ff5fb5bef7ac..dfdb247d781d 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -1975,7 +1975,7 @@ class DecimalCoderImpl(StreamCoderImpl): def encode_to_stream(self, value, out, nested): # type: (decimal.Decimal, create_OutputStream, bool) -> None - scale = -value.as_tuple().exponent + scale = -value.as_tuple().exponent # type: ignore[operator] int_value = int(value.scaleb(scale)) out.write_var_int64(scale) self.BIG_INT_CODER_IMPL.encode_to_stream(int_value, out, nested) diff --git a/sdks/python/apache_beam/coders/coders_property_based_test.py b/sdks/python/apache_beam/coders/coders_property_based_test.py index be18dd3586b0..9279fc31c099 100644 --- a/sdks/python/apache_beam/coders/coders_property_based_test.py +++ b/sdks/python/apache_beam/coders/coders_property_based_test.py @@ -141,7 +141,7 @@ def test_row_coder(self, data: st.DataObject): coders_registry.register_coder(RowType, RowCoder) # TODO(https://github.com/apache/beam/issues/23002): Apply nulls for these - row = RowType( # type: ignore + row = RowType( **{ name: data.draw(SCHEMA_TYPES_TO_STRATEGY[type_]) for name, diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 7dcfae83f10e..4bd9698dd57b 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -53,7 +53,7 @@ except ImportError: dataclasses = None # type: ignore -MyNamedTuple = collections.namedtuple('A', ['x', 'y']) +MyNamedTuple = collections.namedtuple('A', ['x', 'y']) # type: ignore[name-match] MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), ('f2', str)]) diff --git a/sdks/python/apache_beam/metrics/metric.py b/sdks/python/apache_beam/metrics/metric.py index f402c0acab2f..3e665dd805ea 100644 --- a/sdks/python/apache_beam/metrics/metric.py +++ b/sdks/python/apache_beam/metrics/metric.py @@ -140,7 +140,7 @@ class DelegatingCounter(Counter): def __init__( self, metric_name: MetricName, process_wide: bool = False) -> None: super().__init__(metric_name) - self.inc = MetricUpdater( # type: ignore[assignment] + self.inc = MetricUpdater( # type: ignore[method-assign] cells.CounterCell, metric_name, default_value=1, @@ -150,19 +150,19 @@ class DelegatingDistribution(Distribution): """Metrics Distribution Delegates functionality to MetricsEnvironment.""" def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) - self.update = MetricUpdater(cells.DistributionCell, metric_name) # type: ignore[assignment] + self.update = MetricUpdater(cells.DistributionCell, metric_name) # type: ignore[method-assign] class DelegatingGauge(Gauge): """Metrics Gauge that Delegates functionality to MetricsEnvironment.""" def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) - self.set = MetricUpdater(cells.GaugeCell, metric_name) # type: ignore[assignment] + self.set = MetricUpdater(cells.GaugeCell, metric_name) # type: ignore[method-assign] class DelegatingStringSet(StringSet): """Metrics StringSet that Delegates functionality to MetricsEnvironment.""" def __init__(self, metric_name: MetricName) -> None: super().__init__(metric_name) - self.add = MetricUpdater(cells.StringSetCell, metric_name) # type: ignore[assignment] + self.add = MetricUpdater(cells.StringSetCell, metric_name) # type: ignore[method-assign] class MetricResults(object): diff --git a/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py b/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py index eab547b1c17b..e09f116dfb38 100644 --- a/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/xgboost_inference_test.py @@ -53,7 +53,7 @@ def _compare_prediction_result(a: PredictionResult, b: PredictionResult): example_equal = numpy.array_equal(a.example.todense(), b.example.todense()) else: - example_equal = numpy.array_equal(a.example, b.example) + example_equal = numpy.array_equal(a.example, b.example) # type: ignore[arg-type] if isinstance(a.inference, dict): return all( x == y for x, y in zip(a.inference.values(), diff --git a/sdks/python/apache_beam/ml/transforms/base_test.py b/sdks/python/apache_beam/ml/transforms/base_test.py index 743c3683ce8e..0e65b350211e 100644 --- a/sdks/python/apache_beam/ml/transforms/base_test.py +++ b/sdks/python/apache_beam/ml/transforms/base_test.py @@ -345,7 +345,7 @@ def __init__(self, columns, **kwargs): super().__init__(columns=columns, **kwargs) def get_model_handler(self) -> ModelHandler: - FakeModelHandler.__repr__ = lambda x: 'FakeEmbeddingsManager' # type: ignore[assignment] + FakeModelHandler.__repr__ = lambda x: 'FakeEmbeddingsManager' # type: ignore[method-assign] return FakeModelHandler() def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: @@ -532,7 +532,7 @@ def __init__(self, columns, **kwargs): super().__init__(columns=columns, **kwargs) def get_model_handler(self) -> ModelHandler: - FakeModelHandler.__repr__ = lambda x: 'FakeImageEmbeddingsManager' # type: ignore[assignment] + FakeModelHandler.__repr__ = lambda x: 'FakeImageEmbeddingsManager' # type: ignore[method-assign] return FakeImageModelHandler() def get_ptransform_for_processing(self, **kwargs) -> beam.PTransform: diff --git a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/__init__.py b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/__init__.py index 239ee8c700a2..c0d20c3ec8f9 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/__init__.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/clients/dataflow/__init__.py @@ -30,4 +30,4 @@ pass # pylint: enable=wrong-import-order, wrong-import-position -__path__ = pkgutil.extend_path(__path__, __name__) # type: ignore +__path__ = pkgutil.extend_path(__path__, __name__) diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 0a03c96bc19b..13ab665c1eb1 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -306,7 +306,7 @@ def get_or_create_environment_with_resource_hints( # "Message"; expected "Environment" [arg-type] # Here, Environment is a subclass of Message but mypy still # throws an error. - cloned_env.CopyFrom(template_env) # type: ignore[arg-type] + cloned_env.CopyFrom(template_env) cloned_env.resource_hints.clear() cloned_env.resource_hints.update(resource_hints) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py index 07af1c958cfd..c1c7f649f77a 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py @@ -256,7 +256,7 @@ def has_as_main_input(self, pcoll): transform.spec.payload, beam_runner_api_pb2.ParDoPayload) local_side_inputs = payload.side_inputs else: - local_side_inputs = {} # type: ignore[assignment] + local_side_inputs = {} for local_id, pipeline_id in transform.inputs.items(): if pcoll == pipeline_id and local_id not in local_side_inputs: return True diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py index c5423e167026..d798e96d3aa3 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py @@ -1071,7 +1071,7 @@ def get_raw(self, if state_key.WhichOneof('type') not in self._SUPPORTED_STATE_TYPES: raise NotImplementedError( - 'Unknown state type: ' + state_key.WhichOneof('type')) + 'Unknown state type: ' + state_key.WhichOneof('type')) # type: ignore[operator] with self._lock: if not continuation_token: diff --git a/sdks/python/apache_beam/runners/portability/local_job_service.py b/sdks/python/apache_beam/runners/portability/local_job_service.py index 869f013d0d26..a2b4e5e7f939 100644 --- a/sdks/python/apache_beam/runners/portability/local_job_service.py +++ b/sdks/python/apache_beam/runners/portability/local_job_service.py @@ -35,7 +35,7 @@ import grpc from google.protobuf import json_format from google.protobuf import struct_pb2 -from google.protobuf import text_format # type: ignore # not in typeshed +from google.protobuf import text_format from apache_beam import pipeline from apache_beam.metrics import monitoring_infos diff --git a/sdks/python/apache_beam/runners/render.py b/sdks/python/apache_beam/runners/render.py index fccfa8aacd61..45e66e1ba06a 100644 --- a/sdks/python/apache_beam/runners/render.py +++ b/sdks/python/apache_beam/runners/render.py @@ -64,7 +64,7 @@ import urllib.parse from google.protobuf import json_format -from google.protobuf import text_format # type: ignore +from google.protobuf import text_format from apache_beam.options import pipeline_options from apache_beam.portability.api import beam_runner_api_pb2 diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py index 88cc3c9791d5..979c7cdb53be 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler.py +++ b/sdks/python/apache_beam/runners/worker/log_handler.py @@ -125,7 +125,7 @@ def emit(self, record: logging.LogRecord) -> None: log_entry.message = ( "Failed to format '%s' with args '%s' during logging." % (str(record.msg), record.args)) - log_entry.thread = record.threadName + log_entry.thread = record.threadName # type: ignore[assignment] log_entry.log_location = '%s:%s' % ( record.pathname or record.module, record.lineno or record.funcName) (fraction, seconds) = math.modf(record.created) diff --git a/sdks/python/apache_beam/runners/worker/opcounters.py b/sdks/python/apache_beam/runners/worker/opcounters.py index 51ca4cf0545b..5496bccd014e 100644 --- a/sdks/python/apache_beam/runners/worker/opcounters.py +++ b/sdks/python/apache_beam/runners/worker/opcounters.py @@ -259,7 +259,7 @@ def do_sample(self, windowed_value): self.type_check(windowed_value.value) size, observables = ( - self.coder_impl.get_estimated_size_and_observables(windowed_value)) + self.coder_impl.get_estimated_size_and_observables(windowed_value)) # type: ignore[union-attr] if not observables: self.current_size = size else: diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 2a1423fccba9..b091220a06b5 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -1335,8 +1335,8 @@ def _get_cache_token(self, state_key): return self._context.user_state_cache_token else: return self._context.bundle_cache_token - elif state_key.WhichOneof('type').endswith('_side_input'): - side_input = getattr(state_key, state_key.WhichOneof('type')) + elif state_key.WhichOneof('type').endswith('_side_input'): # type: ignore[union-attr] + side_input = getattr(state_key, state_key.WhichOneof('type')) # type: ignore[arg-type] return self._context.side_input_cache_tokens.get( (side_input.transform_id, side_input.side_input_id), self._context.bundle_cache_token) diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py index cd49c69a80aa..3389f0c7afb1 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -27,7 +27,7 @@ import sys import traceback -from google.protobuf import text_format # type: ignore # not in typeshed +from google.protobuf import text_format from apache_beam.internal import pickler from apache_beam.io import filesystems diff --git a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/tfdv_analyze_and_validate.py b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/tfdv_analyze_and_validate.py index 87c631762287..ba3fae6819bd 100644 --- a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/tfdv_analyze_and_validate.py +++ b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/tfdv_analyze_and_validate.py @@ -29,7 +29,7 @@ from apache_beam.testing.load_tests.load_test_metrics_utils import MeasureTime from apache_beam.testing.load_tests.load_test_metrics_utils import MetricsReader -from google.protobuf import text_format # type: ignore # typeshed out of date +from google.protobuf import text_format from trainer import taxi diff --git a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py index 6d84c995bfca..88ed53e11fc4 100644 --- a/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py +++ b/sdks/python/apache_beam/testing/benchmarks/chicago_taxi/trainer/taxi.py @@ -18,7 +18,7 @@ from tensorflow_transform import coders as tft_coders from tensorflow_transform.tf_metadata import schema_utils -from google.protobuf import text_format # type: ignore # typeshed out of date +from google.protobuf import text_format from tensorflow.python.lib.io import file_io from tensorflow_metadata.proto.v0 import schema_pb2 diff --git a/sdks/python/apache_beam/testing/load_tests/sideinput_test.py b/sdks/python/apache_beam/testing/load_tests/sideinput_test.py index 745d961d2aac..3b5dfdf38cd9 100644 --- a/sdks/python/apache_beam/testing/load_tests/sideinput_test.py +++ b/sdks/python/apache_beam/testing/load_tests/sideinput_test.py @@ -111,7 +111,7 @@ class SequenceSideInputTestDoFn(beam.DoFn): def __init__(self, first_n: int): self._first_n = first_n - def process( # type: ignore[override] + def process( self, element: Any, side_input: Iterable[Tuple[bytes, bytes]]) -> None: i = 0 @@ -129,7 +129,7 @@ class MappingSideInputTestDoFn(beam.DoFn): def __init__(self, first_n: int): self._first_n = first_n - def process( # type: ignore[override] + def process( self, element: Any, dict_side_input: Dict[bytes, bytes]) -> None: i = 0 for key in dict_side_input: @@ -146,7 +146,7 @@ def __init__(self): # Avoid having to use save_main_session self.window = window - def process(self, element: int) -> Iterable[window.TimestampedValue]: # type: ignore[override] + def process(self, element: int) -> Iterable[window.TimestampedValue]: yield self.window.TimestampedValue(element, element) class GetSyntheticSDFOptions(beam.DoFn): @@ -156,7 +156,7 @@ def __init__( self.key_size = key_size self.value_size = value_size - def process(self, element: Any) -> Iterable[Dict[str, Union[int, str]]]: # type: ignore[override] + def process(self, element: Any) -> Iterable[Dict[str, Union[int, str]]]: yield { 'num_records': self.elements_per_record, 'key_size': self.key_size, diff --git a/sdks/python/apache_beam/testing/test_stream_service_test.py b/sdks/python/apache_beam/testing/test_stream_service_test.py index 5bfd0c104ba0..a04fa2303d08 100644 --- a/sdks/python/apache_beam/testing/test_stream_service_test.py +++ b/sdks/python/apache_beam/testing/test_stream_service_test.py @@ -30,9 +30,9 @@ # Nose automatically detects tests if they match a regex. Here, it mistakens # these protos as tests. For more info see the Nose docs at: # https://nose.readthedocs.io/en/latest/writing_tests.html -beam_runner_api_pb2.TestStreamPayload.__test__ = False # type: ignore[attr-defined] -beam_interactive_api_pb2.TestStreamFileHeader.__test__ = False # type: ignore[attr-defined] -beam_interactive_api_pb2.TestStreamFileRecord.__test__ = False # type: ignore[attr-defined] +beam_runner_api_pb2.TestStreamPayload.__test__ = False +beam_interactive_api_pb2.TestStreamFileHeader.__test__ = False +beam_interactive_api_pb2.TestStreamFileRecord.__test__ = False class EventsReader: diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index b30e9e0b70c7..9c798d3ce6dc 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1415,7 +1415,7 @@ class PartitionFn(WithTypeHints): def default_label(self): return self.__class__.__name__ - def partition_for(self, element, num_partitions, *args, **kwargs): + def partition_for(self, element, num_partitions, *args, **kwargs): # type: ignore[empty-body] # type: (T, int, *typing.Any, **typing.Any) -> int """Specify which partition will receive this element. @@ -3451,7 +3451,7 @@ def _dynamic_named_tuple(type_name, field_names): type_name, field_names) # typing: can't override a method. also, self type is unknown and can't # be cast to tuple - result.__reduce__ = lambda self: ( # type: ignore[assignment] + result.__reduce__ = lambda self: ( # type: ignore[method-assign] _unpickle_dynamic_named_tuple, (type_name, field_names, tuple(self))) # type: ignore[arg-type] return result diff --git a/sdks/python/apache_beam/transforms/display.py b/sdks/python/apache_beam/transforms/display.py index 86bbf101f567..14cd485d1f8e 100644 --- a/sdks/python/apache_beam/transforms/display.py +++ b/sdks/python/apache_beam/transforms/display.py @@ -173,7 +173,7 @@ def create_payload(dd) -> Optional[beam_runner_api_pb2.LabelledPayload]: elif isinstance(value, (float, complex)): return beam_runner_api_pb2.LabelledPayload( label=label, - double_value=value, + double_value=value, # type: ignore[arg-type] key=display_data_dict['key'], namespace=display_data_dict.get('namespace', '')) else: diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 8a04e7efb195..83c439ca8ddd 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -740,7 +740,7 @@ def expand(self, pvalueish): components = context.to_runner_api() request = beam_expansion_api_pb2.ExpansionRequest( components=components, - namespace=self._external_namespace, + namespace=self._external_namespace, # type: ignore[arg-type] transform=transform_proto, output_coder_requests=output_coders, pipeline_options=pipeline._options.to_runner_api()) diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 8554ebce5dbd..6ec741705376 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -748,7 +748,7 @@ def to_runner_api(self, context, has_parts=False, **extra_kwargs): # type: (PipelineContext, bool, Any) -> beam_runner_api_pb2.FunctionSpec from apache_beam.portability.api import beam_runner_api_pb2 # typing: only ParDo supports extra_kwargs - urn, typed_param = self.to_runner_api_parameter(context, **extra_kwargs) # type: ignore[call-arg] + urn, typed_param = self.to_runner_api_parameter(context, **extra_kwargs) if urn == python_urns.GENERIC_COMPOSITE_TRANSFORM and not has_parts: # TODO(https://github.com/apache/beam/issues/18713): Remove this fallback. urn, typed_param = self.to_runner_api_pickled(context) diff --git a/sdks/python/apache_beam/transforms/stats.py b/sdks/python/apache_beam/transforms/stats.py index d389463e55a2..0d56b60b050f 100644 --- a/sdks/python/apache_beam/transforms/stats.py +++ b/sdks/python/apache_beam/transforms/stats.py @@ -919,7 +919,7 @@ def _offset(self, new_weight): # TODO(https://github.com/apache/beam/issues/19737): Signature incompatible # with supertype - def create_accumulator(self): # type: ignore[override] + def create_accumulator(self): # type: () -> _QuantileState self._qs = _QuantileState( unbuffered_elements=[], diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py index cad733538111..3b876bf9dbfb 100644 --- a/sdks/python/apache_beam/transforms/userstate.py +++ b/sdks/python/apache_beam/transforms/userstate.py @@ -299,7 +299,7 @@ def validate_stateful_dofn(dofn: 'DoFn') -> None: 'callback: %s.') % (dofn, timer_spec)) method_name = timer_spec._attached_callback.__name__ if (timer_spec._attached_callback != getattr(dofn, method_name, - None).__func__): + None).__func__): # type: ignore[union-attr] raise ValueError(( 'The on_timer callback for %s is not the specified .%s method ' 'for DoFn %r (perhaps it was overwritten?).') % @@ -314,7 +314,7 @@ def set(self, timestamp: Timestamp, dynamic_timer_tag: str = '') -> None: raise NotImplementedError -_TimerTuple = collections.namedtuple('timer_tuple', ('cleared', 'timestamp')) +_TimerTuple = collections.namedtuple('timer_tuple', ('cleared', 'timestamp')) # type: ignore[name-match] class RuntimeTimer(BaseTimer): diff --git a/sdks/python/apache_beam/utils/profiler.py b/sdks/python/apache_beam/utils/profiler.py index c75fdcc5878d..61c2371bd07d 100644 --- a/sdks/python/apache_beam/utils/profiler.py +++ b/sdks/python/apache_beam/utils/profiler.py @@ -104,7 +104,9 @@ def __exit__(self, *args): self.profile.create_stats() self.profile_output = self._upload_profile_data( # typing: seems stats attr is missing from typeshed - self.profile_location, 'cpu_profile', self.profile.stats) # type: ignore[attr-defined] + self.profile_location, + 'cpu_profile', + self.profile.stats) if self.enable_memory_profiling: if not self.hpy: diff --git a/sdks/python/apache_beam/utils/proto_utils.py b/sdks/python/apache_beam/utils/proto_utils.py index 9a93c9e48ea3..60c0af2ebac0 100644 --- a/sdks/python/apache_beam/utils/proto_utils.py +++ b/sdks/python/apache_beam/utils/proto_utils.py @@ -46,7 +46,7 @@ def pack_Any(msg: message.Message) -> any_pb2.Any: @overload -def pack_Any(msg: None) -> None: +def pack_Any(msg: None) -> None: # type: ignore[overload-cannot-match] pass diff --git a/sdks/python/mypy.ini b/sdks/python/mypy.ini index 562cb8d56dcc..ee76089fec0b 100644 --- a/sdks/python/mypy.ini +++ b/sdks/python/mypy.ini @@ -28,11 +28,17 @@ files = apache_beam color_output = true # uncomment this to see how close we are to being complete # check_untyped_defs = true -disable_error_code = var-annotated +disable_error_code = var-annotated, import-untyped, valid-type, truthy-function, attr-defined, annotation-unchecked + +[tool.mypy] +ignore_missing_imports = true [mypy-apache_beam.coders.proto2_coder_test_messages_pb2] ignore_errors = true +[mypy-apache_beam.dataframe.*] +ignore_errors = true + [mypy-apache_beam.examples.*] ignore_errors = true diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index 8cdc4a98bbfe..c7713498d87d 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -149,7 +149,7 @@ commands = [testenv:mypy] deps = - mypy==0.790 + mypy==1.13.0 dask==2022.01.0 distributed==2022.01.0 # make extras available in case any of these libs are typed From c7a161d23e588b57df6e4a304bf7108867afb51a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 15:11:26 -0400 Subject: [PATCH 096/181] Bump github.com/nats-io/nats-server/v2 from 2.10.18 to 2.10.22 in /sdks (#32856) Bumps [github.com/nats-io/nats-server/v2](https://github.com/nats-io/nats-server) from 2.10.18 to 2.10.22. - [Release notes](https://github.com/nats-io/nats-server/releases) - [Changelog](https://github.com/nats-io/nats-server/blob/main/.goreleaser.yml) - [Commits](https://github.com/nats-io/nats-server/compare/v2.10.18...v2.10.22) --- updated-dependencies: - dependency-name: github.com/nats-io/nats-server/v2 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- sdks/go.mod | 4 ++-- sdks/go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index ed7e58b9a7bb..3f7bdc2c8ce4 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -44,7 +44,7 @@ require ( github.com/johannesboyne/gofakes3 v0.0.0-20221110173912-32fb85c5aed6 github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.13.0 - github.com/nats-io/nats-server/v2 v2.10.18 + github.com/nats-io/nats-server/v2 v2.10.22 github.com/nats-io/nats.go v1.37.0 github.com/proullon/ramsql v0.1.4 github.com/spf13/cobra v1.8.1 @@ -167,7 +167,7 @@ require ( github.com/gorilla/handlers v1.5.2 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/klauspost/compress v1.17.9 // indirect + github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/moby/patternmatcher v0.6.0 // indirect diff --git a/sdks/go.sum b/sdks/go.sum index 1c09fbb1710b..640a4233f982 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -1029,8 +1029,8 @@ github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= -github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= -github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= @@ -1085,8 +1085,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/nats-io/jwt/v2 v2.5.8 h1:uvdSzwWiEGWGXf+0Q+70qv6AQdvcvxrv9hPM0RiPamE= github.com/nats-io/jwt/v2 v2.5.8/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= -github.com/nats-io/nats-server/v2 v2.10.18 h1:tRdZmBuWKVAFYtayqlBB2BuCHNGAQPvoQIXOKwU3WSM= -github.com/nats-io/nats-server/v2 v2.10.18/go.mod h1:97Qyg7YydD8blKlR8yBsUlPlWyZKjA7Bp5cl3MUE9K8= +github.com/nats-io/nats-server/v2 v2.10.22 h1:Yt63BGu2c3DdMoBZNcR6pjGQwk/asrKU7VX846ibxDA= +github.com/nats-io/nats-server/v2 v2.10.22/go.mod h1:X/m1ye9NYansUXYFrbcDwUi/blHkrgHh2rgCJaakonk= github.com/nats-io/nats.go v1.37.0 h1:07rauXbVnnJvv1gfIyghFEo6lUcYRY0WXc3x7x0vUxE= github.com/nats-io/nats.go v1.37.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8= github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= From afaa4c3f7d1709bd5eac229566c9eced1b959fe8 Mon Sep 17 00:00:00 2001 From: Jeff Kinard Date: Thu, 31 Oct 2024 15:19:51 -0400 Subject: [PATCH 097/181] [yaml] Enhance YAML API docs (#32825) * [yaml] Enhance YAML API docs Signed-off-by: Jeffrey Kinard * use single html file Signed-off-by: Jeffrey Kinard * fix test failures Signed-off-by: Jeffrey Kinard * rebase on master Signed-off-by: Jeffrey Kinard --------- Signed-off-by: Jeffrey Kinard --- .../apache_beam/yaml/generate_yaml_docs.py | 164 +++++++++++++++--- sdks/python/apache_beam/yaml/tests/map.yaml | 6 +- 2 files changed, 141 insertions(+), 29 deletions(-) diff --git a/sdks/python/apache_beam/yaml/generate_yaml_docs.py b/sdks/python/apache_beam/yaml/generate_yaml_docs.py index 4719bc3e66aa..84a5e62f0abd 100644 --- a/sdks/python/apache_beam/yaml/generate_yaml_docs.py +++ b/sdks/python/apache_beam/yaml/generate_yaml_docs.py @@ -24,6 +24,7 @@ from apache_beam.portability.api import schema_pb2 from apache_beam.utils import subprocess_server +from apache_beam.version import __version__ as beam_version from apache_beam.yaml import json_utils from apache_beam.yaml import yaml_provider @@ -284,42 +285,143 @@ def main(): markdown.extensions.toc.TocExtension(toc_depth=2), 'codehilite', ]) - html = md.convert(markdown_out.getvalue()) pygments_style = pygments.formatters.HtmlFormatter().get_style_defs( '.codehilite') extra_style = ''' - .nav { - height: 100%; - width: 12em; + * { + box-sizing: border-box; + } + body { + font-family: 'Roboto', sans-serif; + font-weight: normal; + color: #404040; + background: #edf0f2; + } + .body-for-nav { + background: #fcfcfc; + } + .grid-for-nav { + width: 100%; + } + .nav-side { position: fixed; top: 0; left: 0; - overflow-x: hidden; + width: 300px; + height: 100%; + padding-bottom: 2em; + color: #9b9b9b; + background: #343131; } - .nav a { - color: #333; - padding: .2em; + .nav-header { display: block; - text-decoration: none; + width: 300px; + padding: 1em; + background-color: #2980B9; + text-align: center; + color: #fcfcfc; + } + .nav-header a { + color: #fcfcfc; + font-weight: bold; + display: inline-block; + padding: 4px 6px; + margin-bottom: 1em; + text-decoration:none; + } + .nav-header>div.version { + margin-top: -.5em; + margin-bottom: 1em; + font-weight: normal; + color: rgba(255, 255, 255, 0.3); } - .nav a:hover { - color: #888; + .toc { + width: 300px; + text-align: left; + overflow-y: auto; + max-height: calc(100% - 4.3em); + scrollbar-width: thin; + scrollbar-color: #9b9b9b #343131; } - .nav li { - list-style-type: none; + .toc ul { margin: 0; padding: 0; + list-style: none; } - .content { - margin-left: 12em; + .toc li { + border-bottom: 1px solid #4e4a4a; + margin-left: 1em; + } + .toc a { + display: block; + line-height: 36px; + font-size: 90%; + color: #d9d9d9; + padding: .1em 0.6em; + text-decoration: none; + transition: background-color 0.3s ease, color 0.3s ease; } - h2 { - margin-top: 2em; + .toc a:hover { + background-color: #4e4a4a; + color: #ffffff; + } + .transform-content-wrap { + margin-left: 300px; + background: #fcfcfc; + } + .transform-content { + padding: 1.5em 3em; + margin: 20px; + padding-bottom: 2em; + } + .transform-content li::marker { + display: inline-block; + width: 0.5em; + } + .transform-content h1 { + font-size: 40px; + } + .transform-content ul { + margin-left: 0.75em; + text-align: left; + list-style-type: disc; + } + hr { + color: gray; + display: block; + height: 1px; + border: 0; + border-top: 1px solid #e1e4e5; + margin-bottom: 3em; + margin-top: 3em; + padding: 0; + } + .codehilite { + background: #f5f5f5; + border: 1px solid #ccc; + border-radius: 4px; + padding: 0.2em 1em; + overflow: auto; + font-family: monospace; + font-size: 14px; + line-height: 1.5; + } + p code, li code { + white-space: nowrap; + max-width: 100%; + background: #fff; + border: solid 1px #e1e4e5; + padding: 0 5px; + font-family: monospace; + color: #404040; + font-weight: bold; + padding: 2px 5px; } ''' - with open(options.html_file, 'w') as fout: - fout.write( + html = md.convert(markdown_out.getvalue()) + with open(options.html_file, 'w') as html_out: + html_out.write( f''' @@ -329,13 +431,23 @@ def main(): {extra_style} - -

-
-

{title}

- {html} + +
+ +
+
+

{title}

+ {html.replace(' +
diff --git a/sdks/python/apache_beam/yaml/tests/map.yaml b/sdks/python/apache_beam/yaml/tests/map.yaml index bbb7fc4527de..04f057cb2e82 100644 --- a/sdks/python/apache_beam/yaml/tests/map.yaml +++ b/sdks/python/apache_beam/yaml/tests/map.yaml @@ -31,8 +31,8 @@ pipelines: append: true fields: # TODO(https://github.com/apache/beam/issues/32832): Figure out why Java sometimes re-orders these fields. - literal_int: 10 named_field: element + literal_int: 10 literal_float: 1.5 literal_str: '"abc"' @@ -43,5 +43,5 @@ pipelines: - type: AssertEqual config: elements: - - {element: 100, literal_int: 10, named_field: 100, literal_float: 1.5, literal_str: "abc"} - - {element: 200, literal_int: 10, named_field: 200, literal_float: 1.5, literal_str: "abc"} + - {element: 100, named_field: 100, literal_int: 10, literal_float: 1.5, literal_str: "abc"} + - {element: 200, named_field: 200, literal_int: 10, literal_float: 1.5, literal_str: "abc"} From 5226b73c229803295a0d0e30c433ce4d2edd22be Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 31 Oct 2024 15:20:08 -0400 Subject: [PATCH 098/181] Bump google.golang.org/api from 0.202.0 to 0.203.0 in /sdks (#32918) Bumps [google.golang.org/api](https://github.com/googleapis/google-api-go-client) from 0.202.0 to 0.203.0. - [Release notes](https://github.com/googleapis/google-api-go-client/releases) - [Changelog](https://github.com/googleapis/google-api-go-client/blob/main/CHANGES.md) - [Commits](https://github.com/googleapis/google-api-go-client/compare/v0.202.0...v0.203.0) --- updated-dependencies: - dependency-name: google.golang.org/api dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- sdks/go.mod | 4 ++-- sdks/go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index 3f7bdc2c8ce4..81221f98e276 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -58,7 +58,7 @@ require ( golang.org/x/sync v0.8.0 golang.org/x/sys v0.26.0 golang.org/x/text v0.19.0 - google.golang.org/api v0.202.0 + google.golang.org/api v0.203.0 google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53 google.golang.org/grpc v1.67.1 google.golang.org/protobuf v1.35.1 @@ -75,7 +75,7 @@ require ( require ( cel.dev/expr v0.16.1 // indirect - cloud.google.com/go/auth v0.9.8 // indirect + cloud.google.com/go/auth v0.9.9 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect cloud.google.com/go/monitoring v1.21.1 // indirect dario.cat/mergo v1.0.0 // indirect diff --git a/sdks/go.sum b/sdks/go.sum index 640a4233f982..a45baf72a02b 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -101,8 +101,8 @@ cloud.google.com/go/assuredworkloads v1.7.0/go.mod h1:z/736/oNmtGAyU47reJgGN+KVo cloud.google.com/go/assuredworkloads v1.8.0/go.mod h1:AsX2cqyNCOvEQC8RMPnoc0yEarXQk6WEKkxYfL6kGIo= cloud.google.com/go/assuredworkloads v1.9.0/go.mod h1:kFuI1P78bplYtT77Tb1hi0FMxM0vVpRC7VVoJC3ZoT0= cloud.google.com/go/assuredworkloads v1.10.0/go.mod h1:kwdUQuXcedVdsIaKgKTp9t0UJkE5+PAVNhdQm4ZVq2E= -cloud.google.com/go/auth v0.9.8 h1:+CSJ0Gw9iVeSENVCKJoLHhdUykDgXSc4Qn+gu2BRtR8= -cloud.google.com/go/auth v0.9.8/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= +cloud.google.com/go/auth v0.9.9 h1:BmtbpNQozo8ZwW2t7QJjnrQtdganSdmqeIBxHxNkEZQ= +cloud.google.com/go/auth v0.9.9/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= cloud.google.com/go/auth/oauth2adapt v0.2.4 h1:0GWE/FUsXhf6C+jAkWgYm7X9tK8cuEIfy19DBn6B6bY= cloud.google.com/go/auth/oauth2adapt v0.2.4/go.mod h1:jC/jOpwFP6JBxhB3P5Rr0a9HLMC/Pe3eaL4NmdvqPtc= cloud.google.com/go/automl v1.5.0/go.mod h1:34EjfoFGMZ5sgJ9EoLsRtdPSNZLcfflJR39VbVNS2M0= @@ -1705,8 +1705,8 @@ google.golang.org/api v0.108.0/go.mod h1:2Ts0XTHNVWxypznxWOYUeI4g3WdP9Pk2Qk58+a/ google.golang.org/api v0.110.0/go.mod h1:7FC4Vvx1Mooxh8C5HWjzZHcavuS2f6pmJpZx60ca7iI= google.golang.org/api v0.111.0/go.mod h1:qtFHvU9mhgTJegR31csQ+rwxyUTHOKFqCKWp1J0fdw0= google.golang.org/api v0.114.0/go.mod h1:ifYI2ZsFK6/uGddGfAD5BMxlnkBqCmqHSDUVi45N5Yg= -google.golang.org/api v0.202.0 h1:y1iuVHMqokQbimW79ZqPZWo4CiyFu6HcCYHwSNyzlfo= -google.golang.org/api v0.202.0/go.mod h1:3Jjeq7M/SFblTNCp7ES2xhq+WvGL0KeXI0joHQBfwTQ= +google.golang.org/api v0.203.0 h1:SrEeuwU3S11Wlscsn+LA1kb/Y5xT8uggJSkIhD08NAU= +google.golang.org/api v0.203.0/go.mod h1:BuOVyCSYEPwJb3npWvDnNmFI92f3GeRnHNkETneT3SI= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= From 8c235043ec05788516f583f6e1d829887358aec8 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Thu, 31 Oct 2024 16:18:33 -0400 Subject: [PATCH 099/181] Swallow errors removing awaiting triage label (#32989) --- .github/workflows/self-assign.yml | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/workflows/self-assign.yml b/.github/workflows/self-assign.yml index 084581db7340..6c2f2219b4e3 100644 --- a/.github/workflows/self-assign.yml +++ b/.github/workflows/self-assign.yml @@ -40,12 +40,16 @@ jobs: repo: context.repo.repo, assignees: [context.payload.comment.user.login] }); - github.rest.issues.removeLabel({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - name: 'awaiting triage' - }); + try { + github.rest.issues.removeLabel({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + name: 'awaiting triage' + }); + } catch (error) { + console.log(`Failed to remove awaiting triage label. It may not exist on this issue. Error ${error}`); + } } else if (bodyString == '.close-issue') { console.log('Closing issue'); if (i + 1 < body.length && body[i+1].toLowerCase() == 'not_planned') { From d3a841c100dd91f8daebd18bd807cfe438b6b988 Mon Sep 17 00:00:00 2001 From: Vlado Djerek Date: Thu, 31 Oct 2024 21:20:15 +0100 Subject: [PATCH 100/181] playground precommit move to selfhosted and update (#32987) * playground precommit move to selfhosted and update * playground precommit move to selfhosted and update --- ...mmit.yml => beam_Playground_Precommit.yml} | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) rename .github/workflows/{playground_backend_precommit.yml => beam_Playground_Precommit.yml} (75%) diff --git a/.github/workflows/playground_backend_precommit.yml b/.github/workflows/beam_Playground_Precommit.yml similarity index 75% rename from .github/workflows/playground_backend_precommit.yml rename to .github/workflows/beam_Playground_Precommit.yml index 9ba6cf20534f..edb50661b1ee 100644 --- a/.github/workflows/playground_backend_precommit.yml +++ b/.github/workflows/beam_Playground_Precommit.yml @@ -17,10 +17,12 @@ name: Playground PreCommit on: workflow_dispatch: - pull_request: + pull_request_target: paths: - .github/workflows/playground_backend_precommit.yml - playground/backend/** + issue_comment: + types: [created] env: DEVELOCITY_ACCESS_KEY: ${{ secrets.GE_ACCESS_TOKEN }} @@ -28,17 +30,30 @@ env: GRADLE_ENTERPRISE_CACHE_PASSWORD: ${{ secrets.GE_CACHE_PASSWORD }} jobs: - precommit_check: - name: precommit-check - runs-on: ubuntu-latest + beam_Playground_PreCommit: + if: | + github.event_name == 'workflow_dispatch' || + github.event_name == 'pull_request_target' || + github.event.comment.body == 'Run Playground PreCommit' + name: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) + runs-on: [self-hosted, ubuntu-20.04, main] + strategy: + fail-fast: false + matrix: + job_name: [beam_Playground_PreCommit] + job_phrase: [Run Playground PreCommit] env: DATASTORE_EMULATOR_VERSION: '423.0.0' PYTHON_VERSION: '3.9' JAVA_VERSION: '11' steps: - - name: Check out the repo - uses: actions/checkout@v4 - + - uses: actions/checkout@v4 + - name: Setup repository + uses: ./.github/actions/setup-action + with: + comment_phrase: ${{ matrix.job_phrase }} + github_token: ${{ secrets.GITHUB_TOKEN }} + github_job: ${{ matrix.job_name }} (${{ matrix.job_phrase }}) - name: Setup environment uses: ./.github/actions/setup-environment-action with: @@ -58,7 +73,7 @@ jobs: sudo chmod 644 /etc/apt/trusted.gpg.d/scalasbt-release.gpg sudo apt-get update --yes sudo apt-get install sbt --yes - sudo wget https://codeload.github.com/spotify/scio.g8/zip/7c1ba7c1651dfd70976028842e721da4107c0d6d -O scio.g8.zip && unzip scio.g8.zip && mv scio.g8-7c1ba7c1651dfd70976028842e721da4107c0d6d /opt/scio.g8 + sudo wget https://codeload.github.com/spotify/scio.g8/zip/7c1ba7c1651dfd70976028842e721da4107c0d6d -O scio.g8.zip && unzip scio.g8.zip && sudo mv scio.g8-7c1ba7c1651dfd70976028842e721da4107c0d6d /opt/scio.g8 - name: Set up Cloud SDK and its components uses: google-github-actions/setup-gcloud@v2 with: From 78421b5fc0b1030d3ac35f3e7d241e7491488234 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 30 Oct 2024 14:16:42 -0400 Subject: [PATCH 101/181] Suppress errors in JvmInitializer#beforeProcessing if successfully initialized before --- .../beam_PostCommit_Java_PVR_Spark_Batch.json | 2 +- .../EmbeddedEnvironmentFactory.java | 2 ++ .../apache/beam/sdk/fn/JvmInitializers.java | 24 +++++++++++++++---- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json index e0266d62f2e0..f1ba03a243ee 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 4 + "modification": 5 } diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/EmbeddedEnvironmentFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/EmbeddedEnvironmentFactory.java index 72fa991c1f73..470692e75103 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/EmbeddedEnvironmentFactory.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/environment/EmbeddedEnvironmentFactory.java @@ -140,6 +140,8 @@ public RemoteEnvironment createEnvironment(Environment environment, String worke try { fnHarness.get(); } catch (Throwable t) { + // Print stacktrace to stderr. Could be useful if underlying error not surfaced earlier + t.printStackTrace(); executor.shutdownNow(); } }); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/JvmInitializers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/JvmInitializers.java index f739a797af80..453c1cd79a42 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/JvmInitializers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/JvmInitializers.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.fn; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.sdk.harness.JvmInitializer; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.common.ReflectHelpers; @@ -25,6 +26,8 @@ /** Helpers for executing {@link JvmInitializer} implementations. */ public class JvmInitializers { + private static final AtomicBoolean initialized = new AtomicBoolean(false); + /** * Finds all registered implementations of JvmInitializer and executes their {@code onStartup} * methods. Should be called in worker harness implementations at the very beginning of their main @@ -50,10 +53,23 @@ public static void runBeforeProcessing(PipelineOptions options) { // We load the logger in the method to minimize the amount of class loading that happens // during class initialization. Logger logger = LoggerFactory.getLogger(JvmInitializers.class); - for (JvmInitializer initializer : ReflectHelpers.loadServicesOrdered(JvmInitializer.class)) { - logger.info("Running JvmInitializer#beforeProcessing for {}", initializer); - initializer.beforeProcessing(options); - logger.info("Completed JvmInitializer#beforeProcessing for {}", initializer); + + try { + for (JvmInitializer initializer : ReflectHelpers.loadServicesOrdered(JvmInitializer.class)) { + logger.info("Running JvmInitializer#beforeProcessing for {}", initializer); + initializer.beforeProcessing(options); + logger.info("Completed JvmInitializer#beforeProcessing for {}", initializer); + } + initialized.compareAndSet(false, true); + } catch (Error e) { + if (initialized.get()) { + logger.warn( + "Error at JvmInitializer#beforeProcessing. This error is suppressed after " + + "previous success runs. It is expected on Embedded environment", + e); + } else { + throw e; + } } } } From daa51ff35aa663066dc8b00b90db16525f89f1d5 Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Thu, 31 Oct 2024 18:31:36 -0400 Subject: [PATCH 102/181] whitespace change to trigger gh checks --- sdks/python/apache_beam/transforms/ptransform_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index a51d5cd83d26..460c9affc5b8 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -1465,6 +1465,7 @@ def bool_to_int(a): def test_filter_does_not_type_check_using_type_hints_method(self): # Filter is expecting an int but instead looks to the 'left' and sees a str # incoming. + with self.assertRaises(typehints.TypeCheckError) as e: ( self.p From 489ebc6f74cc5d11313fdd4f6b88c4620bd6d74b Mon Sep 17 00:00:00 2001 From: Joey Tran Date: Thu, 31 Oct 2024 18:31:49 -0400 Subject: [PATCH 103/181] Revert "whitespace change to trigger gh checks" This reverts commit daa51ff35aa663066dc8b00b90db16525f89f1d5. --- sdks/python/apache_beam/transforms/ptransform_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index 460c9affc5b8..a51d5cd83d26 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -1465,7 +1465,6 @@ def bool_to_int(a): def test_filter_does_not_type_check_using_type_hints_method(self): # Filter is expecting an int but instead looks to the 'left' and sees a str # incoming. - with self.assertRaises(typehints.TypeCheckError) as e: ( self.p From 2160737501ecffce5636d6e0e331a8b5c91a417c Mon Sep 17 00:00:00 2001 From: s21lee <51240829+s21lee@users.noreply.github.com> Date: Fri, 1 Nov 2024 08:42:44 +0900 Subject: [PATCH 104/181] Add JobServerOption for `--jar_cache_dir` (#32033) * Add JobServerOption for jar_cache_dir Signed-off-by: s21.lee * fixed for job_server_test error - add missing comma Signed-off-by: s21.lee * fix error for missing comma Signed-off-by: s21lee * fix lint Signed-off-by: s21lee * fix lint Signed-off-by: s21lee * fix for unit test error Signed-off-by: s21lee * fix lint Signed-off-by: s21lee * fix test error Signed-off-by: s21lee * fix lint Signed-off-by: s21lee * fix error for typo Signed-off-by: s21lee * fix test error Signed-off-by: s21lee * fix lint Signed-off-by: s21lee * fix error Signed-off-by: s21lee * fix error Signed-off-by: s21lee * fix lint Signed-off-by: s21lee * fix error Signed-off-by: s21lee * fix error Signed-off-by: s21lee * fix error Signed-off-by: s21lee * fix lint Signed-off-by: s21lee * fix lint Signed-off-by: s21lee * fix error and lint Signed-off-by: s21lee * fix lint Signed-off-by: s21lee * fix lint Signed-off-by: s21lee --------- Signed-off-by: s21.lee Signed-off-by: s21lee Co-authored-by: s21.lee --- sdks/python/apache_beam/options/pipeline_options.py | 4 ++++ .../apache_beam/options/pipeline_options_test.py | 6 ++++++ .../apache_beam/runners/portability/job_server.py | 7 ++++--- .../apache_beam/runners/portability/job_server_test.py | 3 ++- sdks/python/apache_beam/utils/subprocess_server.py | 10 ++++++++-- 5 files changed, 24 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 455d12b4d3c1..af0c5e3de66f 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -1674,6 +1674,10 @@ def _add_argparse_args(cls, parser): action='append', default=[], help='JVM properties to pass to a Java job server.') + parser.add_argument( + '--jar_cache_dir', + default=None, + help='The location to store jar cache for job server.') class FlinkRunnerOptions(PipelineOptions): diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py index c0616bc6451c..66acfe654791 100644 --- a/sdks/python/apache_beam/options/pipeline_options_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_test.py @@ -31,6 +31,7 @@ from apache_beam.options.pipeline_options import CrossLanguageOptions from apache_beam.options.pipeline_options import DebugOptions from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.options.pipeline_options import JobServerOptions from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.options.pipeline_options import ProfilingOptions from apache_beam.options.pipeline_options import TypeOptions @@ -645,6 +646,11 @@ def test_transform_name_mapping(self): mapping = options.view_as(GoogleCloudOptions).transform_name_mapping self.assertEqual(mapping['from'], 'to') + def test_jar_cache_dir(self): + options = PipelineOptions(['--jar_cache_dir=/path/to/jar_cache_dir']) + jar_cache_dir = options.view_as(JobServerOptions).jar_cache_dir + self.assertEqual(jar_cache_dir, '/path/to/jar_cache_dir') + def test_dataflow_service_options(self): options = PipelineOptions([ '--dataflow_service_option', diff --git a/sdks/python/apache_beam/runners/portability/job_server.py b/sdks/python/apache_beam/runners/portability/job_server.py index e44d8ab0ae93..eee75f66a277 100644 --- a/sdks/python/apache_beam/runners/portability/job_server.py +++ b/sdks/python/apache_beam/runners/portability/job_server.py @@ -127,6 +127,7 @@ def __init__(self, options): self._artifacts_dir = options.artifacts_dir self._java_launcher = options.job_server_java_launcher self._jvm_properties = options.job_server_jvm_properties + self._jar_cache_dir = options.jar_cache_dir def java_arguments( self, job_port, artifact_port, expansion_port, artifacts_dir): @@ -141,11 +142,11 @@ def path_to_beam_jar(gradle_target, artifact_id=None): gradle_target, artifact_id=artifact_id) @staticmethod - def local_jar(url): - return subprocess_server.JavaJarServer.local_jar(url) + def local_jar(url, jar_cache_dir=None): + return subprocess_server.JavaJarServer.local_jar(url, jar_cache_dir) def subprocess_cmd_and_endpoint(self): - jar_path = self.local_jar(self.path_to_jar()) + jar_path = self.local_jar(self.path_to_jar(), self._jar_cache_dir) artifacts_dir = ( self._artifacts_dir if self._artifacts_dir else self.local_temp_dir( prefix='artifacts')) diff --git a/sdks/python/apache_beam/runners/portability/job_server_test.py b/sdks/python/apache_beam/runners/portability/job_server_test.py index 1e2ede281c9d..13b3629b24bf 100644 --- a/sdks/python/apache_beam/runners/portability/job_server_test.py +++ b/sdks/python/apache_beam/runners/portability/job_server_test.py @@ -41,7 +41,8 @@ def path_to_jar(self): return '/path/to/jar' @staticmethod - def local_jar(url): + def local_jar(url, jar_cache_dir=None): + logging.debug("url({%s}), jar_cache_dir({%s})", url, jar_cache_dir) return url diff --git a/sdks/python/apache_beam/utils/subprocess_server.py b/sdks/python/apache_beam/utils/subprocess_server.py index 944c12625d7c..b1080cb643af 100644 --- a/sdks/python/apache_beam/utils/subprocess_server.py +++ b/sdks/python/apache_beam/utils/subprocess_server.py @@ -266,11 +266,17 @@ class JavaJarServer(SubprocessServer): 'local', (threading.local, ), dict(__init__=lambda self: setattr(self, 'replacements', {})))() - def __init__(self, stub_class, path_to_jar, java_arguments, classpath=None): + def __init__( + self, + stub_class, + path_to_jar, + java_arguments, + classpath=None, + cache_dir=None): if classpath: # java -jar ignores the classpath, so we make a new jar that embeds # the requested classpath. - path_to_jar = self.make_classpath_jar(path_to_jar, classpath) + path_to_jar = self.make_classpath_jar(path_to_jar, classpath, cache_dir) super().__init__( stub_class, ['java', '-jar', path_to_jar] + list(java_arguments)) self._existing_service = path_to_jar if is_service_endpoint( From d37ba37778dff3b14533825b01d127a5e1a35846 Mon Sep 17 00:00:00 2001 From: Rebecca Szper <98840847+rszper@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:54:04 -0700 Subject: [PATCH 105/181] Update docs for MapState and SetState support update (#32915) * Update docs for MapState and SetState support update * remove the code formatting * clarify state support for the Dataflow runner * clarify state support for the Dataflow runner * Update website/www/site/data/capability_matrix.yaml Co-authored-by: Danny McCormick --------- Co-authored-by: Danny McCormick --- website/www/site/data/capability_matrix.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/data/capability_matrix.yaml b/website/www/site/data/capability_matrix.yaml index dcbbca438b6e..e6fd51a9bb17 100644 --- a/website/www/site/data/capability_matrix.yaml +++ b/website/www/site/data/capability_matrix.yaml @@ -393,7 +393,7 @@ capability-matrix: - class: dataflow l1: "Partially" l2: non-merging windows - l3: State is supported for non-merging windows. SetState and MapState are not yet supported. + l3: "State is supported for non-merging windows. The MapState, SetState, and MultimapState state types are supported in the following scenarios: Java pipelines that don't use Streaming Engine; Java pipelines that use Streaming Engine and version 2.58.0 or later of the Java SDK. SetState, MapState, and MultimapState are not supported for pipelines that use Runner v2." - class: flink l1: "Partially" l2: non-merging windows From d749c08bd3820e625fc4f1892d464b91829b89dc Mon Sep 17 00:00:00 2001 From: Hai Joey Tran Date: Thu, 31 Oct 2024 23:46:26 -0400 Subject: [PATCH 106/181] fix typo (#32992) --- website/www/site/content/en/documentation/programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/www/site/content/en/documentation/programming-guide.md b/website/www/site/content/en/documentation/programming-guide.md index df6907f672f4..955c2b8797d1 100644 --- a/website/www/site/content/en/documentation/programming-guide.md +++ b/website/www/site/content/en/documentation/programming-guide.md @@ -6209,7 +6209,7 @@ class MyDoFn(beam.DoFn): self.gauge = metrics.Metrics.gauge("namespace", "gauge1") def process(self, element): - self.gaguge.set(element) + self.gauge.set(element) yield element {{< /highlight >}} From 0987e799756ad0e7cd8eb8eb9312bb61c95d6648 Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Fri, 1 Nov 2024 09:30:18 -0400 Subject: [PATCH 107/181] Update type hinting in examples to PEP 585 standards (#32970) --- .../apache_beam/examples/complete/estimate_pi.py | 9 ++++----- .../examples/cookbook/bigtableio_it_test.py | 3 +-- .../examples/cookbook/datastore_wordcount.py | 2 +- .../examples/cookbook/group_with_coder.py | 3 +-- .../inference/huggingface_language_modeling.py | 14 ++++++-------- .../inference/huggingface_question_answering.py | 5 ++--- .../inference/onnx_sentiment_classification.py | 9 ++++----- .../inference/pytorch_image_classification.py | 9 ++++----- ...ytorch_image_classification_with_side_inputs.py | 9 ++++----- .../inference/pytorch_image_segmentation.py | 9 ++++----- .../inference/pytorch_language_modeling.py | 14 ++++++-------- .../pytorch_model_per_key_image_segmentation.py | 13 ++++++------- .../inference/run_inference_side_inputs.py | 4 ++-- .../sklearn_japanese_housing_regression.py | 2 +- .../inference/sklearn_mnist_classification.py | 8 +++----- .../inference/tensorflow_imagenet_segmentation.py | 4 ++-- .../inference/tensorflow_mnist_classification.py | 7 +++---- .../inference/tensorrt_object_detection.py | 9 ++++----- .../tfx_bsl/tensorflow_image_classification.py | 9 ++++----- .../inference/vertex_ai_image_classification.py | 10 ++++------ .../examples/inference/vllm_text_completion.py | 2 +- .../inference/xgboost_iris_classification.py | 10 ++++------ .../apache_beam/examples/kafkataxi/kafka_taxi.py | 3 +-- .../apache_beam/examples/wordcount_xlang_sql.py | 4 ++-- 24 files changed, 74 insertions(+), 97 deletions(-) diff --git a/sdks/python/apache_beam/examples/complete/estimate_pi.py b/sdks/python/apache_beam/examples/complete/estimate_pi.py index 089767d2a99e..530a270308d9 100644 --- a/sdks/python/apache_beam/examples/complete/estimate_pi.py +++ b/sdks/python/apache_beam/examples/complete/estimate_pi.py @@ -30,9 +30,8 @@ import json import logging import random +from collections.abc import Iterable from typing import Any -from typing import Iterable -from typing import Tuple import apache_beam as beam from apache_beam.io import WriteToText @@ -40,7 +39,7 @@ from apache_beam.options.pipeline_options import SetupOptions -@beam.typehints.with_output_types(Tuple[int, int, int]) +@beam.typehints.with_output_types(tuple[int, int, int]) @beam.typehints.with_input_types(int) def run_trials(runs): """Run trials and return a 3-tuple representing the results. @@ -62,8 +61,8 @@ def run_trials(runs): return runs, inside_runs, 0 -@beam.typehints.with_output_types(Tuple[int, int, float]) -@beam.typehints.with_input_types(Iterable[Tuple[int, int, Any]]) +@beam.typehints.with_output_types(tuple[int, int, float]) +@beam.typehints.with_input_types(Iterable[tuple[int, int, Any]]) def combine_results(results): """Combiner function to sum up trials and compute the estimate. diff --git a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py index 6b5573aa4569..0a8c55d17d3a 100644 --- a/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py +++ b/sdks/python/apache_beam/examples/cookbook/bigtableio_it_test.py @@ -25,7 +25,6 @@ import unittest import uuid from typing import TYPE_CHECKING -from typing import List import pytest import pytz @@ -53,7 +52,7 @@ if TYPE_CHECKING: import google.cloud.bigtable.instance -EXISTING_INSTANCES: List['google.cloud.bigtable.instance.Instance'] = [] +EXISTING_INSTANCES: list['google.cloud.bigtable.instance.Instance'] = [] LABEL_KEY = 'python-bigtable-beam' label_stamp = datetime.datetime.utcnow().replace(tzinfo=UTC) label_stamp_micros = _microseconds_from_datetime(label_stamp) diff --git a/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py b/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py index 65ea7990a2d8..9d71ac32aff2 100644 --- a/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py +++ b/sdks/python/apache_beam/examples/cookbook/datastore_wordcount.py @@ -59,7 +59,7 @@ import logging import re import sys -from typing import Iterable +from collections.abc import Iterable from typing import Optional from typing import Text import uuid diff --git a/sdks/python/apache_beam/examples/cookbook/group_with_coder.py b/sdks/python/apache_beam/examples/cookbook/group_with_coder.py index 3ce7836b491a..8a959138d3da 100644 --- a/sdks/python/apache_beam/examples/cookbook/group_with_coder.py +++ b/sdks/python/apache_beam/examples/cookbook/group_with_coder.py @@ -30,7 +30,6 @@ import argparse import logging import sys -import typing import apache_beam as beam from apache_beam import coders @@ -71,7 +70,7 @@ def is_deterministic(self): # Annotate the get_players function so that the typehint system knows that the # input to the CombinePerKey operation is a key-value pair of a Player object # and an integer. -@with_output_types(typing.Tuple[Player, int]) +@with_output_types(tuple[Player, int]) def get_players(descriptor): name, points = descriptor.split(',') return Player(name), int(points) diff --git a/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py b/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py index 5eb57c8fc080..69c2eacc593d 100644 --- a/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py +++ b/sdks/python/apache_beam/examples/inference/huggingface_language_modeling.py @@ -27,10 +27,8 @@ import argparse import logging -from typing import Dict -from typing import Iterable -from typing import Iterator -from typing import Tuple +from collections.abc import Iterable +from collections.abc import Iterator import apache_beam as beam import torch @@ -45,14 +43,14 @@ from transformers import AutoTokenizer -def add_mask_to_last_word(text: str) -> Tuple[str, str]: +def add_mask_to_last_word(text: str) -> tuple[str, str]: text_list = text.split() return text, ' '.join(text_list[:-2] + ['', text_list[-1]]) def tokenize_sentence( - text_and_mask: Tuple[str, str], - tokenizer: AutoTokenizer) -> Tuple[str, Dict[str, torch.Tensor]]: + text_and_mask: tuple[str, str], + tokenizer: AutoTokenizer) -> tuple[str, dict[str, torch.Tensor]]: text, masked_text = text_and_mask tokenized_sentence = tokenizer.encode_plus(masked_text, return_tensors="pt") @@ -81,7 +79,7 @@ def __init__(self, tokenizer: AutoTokenizer): super().__init__() self.tokenizer = tokenizer - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: text, prediction_result = element inputs = prediction_result.example logits = prediction_result.inference['logits'] diff --git a/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py b/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py index 9005ea5d11d7..7d4899cc38d9 100644 --- a/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py +++ b/sdks/python/apache_beam/examples/inference/huggingface_question_answering.py @@ -28,8 +28,7 @@ import argparse import logging -from typing import Iterable -from typing import Tuple +from collections.abc import Iterable import apache_beam as beam from apache_beam.ml.inference.base import KeyedModelHandler @@ -49,7 +48,7 @@ class PostProcessor(beam.DoFn): Hugging Face Pipeline for Question Answering returns a dictionary with score, start and end index of answer and the answer. """ - def process(self, result: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, result: tuple[str, PredictionResult]) -> Iterable[str]: text, prediction = result predicted_answer = prediction.inference['answer'] yield text + ';' + predicted_answer diff --git a/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py b/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py index 18f697f673bf..0e62ab865431 100644 --- a/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py +++ b/sdks/python/apache_beam/examples/inference/onnx_sentiment_classification.py @@ -28,9 +28,8 @@ import argparse import logging -from typing import Iterable -from typing import Iterator -from typing import Tuple +from collections.abc import Iterable +from collections.abc import Iterator import numpy as np @@ -47,7 +46,7 @@ def tokenize_sentence(text: str, - tokenizer: RobertaTokenizer) -> Tuple[str, torch.Tensor]: + tokenizer: RobertaTokenizer) -> tuple[str, torch.Tensor]: tokenized_sentence = tokenizer.encode(text, add_special_tokens=True) # Workaround to manually remove batch dim until we have the feature to @@ -63,7 +62,7 @@ def filter_empty_lines(text: str) -> Iterator[str]: class PostProcessor(beam.DoFn): - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: filename, prediction_result = element prediction = np.argmax(prediction_result.inference, axis=0) yield filename + ';' + str(prediction) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py index d627001bcb82..c24a6d0a910e 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification.py @@ -21,9 +21,8 @@ import io import logging import os -from typing import Iterator +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import torch @@ -41,7 +40,7 @@ def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -122,13 +121,13 @@ def run( model_class = models.mobilenet_v2 model_params = {'num_classes': 1000} - def preprocess(image_name: str) -> Tuple[str, torch.Tensor]: + def preprocess(image_name: str) -> tuple[str, torch.Tensor]: image_name, image = read_image( image_file_name=image_name, path_to_dir=known_args.images_dir) return (image_name, preprocess_image(image)) - def postprocess(element: Tuple[str, PredictionResult]) -> str: + def postprocess(element: tuple[str, PredictionResult]) -> str: filename, prediction_result = element prediction = torch.argmax(prediction_result.inference, dim=0) return filename + ',' + str(prediction.item()) diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py index 2a4e6e9a9bc6..787341263fde 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_classification_with_side_inputs.py @@ -62,10 +62,9 @@ import io import logging import os -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import torch @@ -84,7 +83,7 @@ def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -116,7 +115,7 @@ class PostProcessor(beam.DoFn): Return filename, prediction and the model id used to perform the prediction """ - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: filename, prediction_result = element prediction = torch.argmax(prediction_result.inference, dim=0) yield filename, prediction, prediction_result.model_id diff --git a/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py b/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py index cdecb826d6e3..5e5f77a679c3 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_image_segmentation.py @@ -21,10 +21,9 @@ import io import logging import os -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import torch @@ -138,7 +137,7 @@ def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -161,7 +160,7 @@ def filter_empty_lines(text: str) -> Iterator[str]: class PostProcessor(beam.DoFn): - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: filename, prediction_result = element prediction_labels = prediction_result.inference['labels'] classes = [CLASS_ID_TO_NAME[label.item()] for label in prediction_labels] diff --git a/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py b/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py index 9de10e73e11b..a616998d2c73 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py @@ -26,10 +26,8 @@ import argparse import logging -from typing import Dict -from typing import Iterable -from typing import Iterator -from typing import Tuple +from collections.abc import Iterable +from collections.abc import Iterator import apache_beam as beam import torch @@ -45,14 +43,14 @@ from transformers import BertTokenizer -def add_mask_to_last_word(text: str) -> Tuple[str, str]: +def add_mask_to_last_word(text: str) -> tuple[str, str]: text_list = text.split() return text, ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]]) def tokenize_sentence( - text_and_mask: Tuple[str, str], - bert_tokenizer: BertTokenizer) -> Tuple[str, Dict[str, torch.Tensor]]: + text_and_mask: tuple[str, str], + bert_tokenizer: BertTokenizer) -> tuple[str, dict[str, torch.Tensor]]: text, masked_text = text_and_mask tokenized_sentence = bert_tokenizer.encode_plus( masked_text, return_tensors="pt") @@ -84,7 +82,7 @@ def __init__(self, bert_tokenizer: BertTokenizer): super().__init__() self.bert_tokenizer = bert_tokenizer - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: text, prediction_result = element inputs = prediction_result.example logits = prediction_result.inference['logits'] diff --git a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py index f0b5462d5335..18c4c3e653b4 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_model_per_key_image_segmentation.py @@ -24,10 +24,9 @@ import io import logging import os -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import torch @@ -143,7 +142,7 @@ def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -168,15 +167,15 @@ def filter_empty_lines(text: str) -> Iterator[str]: class KeyExamplesForEachModelType(beam.DoFn): """Duplicate data to run against each model type""" def process( - self, element: Tuple[torch.Tensor, - str]) -> Iterable[Tuple[str, torch.Tensor]]: + self, element: tuple[torch.Tensor, + str]) -> Iterable[tuple[str, torch.Tensor]]: yield 'v1', element[0] yield 'v2', element[0] class PostProcessor(beam.DoFn): def process( - self, element: Tuple[str, PredictionResult]) -> Tuple[torch.Tensor, str]: + self, element: tuple[str, PredictionResult]) -> tuple[torch.Tensor, str]: model, prediction_result = element prediction_labels = prediction_result.inference['labels'] classes = [CLASS_ID_TO_NAME[label.item()] for label in prediction_labels] diff --git a/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py b/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py index a6e4dc2bdb03..755eff17c163 100644 --- a/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py +++ b/sdks/python/apache_beam/examples/inference/run_inference_side_inputs.py @@ -22,9 +22,9 @@ import argparse import logging import time -from typing import Iterable +from collections.abc import Iterable +from collections.abc import Sequence from typing import Optional -from typing import Sequence import apache_beam as beam from apache_beam.ml.inference import base diff --git a/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py b/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py index 3aa2f362fa64..0a527e88dec2 100644 --- a/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py +++ b/sdks/python/apache_beam/examples/inference/sklearn_japanese_housing_regression.py @@ -31,7 +31,7 @@ import argparse import os -from typing import Iterable +from collections.abc import Iterable import pandas diff --git a/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py b/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py index 5392cdf7ddae..d7d08e294e9d 100644 --- a/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py +++ b/sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py @@ -27,9 +27,7 @@ import argparse import logging import os -from typing import Iterable -from typing import List -from typing import Tuple +from collections.abc import Iterable import apache_beam as beam from apache_beam.ml.inference.base import KeyedModelHandler @@ -42,7 +40,7 @@ from apache_beam.runners.runner import PipelineResult -def process_input(row: str) -> Tuple[int, List[int]]: +def process_input(row: str) -> tuple[int, list[int]]: data = row.split(',') label, pixels = int(data[0]), data[1:] pixels = [int(pixel) for pixel in pixels] @@ -53,7 +51,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns a comma separated string with true label and predicted label. """ - def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[int, PredictionResult]) -> Iterable[str]: label, prediction_result = element prediction = prediction_result.inference yield '{},{}'.format(label, prediction) diff --git a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py index a0f249dcfbf0..b44d775f4ad3 100644 --- a/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py +++ b/sdks/python/apache_beam/examples/inference/tensorflow_imagenet_segmentation.py @@ -17,8 +17,8 @@ import argparse import logging -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator import numpy diff --git a/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py b/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py index 6cf746e77cd2..bf85bb1aef16 100644 --- a/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py +++ b/sdks/python/apache_beam/examples/inference/tensorflow_mnist_classification.py @@ -17,8 +17,7 @@ import argparse import logging -from typing import Iterable -from typing import Tuple +from collections.abc import Iterable import numpy @@ -33,7 +32,7 @@ from apache_beam.runners.runner import PipelineResult -def process_input(row: str) -> Tuple[int, numpy.ndarray]: +def process_input(row: str) -> tuple[int, numpy.ndarray]: data = row.split(',') label, pixels = int(data[0]), data[1:] pixels = [int(pixel) for pixel in pixels] @@ -46,7 +45,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns a comma separated string with true label and predicted label. """ - def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[int, PredictionResult]) -> Iterable[str]: label, prediction_result = element prediction = numpy.argmax(prediction_result.inference, axis=0) yield '{},{}'.format(label, prediction) diff --git a/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py b/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py index 1faf502c71af..677d36b9b767 100644 --- a/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py +++ b/sdks/python/apache_beam/examples/inference/tensorrt_object_detection.py @@ -22,9 +22,8 @@ import argparse import io import os -from typing import Iterable +from collections.abc import Iterable from typing import Optional -from typing import Tuple import numpy as np @@ -134,14 +133,14 @@ def attach_im_size_to_key( - data: Tuple[str, Image.Image]) -> Tuple[Tuple[str, int, int], Image.Image]: + data: tuple[str, Image.Image]) -> tuple[tuple[str, int, int], Image.Image]: filename, image = data width, height = image.size return ((filename, width, height), image) def read_image(image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]: + path_to_dir: Optional[str] = None) -> tuple[str, Image.Image]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -168,7 +167,7 @@ class PostProcessor(beam.DoFn): an integer that we can transform into actual string class using COCO_OBJ_DET_CLASSES as reference. """ - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: key, prediction_result = element filename, im_width, im_height = key num_detections = prediction_result.inference[0] diff --git a/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py b/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py index 09a70caa4ede..5df0b51e36d7 100644 --- a/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/tfx_bsl/tensorflow_image_classification.py @@ -32,10 +32,9 @@ import io import logging import os -from typing import Iterable -from typing import Iterator +from collections.abc import Iterable +from collections.abc import Iterator from typing import Optional -from typing import Tuple import apache_beam as beam import tensorflow as tf @@ -60,7 +59,7 @@ def filter_empty_lines(text: str) -> Iterator[str]: def read_and_process_image( image_file_name: str, - path_to_dir: Optional[str] = None) -> Tuple[str, tf.Tensor]: + path_to_dir: Optional[str] = None) -> tuple[str, tf.Tensor]: if path_to_dir is not None: image_file_name = os.path.join(path_to_dir, image_file_name) with FileSystems().open(image_file_name, 'r') as file: @@ -97,7 +96,7 @@ def convert_image_to_example_proto(tensor: tf.Tensor) -> tf.train.Example: class ProcessInferenceToString(beam.DoFn): def process( - self, element: Tuple[str, + self, element: tuple[str, prediction_log_pb2.PredictionLog]) -> Iterable[str]: """ Args: diff --git a/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py b/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py index 73126569e988..20312e7d3c88 100644 --- a/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py +++ b/sdks/python/apache_beam/examples/inference/vertex_ai_image_classification.py @@ -27,9 +27,7 @@ import argparse import io import logging -from typing import Iterable -from typing import List -from typing import Tuple +from collections.abc import Iterable import apache_beam as beam import tensorflow as tf @@ -102,13 +100,13 @@ def parse_known_args(argv): COLUMNS = ['dandelion', 'daisy', 'tulips', 'sunflowers', 'roses'] -def read_image(image_file_name: str) -> Tuple[str, bytes]: +def read_image(image_file_name: str) -> tuple[str, bytes]: with FileSystems().open(image_file_name, 'r') as file: data = io.BytesIO(file.read()).getvalue() return image_file_name, data -def preprocess_image(data: bytes) -> List[float]: +def preprocess_image(data: bytes) -> list[float]: """Preprocess the image, resizing it and normalizing it before converting to a list. """ @@ -119,7 +117,7 @@ def preprocess_image(data: bytes) -> List[float]: class PostProcessor(beam.DoFn): - def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[str, PredictionResult]) -> Iterable[str]: img_name, prediction_result = element prediction_vals = prediction_result.inference index = prediction_vals.index(max(prediction_vals)) diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py index 3cf7d04cb03e..2708c0f3d1a1 100644 --- a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py +++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py @@ -25,7 +25,7 @@ import argparse import logging -from typing import Iterable +from collections.abc import Iterable import apache_beam as beam from apache_beam.ml.inference.base import PredictionResult diff --git a/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py b/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py index 963187fd210d..498511a5a2cf 100644 --- a/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py +++ b/sdks/python/apache_beam/examples/inference/xgboost_iris_classification.py @@ -17,10 +17,8 @@ import argparse import logging -from typing import Callable -from typing import Iterable -from typing import List -from typing import Tuple +from collections.abc import Callable +from collections.abc import Iterable from typing import Union import numpy @@ -48,7 +46,7 @@ class PostProcessor(beam.DoFn): """Process the PredictionResult to get the predicted label. Returns a comma separated string with true label and predicted label. """ - def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]: + def process(self, element: tuple[int, PredictionResult]) -> Iterable[str]: label, prediction_result = element prediction = prediction_result.inference yield '{},{}'.format(label, prediction) @@ -89,7 +87,7 @@ def parse_known_args(argv): def load_sklearn_iris_test_data( data_type: Callable, split: bool = True, - seed: int = 999) -> List[Union[numpy.array, pandas.DataFrame]]: + seed: int = 999) -> list[Union[numpy.array, pandas.DataFrame]]: """ Loads test data from the sklearn Iris dataset in a given format, either in a single or multiple batches. diff --git a/sdks/python/apache_beam/examples/kafkataxi/kafka_taxi.py b/sdks/python/apache_beam/examples/kafkataxi/kafka_taxi.py index 1cdd266c3df4..9b4889017077 100644 --- a/sdks/python/apache_beam/examples/kafkataxi/kafka_taxi.py +++ b/sdks/python/apache_beam/examples/kafkataxi/kafka_taxi.py @@ -26,7 +26,6 @@ import logging import sys -import typing import apache_beam as beam from apache_beam.io.kafka import ReadFromKafka @@ -97,7 +96,7 @@ def convert_kafka_record_to_dictionary(record): topic='projects/pubsub-public-data/topics/taxirides-realtime'). with_output_types(bytes) | beam.Map(lambda x: (b'', x)).with_output_types( - typing.Tuple[bytes, bytes]) # Kafka write transforms expects KVs. + tuple[bytes, bytes]) # Kafka write transforms expects KVs. | beam.WindowInto(beam.window.FixedWindows(window_size)) | WriteToKafka( producer_config={'bootstrap.servers': bootstrap_servers}, diff --git a/sdks/python/apache_beam/examples/wordcount_xlang_sql.py b/sdks/python/apache_beam/examples/wordcount_xlang_sql.py index 9d7d756f223f..632e90303010 100644 --- a/sdks/python/apache_beam/examples/wordcount_xlang_sql.py +++ b/sdks/python/apache_beam/examples/wordcount_xlang_sql.py @@ -24,7 +24,7 @@ import argparse import logging import re -import typing +from typing import NamedTuple import apache_beam as beam from apache_beam import coders @@ -41,7 +41,7 @@ # # Here we create and register a simple NamedTuple with a single str typed # field named 'word' which we will use below. -MyRow = typing.NamedTuple('MyRow', [('word', str)]) +MyRow = NamedTuple('MyRow', [('word', str)]) coders.registry.register_coder(MyRow, coders.RowCoder) From 02ab9dafd00e7c9d0ad62272de2694a5797e9a04 Mon Sep 17 00:00:00 2001 From: kushmiD Date: Fri, 1 Nov 2024 09:17:36 -0700 Subject: [PATCH 108/181] Support caching in Apache Beam by using relative co_filename paths. (#32979) * The motivation for this change is to support caching in Apache Beam. Apache Beam does the following: - Pickle Python code - Send the pickled source code to "worker" VMs - The workers unpickle and execute the code In the environment that these Beam pipelines execute, the source code is in a temporary directory whose name is random and changes. The source code paths relative to the temporary directory are constant. Using absolute paths prevents pickled code from being cached because the absolute path keeps changing. Using relative paths enables this caching and promises significant resource savings and speed-ups. Additionally the absolute paths leak information about the directory structure of the machine pickling the source code. When the pickled code is passed across the network to another machine, the absolute paths may no longer be valid when the other machine has a different directory structure. The reason for using relative paths rather than omitting the path entirely is because Python uses the co_filename attribute to create stack traces. * The motivation for this change is to support caching in Apache Beam for Google. Apache Beam does the following: - Pickle Python code - Send the pickled source code to "worker" VMs - The workers unpickle and execute the code In the environment that these Beam pipelines execute, the source code is in a temporary directory whose name is random and changes. The source code paths relative to the temporary directory are constant. Using absolute paths prevents pickled code from being cached because the absolute path keeps changing. Using relative paths enables this caching and promises significant resource savings and speed-ups. Additionally the absolute paths leak information about the directory structure of the machine pickling the source code. When the pickled code is passed across the network to another machine, the absolute paths may no longer be valid when the other machine has a different directory structure. The reason for using relative paths rather than omitting the path entirely is because Python uses the co_filename attribute to create stack traces. * Simplify. --------- Co-authored-by: Robert Bradshaw --- .../apache_beam/internal/dill_pickler.py | 19 +++++++++++++------ .../apache_beam/internal/pickler_test.py | 5 +++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/internal/dill_pickler.py b/sdks/python/apache_beam/internal/dill_pickler.py index 7f7ac5b214fa..e1d6b7e74e49 100644 --- a/sdks/python/apache_beam/internal/dill_pickler.py +++ b/sdks/python/apache_beam/internal/dill_pickler.py @@ -46,9 +46,15 @@ settings = {'dill_byref': None} -if sys.version_info >= (3, 10) and dill.__version__ == "0.3.1.1": - # Let's make dill 0.3.1.1 support Python 3.11. +patch_save_code = sys.version_info >= (3, 10) and dill.__version__ == "0.3.1.1" + +def get_normalized_path(path): + """Returns a normalized path. This function is intended to be overridden.""" + return path + + +if patch_save_code: # The following function is based on 'save_code' from 'dill' # Author: Mike McKerns (mmckerns @caltech and @uqfoundation) # Copyright (c) 2008-2015 California Institute of Technology. @@ -66,6 +72,7 @@ @dill.register(CodeType) def save_code(pickler, obj): + co_filename = get_normalized_path(obj.co_filename) if hasattr(obj, "co_endlinetable"): # python 3.11a (20 args) args = ( obj.co_argcount, @@ -78,7 +85,7 @@ def save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, + co_filename, obj.co_name, obj.co_qualname, obj.co_firstlineno, @@ -100,7 +107,7 @@ def save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, + co_filename, obj.co_name, obj.co_qualname, obj.co_firstlineno, @@ -120,7 +127,7 @@ def save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, + co_filename, obj.co_name, obj.co_firstlineno, obj.co_linetable, @@ -138,7 +145,7 @@ def save_code(pickler, obj): obj.co_consts, obj.co_names, obj.co_varnames, - obj.co_filename, + co_filename, obj.co_name, obj.co_firstlineno, obj.co_lnotab, diff --git a/sdks/python/apache_beam/internal/pickler_test.py b/sdks/python/apache_beam/internal/pickler_test.py index 824c4c59c0ce..c26a8ee3e653 100644 --- a/sdks/python/apache_beam/internal/pickler_test.py +++ b/sdks/python/apache_beam/internal/pickler_test.py @@ -94,6 +94,11 @@ def test_pickle_rlock(self): self.assertIsInstance(loads(dumps(rlock_instance)), rlock_type) + def test_save_paths(self): + f = loads(dumps(lambda x: x)) + co_filename = f.__code__.co_filename + self.assertTrue(co_filename.endswith('pickler_test.py')) + @unittest.skipIf(NO_MAPPINGPROXYTYPE, 'test if MappingProxyType introduced') def test_dump_and_load_mapping_proxy(self): self.assertEqual( From 61268ef9437c54c3842fb6edc9328f93c1b4d72d Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Fri, 1 Nov 2024 16:01:06 -0400 Subject: [PATCH 109/181] Temporariuly disable TensorRT PostCommit suite (#33002) --- sdks/python/test-suites/dataflow/common.gradle | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 845791e9c10f..71d44652bc7e 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -548,7 +548,8 @@ task mockAPITests { // TODO: https://github.com/apache/beam/issues/22651 project.tasks.register("inferencePostCommitIT") { dependsOn = [ - 'tensorRTtests', + // Temporarily disabled because of a container issue + // 'tensorRTtests', 'vertexAIInferenceTest', 'mockAPITests', ] From 90c1ee9ac5107491f1ac1e07ff9107c3423b74df Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Fri, 1 Nov 2024 17:38:48 -0400 Subject: [PATCH 110/181] Remove usage of deprecated _serialize (#33000) * Remove usage of deprecated _serialize * Correct assignment * indentation * lint * fmt * fmt --- sdks/python/apache_beam/ml/inference/onnx_inference.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index e7af114ad431..4ac856456748 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -116,7 +116,15 @@ def load_model(self) -> ort.InferenceSession: # when path is remote, we should first load into memory then deserialize f = FileSystems.open(self._model_uri, "rb") model_proto = onnx.load(f) - model_proto_bytes = onnx._serialize(model_proto) + model_proto_bytes = model_proto + if not isinstance(model_proto, bytes): + if (hasattr(model_proto, "SerializeToString") and + callable(model_proto.SerializeToString)): + model_proto_bytes = model_proto.SerializeToString() + else: + raise TypeError( + "No SerializeToString method is detected on loaded model. " + + f"Type of model: {type(model_proto)}") ort_session = ort.InferenceSession( model_proto_bytes, sess_options=self._session_options, From eed82f012e96ed66c1d17449493dd39a792d8284 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 1 Nov 2024 17:09:20 -0700 Subject: [PATCH 111/181] Add a simple validation transform to yaml. (#32956) --- sdks/python/apache_beam/yaml/json_utils.py | 61 +++++++++++++++++++ sdks/python/apache_beam/yaml/yaml_mapping.py | 39 +++++++++++- .../apache_beam/yaml/yaml_mapping_test.py | 37 +++++++++++ 3 files changed, 136 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/yaml/json_utils.py b/sdks/python/apache_beam/yaml/json_utils.py index 40e515ee6946..76cc80bc2036 100644 --- a/sdks/python/apache_beam/yaml/json_utils.py +++ b/sdks/python/apache_beam/yaml/json_utils.py @@ -106,6 +106,18 @@ def json_type_to_beam_type(json_type: Dict[str, Any]) -> schema_pb2.FieldType: raise ValueError(f'Unable to convert {json_type} to a Beam schema.') +def beam_schema_to_json_schema( + beam_schema: schema_pb2.Schema) -> Dict[str, Any]: + return { + 'type': 'object', + 'properties': { + field.name: beam_type_to_json_type(field.type) + for field in beam_schema.fields + }, + 'additionalProperties': False + } + + def beam_type_to_json_type(beam_type: schema_pb2.FieldType) -> Dict[str, Any]: type_info = beam_type.WhichOneof("type_info") if type_info == "atomic_type": @@ -267,3 +279,52 @@ def json_formater( convert = row_to_json( schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema))) return lambda row: json.dumps(convert(row), sort_keys=True).encode('utf-8') + + +def _validate_compatible(weak_schema, strong_schema): + if not weak_schema: + return + if weak_schema['type'] != strong_schema['type']: + raise ValueError( + 'Incompatible types: %r vs %r' % + (weak_schema['type'] != strong_schema['type'])) + if weak_schema['type'] == 'array': + _validate_compatible(weak_schema['items'], strong_schema['items']) + elif weak_schema == 'object': + for required in strong_schema.get('required', []): + if required not in weak_schema['properties']: + raise ValueError('Missing or unkown property %r' % required) + for name, spec in weak_schema.get('properties', {}): + if name in strong_schema['properties']: + try: + _validate_compatible(spec, strong_schema['properties'][name]) + except Exception as exn: + raise ValueError('Incompatible schema for %r' % name) from exn + elif not strong_schema.get('additionalProperties'): + raise ValueError( + 'Prohibited property: {property}; ' + 'perhaps additionalProperties: False is missing?') + + +def row_validator(beam_schema: schema_pb2.Schema, + json_schema: Dict[str, Any]) -> Callable[[Any], Any]: + """Returns a callable that will fail on elements not respecting json_schema. + """ + if not json_schema: + return lambda x: None + + # Validate that this compiles, but avoid pickling the validator itself. + _ = jsonschema.validators.validator_for(json_schema)(json_schema) + _validate_compatible(beam_schema_to_json_schema(beam_schema), json_schema) + validator = None + + convert = row_to_json( + schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema))) + + def validate(row): + nonlocal validator + if validator is None: + validator = jsonschema.validators.validator_for(json_schema)(json_schema) + validator.validate(convert(row)) + + return validate diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 377bcac0e31a..960fcdeecf30 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -43,6 +43,7 @@ from apache_beam.typehints.native_type_compatibility import convert_to_beam_type from apache_beam.typehints.row_type import RowTypeConstraint from apache_beam.typehints.schemas import named_fields_from_element_type +from apache_beam.typehints.schemas import schema_from_element_type from apache_beam.utils import python_callable from apache_beam.yaml import json_utils from apache_beam.yaml import options @@ -435,7 +436,8 @@ def _map_errors_to_standard_format(input_type): # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple. return beam.Map( - lambda x: beam.Row(element=x[0], msg=str(x[1][1]), stack=str(x[1][2])) + lambda x: beam.Row( + element=x[0], msg=str(x[1][1]), stack=''.join(x[1][2])) ).with_output_types( RowTypeConstraint.from_fields([("element", input_type), ("msg", str), ("stack", str)])) @@ -475,6 +477,40 @@ def expand(pcoll, error_handling=None, **kwargs): return expand +class _Validate(beam.PTransform): + """Validates each element of a PCollection against a json schema. + + Args: + schema: A json schema against which to validate each element. + error_handling: Whether and how to handle errors during iteration. + If this is not set, invalid elements will fail the pipeline, otherwise + invalid elements will be passed to the specified error output along + with information about how the schema was invalidated. + """ + def __init__( + self, + schema: Dict[str, Any], + error_handling: Optional[Mapping[str, Any]] = None): + self._schema = schema + self._exception_handling_args = exception_handling_args(error_handling) + + @maybe_with_exception_handling + def expand(self, pcoll): + validator = json_utils.row_validator( + schema_from_element_type(pcoll.element_type), self._schema) + + def invoke_validator(x): + validator(x) + return x + + return pcoll | beam.Map(invoke_validator) + + def with_exception_handling(self, **kwargs): + # It's possible there's an error in iteration... + self._exception_handling_args = kwargs + return self + + class _Explode(beam.PTransform): """Explodes (aka unnest/flatten) one or more fields producing multiple rows. @@ -797,6 +833,7 @@ def create_mapping_providers(): 'Partition-python': _Partition, 'Partition-javascript': _Partition, 'Partition-generic': _Partition, + 'ValidateWithSchema': _Validate, }), yaml_provider.SqlBackedProvider({ 'Filter-sql': _SqlFilterTransform, diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py b/sdks/python/apache_beam/yaml/yaml_mapping_test.py index 1b74a765e54b..2c5feec18278 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py @@ -134,6 +134,43 @@ def test_explode(self): beam.Row(a=3, b='y', c=.125, range=2), ])) + def test_validate(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(key='good', small=[5], nested=beam.Row(big=100)), + beam.Row(key='bad1', small=[500], nested=beam.Row(big=100)), + beam.Row(key='bad2', small=[5], nested=beam.Row(big=1)), + ]) + result = elements | YamlTransform( + ''' + type: ValidateWithSchema + config: + schema: + type: object + properties: + small: + type: array + items: + type: integer + maximum: 10 + nested: + type: object + properties: + big: + type: integer + minimum: 10 + error_handling: + output: bad + ''') + + assert_that( + result['good'] | beam.Map(lambda x: x.key), equal_to(['good'])) + assert_that( + result['bad'] | beam.Map(lambda x: x.element.key), + equal_to(['bad1', 'bad2']), + label='Errors') + def test_validate_explicit_types(self): with self.assertRaisesRegex(TypeError, r'.*violates schema.*'): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( From e85df7942df028969c8d43c64ecb0261e4de65d3 Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com> Date: Mon, 4 Nov 2024 12:41:52 +0300 Subject: [PATCH 112/181] Add Managed Iceberg example (#32678) * add iceberg example * runtime dependency * update to let the sink create tables * remove dependencies --- examples/java/build.gradle | 4 + .../cookbook/IcebergTaxiExamples.java | 119 ++++++++++++++++++ 2 files changed, 123 insertions(+) create mode 100644 examples/java/src/main/java/org/apache/beam/examples/cookbook/IcebergTaxiExamples.java diff --git a/examples/java/build.gradle b/examples/java/build.gradle index af91fa83fe91..4f1902cf1679 100644 --- a/examples/java/build.gradle +++ b/examples/java/build.gradle @@ -66,6 +66,8 @@ dependencies { implementation project(":sdks:java:extensions:python") implementation project(":sdks:java:io:google-cloud-platform") implementation project(":sdks:java:io:kafka") + runtimeOnly project(":sdks:java:io:iceberg") + implementation project(":sdks:java:managed") implementation project(":sdks:java:extensions:ml") implementation library.java.avro implementation library.java.bigdataoss_util @@ -100,6 +102,8 @@ dependencies { implementation "org.apache.httpcomponents:httpcore:4.4.13" implementation "com.fasterxml.jackson.core:jackson-annotations:2.14.1" implementation "com.fasterxml.jackson.core:jackson-core:2.14.1" + runtimeOnly library.java.hadoop_client + runtimeOnly library.java.bigdataoss_gcs_connector testImplementation project(path: ":runners:direct-java", configuration: "shadow") testImplementation project(":sdks:java:io:google-cloud-platform") testImplementation project(":sdks:java:extensions:ml") diff --git a/examples/java/src/main/java/org/apache/beam/examples/cookbook/IcebergTaxiExamples.java b/examples/java/src/main/java/org/apache/beam/examples/cookbook/IcebergTaxiExamples.java new file mode 100644 index 000000000000..446d11d03be4 --- /dev/null +++ b/examples/java/src/main/java/org/apache/beam/examples/cookbook/IcebergTaxiExamples.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.examples.cookbook; + +import java.util.Arrays; +import java.util.Map; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubIO; +import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.transforms.Filter; +import org.apache.beam.sdk.transforms.JsonToRow; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; + +/** + * Reads real-time NYC taxi ride information from {@code + * projects/pubsub-public-data/topics/taxirides-realtime} and writes to Iceberg tables using Beam's + * {@link Managed} IcebergIO sink. + * + *

This is a streaming pipeline that writes records to Iceberg tables dynamically, depending on + * each record's passenger count. New tables are created as needed. We set a triggering frequency of + * 10s; at around this interval, the sink will accumulate records and write them to the appropriate + * table, creating a new snapshot each time. + */ +public class IcebergTaxiExamples { + private static final String TAXI_RIDES_TOPIC = + "projects/pubsub-public-data/topics/taxirides-realtime"; + private static final Schema TAXI_RIDE_INFO_SCHEMA = + Schema.builder() + .addStringField("ride_id") + .addInt32Field("point_idx") + .addDoubleField("latitude") + .addDoubleField("longitude") + .addStringField("timestamp") + .addDoubleField("meter_reading") + .addDoubleField("meter_increment") + .addStringField("ride_status") + .addInt32Field("passenger_count") + .build(); + + public static void main(String[] args) { + IcebergPipelineOptions options = + PipelineOptionsFactory.fromArgs(args).as(IcebergPipelineOptions.class); + options.setProject("apache-beam-testing"); + + // each record's 'passenger_count' value will be substituted in to determine + // its final table destination + // e.g. an event with 3 passengers will be written to 'iceberg_taxi.3_passengers' + String tableIdentifierTemplate = "iceberg_taxi.{passenger_count}_passengers"; + + Map catalogProps = + ImmutableMap.builder() + .put("catalog-impl", options.getCatalogImpl()) + .put("warehouse", options.getWarehouse()) + .build(); + Map icebergWriteConfig = + ImmutableMap.builder() + .put("table", tableIdentifierTemplate) + .put("catalog_name", options.getCatalogName()) + .put("catalog_properties", catalogProps) + .put("triggering_frequency_seconds", 10) + // perform a final filter to only write these two columns + .put("keep", Arrays.asList("ride_id", "meter_reading")) + .build(); + + Pipeline p = Pipeline.create(options); + p + // Read taxi ride data + .apply(PubsubIO.readStrings().fromTopic(TAXI_RIDES_TOPIC)) + // Convert JSON strings to Beam Rows + .apply(JsonToRow.withSchema(TAXI_RIDE_INFO_SCHEMA)) + // Filter to only include drop-offs + .apply(Filter.create().whereFieldName("ride_status", "dropoff"::equals)) + // Write to Iceberg tables + .apply(Managed.write(Managed.ICEBERG).withConfig(icebergWriteConfig)); + p.run(); + } + + public interface IcebergPipelineOptions extends GcpOptions { + @Description("Warehouse location where the table's data will be written to.") + @Default.String("gs://apache-beam-samples/iceberg-examples") + String getWarehouse(); + + void setWarehouse(String warehouse); + + @Description("Fully-qualified name of the catalog class to use.") + @Default.String("org.apache.iceberg.hadoop.HadoopCatalog") + String getCatalogImpl(); + + void setCatalogImpl(String catalogName); + + @Validation.Required + @Default.String("example-catalog") + String getCatalogName(); + + void setCatalogName(String catalogName); + } +} From adc714360579f4a58726be7f8eb5bf01a709cbb9 Mon Sep 17 00:00:00 2001 From: liferoad Date: Mon, 4 Nov 2024 09:55:06 -0500 Subject: [PATCH 113/181] Update code-change-guide.md --- contributor-docs/code-change-guide.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/contributor-docs/code-change-guide.md b/contributor-docs/code-change-guide.md index f0785d3509d0..efb57973592e 100644 --- a/contributor-docs/code-change-guide.md +++ b/contributor-docs/code-change-guide.md @@ -115,6 +115,8 @@ To run a Gradle task, use the command `./gradlew -p ` or th ./gradlew :sdks:java:harness:test ``` +**It is recommended to run `./gradlew clean` if you run into some strange errors such as `java.lang.NoClassDefFoundError`.** + #### Beam-specific Gradle project configuration For Apache Beam, one plugin manages everything: `buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin`. @@ -145,7 +147,7 @@ in the Google Cloud documentation. Depending on the languages involved, your `PATH` file needs to have the following elements configured. -* A Java environment that uses a supported Java version, preferably Java 8. +* A Java environment that uses a supported Java version, preferably Java 11. * This environment is needed for all development, because Beam is a Gradle project that uses JVM. * Recommended: To manage Java versions, use [sdkman](https://sdkman.io/install). From 76c5d56ea95ee1497cb80dfe1de1d1abd8d34b15 Mon Sep 17 00:00:00 2001 From: Hai Joey Tran Date: Mon, 4 Nov 2024 10:18:05 -0500 Subject: [PATCH 114/181] Update typecheck err msg (#32880) * Update typecheck err msg * update a few typed_pipeline_test unit tests * Move new logic to only apply to main element * fix tests * remove please comment * update tests again * remove debug str * fix more tests * fix pubsub test * revert accidental pyproject change * revert pyrpoject change * fix ptransform_test tests * add explanatory comments for similar looking code --- sdks/python/apache_beam/io/gcp/pubsub_test.py | 3 +- .../apache_beam/transforms/ptransform.py | 25 +++++- .../apache_beam/transforms/ptransform_test.py | 90 ++++++------------- .../apache_beam/typehints/decorators_test.py | 4 +- .../typehints/typed_pipeline_test.py | 30 +++---- 5 files changed, 66 insertions(+), 86 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index 2e3e9b301618..73ba8d6abdb6 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -901,7 +901,8 @@ def test_write_messages_with_attributes_error(self, mock_pubsub): options = PipelineOptions([]) options.view_as(StandardOptions).streaming = True - with self.assertRaisesRegex(Exception, r'Type hint violation'): + with self.assertRaisesRegex(Exception, + r'requires.*PubsubMessage.*applied.*str'): with TestPipeline(options=options) as p: _ = ( p diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 6ec741705376..4848dc4aade8 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -497,13 +497,12 @@ def type_check_inputs_or_outputs(self, pvalueish, input_or_output): at_context = ' %s %s' % (input_or_output, context) if context else '' raise TypeCheckError( '{type} type hint violation at {label}{context}: expected {hint}, ' - 'got {actual_type}\nFull type hint:\n{debug_str}'.format( + 'got {actual_type}'.format( type=input_or_output.title(), label=self.label, context=at_context, hint=hint, - actual_type=pvalue_.element_type, - debug_str=type_hints.debug_str())) + actual_type=pvalue_.element_type)) def _infer_output_coder(self, input_type=None, input_coder=None): # type: (...) -> Optional[coders.Coder] @@ -939,7 +938,25 @@ def element_type(side_input): bindings = getcallargs_forhints(argspec_fn, *arg_types, **kwargs_types) hints = getcallargs_forhints( argspec_fn, *input_types[0], **input_types[1]) - for arg, hint in hints.items(): + + # First check the main input. + arg_hints = iter(hints.items()) + element_arg, element_hint = next(arg_hints) + if not typehints.is_consistent_with( + bindings.get(element_arg, typehints.Any), element_hint): + transform_nest_level = self.label.count("/") + split_producer_label = pvalueish.producer.full_label.split("/") + producer_label = "/".join( + split_producer_label[:transform_nest_level + 1]) + raise TypeCheckError( + f"The transform '{self.label}' requires " + f"PCollections of type '{element_hint}' " + f"but was applied to a PCollection of type" + f" '{bindings[element_arg]}' " + f"(produced by the transform '{producer_label}'). ") + + # Now check the side inputs. + for arg, hint in arg_hints: if arg.startswith('__unknown__'): continue if hint is None: diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index 0acea547ccdc..7db017a59158 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -1298,17 +1298,13 @@ class ToUpperCaseWithPrefix(beam.DoFn): def process(self, element, prefix): return [prefix + element.upper()] - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Upper.*requires.*str.*applied.*int'): ( self.p | 'T' >> beam.Create([1, 2, 3]).with_output_types(int) | 'Upper' >> beam.ParDo(ToUpperCaseWithPrefix(), 'hello')) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Upper': " - "requires {} but got {} for element".format(str, int)) - def test_do_fn_pipeline_runtime_type_check_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True @@ -1335,18 +1331,14 @@ class AddWithNum(beam.DoFn): def process(self, element, num): return [element + num] - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Add.*requires.*int.*applied.*str'): ( self.p | 'T' >> beam.Create(['1', '2', '3']).with_output_types(str) | 'Add' >> beam.ParDo(AddWithNum(), 5)) self.p.run() - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Add': " - "requires {} but got {} for element".format(int, str)) - def test_pardo_does_not_type_check_using_type_hint_decorators(self): @with_input_types(a=int) @with_output_types(typing.List[str]) @@ -1355,17 +1347,13 @@ def int_to_str(a): # The function above is expecting an int for its only parameter. However, it # will receive a str instead, which should result in a raised exception. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'ToStr.*requires.*int.*applied.*str'): ( self.p | 'S' >> beam.Create(['b', 'a', 'r']).with_output_types(str) | 'ToStr' >> beam.FlatMap(int_to_str)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'ToStr': " - "requires {} but got {} for a".format(int, str)) - def test_pardo_properly_type_checks_using_type_hint_decorators(self): @with_input_types(a=str) @with_output_types(typing.List[str]) @@ -1387,7 +1375,8 @@ def to_all_upper_case(a): def test_pardo_does_not_type_check_using_type_hint_methods(self): # The first ParDo outputs pcoll's of type int, however the second ParDo is # expecting pcoll's of type str instead. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Upper.*requires.*str.*applied.*int'): ( self.p | 'S' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str) @@ -1398,11 +1387,6 @@ def test_pardo_does_not_type_check_using_type_hint_methods(self): 'Upper' >> beam.FlatMap(lambda x: [x.upper()]).with_input_types( str).with_output_types(str))) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Upper': " - "requires {} but got {} for x".format(str, int)) - def test_pardo_properly_type_checks_using_type_hint_methods(self): # Pipeline should be created successfully without an error d = ( @@ -1419,18 +1403,14 @@ def test_pardo_properly_type_checks_using_type_hint_methods(self): def test_map_does_not_type_check_using_type_hints_methods(self): # The transform before 'Map' has indicated that it outputs PCollections with # int's, while Map is expecting one of str. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Upper.*requires.*str.*applied.*int'): ( self.p | 'S' >> beam.Create([1, 2, 3, 4]).with_output_types(int) | 'Upper' >> beam.Map(lambda x: x.upper()).with_input_types( str).with_output_types(str)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Upper': " - "requires {} but got {} for x".format(str, int)) - def test_map_properly_type_checks_using_type_hints_methods(self): # No error should be raised if this type-checks properly. d = ( @@ -1449,17 +1429,13 @@ def upper(s): # Hinted function above expects a str at pipeline construction. # However, 'Map' should detect that Create has hinted an int instead. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Upper.*requires.*str.*applied.*int'): ( self.p | 'S' >> beam.Create([1, 2, 3, 4]).with_output_types(int) | 'Upper' >> beam.Map(upper)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Upper': " - "requires {} but got {} for s".format(str, int)) - def test_map_properly_type_checks_using_type_hints_decorator(self): @with_input_types(a=bool) @with_output_types(int) @@ -1477,7 +1453,8 @@ def bool_to_int(a): def test_filter_does_not_type_check_using_type_hints_method(self): # Filter is expecting an int but instead looks to the 'left' and sees a str # incoming. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Below 3.*requires.*int.*applied.*str'): ( self.p | 'Strs' >> beam.Create(['1', '2', '3', '4', '5' @@ -1486,11 +1463,6 @@ def test_filter_does_not_type_check_using_type_hints_method(self): str).with_output_types(str) | 'Below 3' >> beam.Filter(lambda x: x < 3).with_input_types(int)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Below 3': " - "requires {} but got {} for x".format(int, str)) - def test_filter_type_checks_using_type_hints_method(self): # No error should be raised if this type-checks properly. d = ( @@ -1508,17 +1480,13 @@ def more_than_half(a): return a > 0.50 # Func above was hinted to only take a float, yet a str will be passed. - with self.assertRaises(typehints.TypeCheckError) as e: + with self.assertRaisesRegex(typehints.TypeCheckError, + r'Half.*requires.*float.*applied.*str'): ( self.p | 'Ints' >> beam.Create(['1', '2', '3', '4']).with_output_types(str) | 'Half' >> beam.Filter(more_than_half)) - self.assertStartswith( - e.exception.args[0], - "Type hint violation for 'Half': " - "requires {} but got {} for a".format(float, str)) - def test_filter_type_checks_using_type_hints_decorator(self): @with_input_types(b=int) def half(b): @@ -2128,14 +2096,10 @@ def test_mean_globally_pipeline_checking_violated(self): self.p | 'C' >> beam.Create(['test']).with_output_types(str) | 'Mean' >> combine.Mean.Globally()) - - expected_msg = \ - "Type hint violation for 'CombinePerKey': " \ - "requires Tuple[TypeVariable[K], Union[, , " \ - ", ]] " \ - "but got Tuple[None, ] for element" - - self.assertStartswith(e.exception.args[0], expected_msg) + err_msg = e.exception.args[0] + assert "CombinePerKey" in err_msg + assert "Tuple[TypeVariable[K]" in err_msg + assert "Tuple[None, " in err_msg def test_mean_globally_runtime_checking_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True @@ -2195,14 +2159,12 @@ def test_mean_per_key_pipeline_checking_violated(self): typing.Tuple[str, str])) | 'EvenMean' >> combine.Mean.PerKey()) self.p.run() - - expected_msg = \ - "Type hint violation for 'CombinePerKey(MeanCombineFn)': " \ - "requires Tuple[TypeVariable[K], Union[, , " \ - ", ]] " \ - "but got Tuple[, ] for element" - - self.assertStartswith(e.exception.args[0], expected_msg) + err_msg = e.exception.args[0] + assert "CombinePerKey(MeanCombineFn)" in err_msg + assert "requires" in err_msg + assert "Tuple[TypeVariable[K]" in err_msg + assert "applied" in err_msg + assert "Tuple[, ]" in err_msg def test_mean_per_key_runtime_checking_satisfied(self): self.p._options.view_as(TypeOptions).runtime_type_check = True diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index 3baf9fa8322f..71edc75f31a6 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -409,7 +409,7 @@ def fn(a: int) -> int: return a with self.assertRaisesRegex(TypeCheckError, - r'requires .*int.* but got .*str'): + r'requires .*int.* but was applied .*str'): _ = ['a', 'b', 'c'] | Map(fn) # Same pipeline doesn't raise without annotations on fn. @@ -423,7 +423,7 @@ def fn(a: int) -> int: _ = [1, 2, 3] | Map(fn) # Doesn't raise - correct types. with self.assertRaisesRegex(TypeCheckError, - r'requires .*int.* but got .*str'): + r'requires .*int.* but was applied .*str'): _ = ['a', 'b', 'c'] | Map(fn) @decorators.no_annotations diff --git a/sdks/python/apache_beam/typehints/typed_pipeline_test.py b/sdks/python/apache_beam/typehints/typed_pipeline_test.py index 72aed46f5e78..57e7f44f6922 100644 --- a/sdks/python/apache_beam/typehints/typed_pipeline_test.py +++ b/sdks/python/apache_beam/typehints/typed_pipeline_test.py @@ -88,11 +88,11 @@ def process(self, element): self.assertEqual(['1', '2', '3'], sorted(result)) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): ['a', 'b', 'c'] | beam.ParDo(MyDoFn()) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): [1, 2, 3] | (beam.ParDo(MyDoFn()) | 'again' >> beam.ParDo(MyDoFn())) def test_typed_dofn_method(self): @@ -104,11 +104,11 @@ def process(self, element: int) -> typehints.Tuple[str]: self.assertEqual(['1', '2', '3'], sorted(result)) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = ['a', 'b', 'c'] | beam.ParDo(MyDoFn()) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = [1, 2, 3] | (beam.ParDo(MyDoFn()) | 'again' >> beam.ParDo(MyDoFn())) def test_typed_dofn_method_with_class_decorators(self): @@ -124,12 +124,12 @@ def process(self, element: int) -> typehints.Tuple[str]: with self.assertRaisesRegex( typehints.TypeCheckError, - r'requires.*Tuple\[, \].*got.*str'): + r'requires.*Tuple\[, \].*applied.*str'): _ = ['a', 'b', 'c'] | beam.ParDo(MyDoFn()) with self.assertRaisesRegex( typehints.TypeCheckError, - r'requires.*Tuple\[, \].*got.*int'): + r'requires.*Tuple\[, \].*applied.*int'): _ = [1, 2, 3] | (beam.ParDo(MyDoFn()) | 'again' >> beam.ParDo(MyDoFn())) def test_typed_callable_iterable_output(self): @@ -156,11 +156,11 @@ def process(self, element: typehints.Tuple[int, int]) -> \ self.assertEqual(['1', '2', '3'], sorted(result)) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = ['a', 'b', 'c'] | beam.ParDo(my_do_fn) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = [1, 2, 3] | (beam.ParDo(my_do_fn) | 'again' >> beam.ParDo(my_do_fn)) def test_typed_callable_instance(self): @@ -177,11 +177,11 @@ def do_fn(element: typehints.Tuple[int, int]) -> typehints.Generator[str]: self.assertEqual(['1', '2', '3'], sorted(result)) with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = ['a', 'b', 'c'] | pardo with self.assertRaisesRegex(typehints.TypeCheckError, - r'requires.*int.*got.*str'): + r'requires.*int.*applied.*str'): _ = [1, 2, 3] | (pardo | 'again' >> pardo) def test_filter_type_hint(self): @@ -430,7 +430,7 @@ def fn(element: float): return pcoll | beam.ParDo(fn) with self.assertRaisesRegex(typehints.TypeCheckError, - r'ParDo.*requires.*float.*got.*str'): + r'ParDo.*requires.*float.*applied.*str'): _ = ['1', '2', '3'] | MyMap() with self.assertRaisesRegex(typehints.TypeCheckError, r'MyMap.*expected.*str.*got.*bytes'): @@ -632,14 +632,14 @@ def produces_unkown(e): return e @typehints.with_input_types(int) - def requires_int(e): + def accepts_int(e): return e class MyPTransform(beam.PTransform): def expand(self, pcoll): unknowns = pcoll | beam.Map(produces_unkown) ints = pcoll | beam.Map(int) - return (unknowns, ints) | beam.Flatten() | beam.Map(requires_int) + return (unknowns, ints) | beam.Flatten() | beam.Map(accepts_int) _ = [1, 2, 3] | MyPTransform() @@ -761,8 +761,8 @@ def test_var_positional_only_side_input_hint(self): with self.assertRaisesRegex( typehints.TypeCheckError, - r'requires Tuple\[Union\[, \], ...\] but ' - r'got Tuple\[Union\[, \], ...\]'): + r'requires.*Tuple\[Union\[, \], ...\].*' + r'applied.*Tuple\[Union\[, \], ...\]'): _ = [1.2] | beam.Map(lambda *_: 'a', 5).with_input_types(int, str) def test_var_keyword_side_input_hint(self): From beba671ad98543015144dbbdeadd19c22200f135 Mon Sep 17 00:00:00 2001 From: liferoad Date: Mon, 4 Nov 2024 10:51:30 -0500 Subject: [PATCH 115/181] Update code-change-guide.md --- contributor-docs/code-change-guide.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/contributor-docs/code-change-guide.md b/contributor-docs/code-change-guide.md index efb57973592e..b4300103454c 100644 --- a/contributor-docs/code-change-guide.md +++ b/contributor-docs/code-change-guide.md @@ -115,8 +115,6 @@ To run a Gradle task, use the command `./gradlew -p ` or th ./gradlew :sdks:java:harness:test ``` -**It is recommended to run `./gradlew clean` if you run into some strange errors such as `java.lang.NoClassDefFoundError`.** - #### Beam-specific Gradle project configuration For Apache Beam, one plugin manages everything: `buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin`. @@ -626,6 +624,11 @@ Tips for using the Dataflow runner: ## Appendix +### Common Issues + +* If you run into some strange errors such as `java.lang.NoClassDefFoundError`, run `./gradlew clean` first +* To run one single Java test with gradle, use `--tests` to filter, for example, `./gradlew :it:google-cloud-platform:WordCountIntegrationTest --tests "org.apache.beam.it.gcp.WordCountIT.testWordCountDataflow"` + ### Directories of snapshot builds * https://repository.apache.org/content/groups/snapshots/org/apache/beam/ Java SDK build (nightly) From c483d4c74af7c09dd688565933420e3d4a8ccec1 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Mon, 4 Nov 2024 13:57:53 -0500 Subject: [PATCH 116/181] apply change to all IOs --- .../beam_PostCommit_Java_Hadoop_Versions.json | 3 ++- sdks/java/io/hadoop-file-system/build.gradle | 4 ++-- sdks/java/io/hadoop-format/build.gradle | 4 ++-- sdks/java/io/hcatalog/build.gradle | 5 +++-- sdks/java/io/iceberg/build.gradle | 8 ++++---- sdks/java/io/parquet/build.gradle | 4 ++-- 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json b/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json index 08c2e40784a9..920c8d132e4a 100644 --- a/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json +++ b/.github/trigger_files/beam_PostCommit_Java_Hadoop_Versions.json @@ -1,3 +1,4 @@ { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 } \ No newline at end of file diff --git a/sdks/java/io/hadoop-file-system/build.gradle b/sdks/java/io/hadoop-file-system/build.gradle index 3fc872bb5d02..fafa8b5c7e34 100644 --- a/sdks/java/io/hadoop-file-system/build.gradle +++ b/sdks/java/io/hadoop-file-system/build.gradle @@ -26,10 +26,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Hadoop File System" ext.summary = "Library to read and write Hadoop/HDFS file formats from Beam." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/hadoop-format/build.gradle b/sdks/java/io/hadoop-format/build.gradle index dbb9f8fdd73d..4664005a1fc8 100644 --- a/sdks/java/io/hadoop-format/build.gradle +++ b/sdks/java/io/hadoop-format/build.gradle @@ -30,10 +30,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Hadoop Format" ext.summary = "IO to read data from sources and to write data to sinks that implement Hadoop MapReduce Format." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/hcatalog/build.gradle b/sdks/java/io/hcatalog/build.gradle index c4f1b76ec390..364c10fa738b 100644 --- a/sdks/java/io/hcatalog/build.gradle +++ b/sdks/java/io/hcatalog/build.gradle @@ -30,9 +30,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: HCatalog" ext.summary = "IO to read and write for HCatalog source." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", + "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/iceberg/build.gradle b/sdks/java/io/iceberg/build.gradle index e10c6f38e20f..6754b0aecf50 100644 --- a/sdks/java/io/iceberg/build.gradle +++ b/sdks/java/io/iceberg/build.gradle @@ -29,10 +29,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Iceberg" ext.summary = "Integration with Iceberg data warehouses." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", - "2102": "2.10.2", - "324": "3.2.4", + "2102": "2.10.2", + "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} diff --git a/sdks/java/io/parquet/build.gradle b/sdks/java/io/parquet/build.gradle index e8f1603f0b58..d5f22b31cc56 100644 --- a/sdks/java/io/parquet/build.gradle +++ b/sdks/java/io/parquet/build.gradle @@ -27,10 +27,10 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Parquet" ext.summary = "IO to read and write on Parquet storage format." def hadoopVersions = [ - "285": "2.8.5", - "292": "2.9.2", "2102": "2.10.2", "324": "3.2.4", + "336": "3.3.6", + "341": "3.4.1", ] hadoopVersions.each {kv -> configurations.create("hadoopVersion$kv.key")} From 77810d1bf9cff9f8b83746ac7c4f6644ae462517 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Mon, 4 Nov 2024 15:30:27 -0500 Subject: [PATCH 117/181] Disable gradle cache for spark job server shadowjar (#33010) --- .github/trigger_files/beam_PostCommit_Python.json | 2 +- runners/spark/job-server/spark_job_server.gradle | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 30ee463ad4e9..1eb60f6e4959 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 2 + "modification": 3 } diff --git a/runners/spark/job-server/spark_job_server.gradle b/runners/spark/job-server/spark_job_server.gradle index 5ed5f4277bf4..90109598ed64 100644 --- a/runners/spark/job-server/spark_job_server.gradle +++ b/runners/spark/job-server/spark_job_server.gradle @@ -301,3 +301,7 @@ createCrossLanguageValidatesRunnerTask( "--endpoint localhost:${jobPort}", ], ) + +shadowJar { + outputs.upToDateWhen { false } +} From 14be0f6613b1454ddbbbe42e61883aa46573821b Mon Sep 17 00:00:00 2001 From: Ahmet Altay Date: Mon, 4 Nov 2024 23:56:33 +0000 Subject: [PATCH 118/181] Remove Behalf from powered-by logos. It looks like a defunct company, www.behalf.com link goes to domain parking web site. --- .../site/content/en/case-studies/behalf.md | 19 ------------------ .../static/images/logos/powered-by/behalf.png | Bin 4748 -> 0 bytes 2 files changed, 19 deletions(-) delete mode 100644 website/www/site/content/en/case-studies/behalf.md delete mode 100644 website/www/site/static/images/logos/powered-by/behalf.png diff --git a/website/www/site/content/en/case-studies/behalf.md b/website/www/site/content/en/case-studies/behalf.md deleted file mode 100644 index e5a240a03d4f..000000000000 --- a/website/www/site/content/en/case-studies/behalf.md +++ /dev/null @@ -1,19 +0,0 @@ ---- -title: "Behalf" -icon: /images/logos/powered-by/behalf.png -hasLink: "https://www.behalf.com/" ---- - - diff --git a/website/www/site/static/images/logos/powered-by/behalf.png b/website/www/site/static/images/logos/powered-by/behalf.png deleted file mode 100644 index 346ec880d764e218803d90845af8bf6da9160a16..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4748 zcmV;75_9c|P)7l?g= zI48gyU@~Sqku^qU3)nEm zuA^5lSp91I2U1FO&+l~@V?)MRUm~rzrysyH43=w^uM63w;o>j&D;TVlK6Beq5}gE} zXpoIRb|+!5lzQA!8gt#koLvX^B>k*S(o2a>=?NG7i{|&ig&)eKHU7NYp>$|R^LuF% zT>2?RdZ!n*!lHYAujS9(%bdvnKrs(-O43;y| z-|Wx3{C$0>_ARvv9_1&&U?u!Buz^C+YnD7hSazzh|E%R*wTpXtf=^wj+T@B(i_WI< zU>Vs_g{h7$Mfj=B-^81O|G#K(u{WiEXhOfY;`h>K_|N~$pLg&4d3R{*=a4D7=l3QJ zm+yAy(OC$U$u~46H&To6{(_AB>if@H}b*rd!kVUQN14c!vCi~?@nny z!wRw@`oy3&FZy1D%aT9~d@p6~&%0SiK8)hkVw5(8{x$4<3N1`gYHW;iTCf^hok@%6 zCN4Pog1-+Z_nLu6ccXLnG|@$+s*_`S(i8#Fn`4h;H?lI5Cc3GQ{=!@kXR#Y+hblA1 z?W(>I-Sc}UFDyhKizqWT&nJK0U7J*uNk(+h4Voqp6}|PB3=Cn9sxn68?KX)n8Xb_n zn8SNh7O6!?089d9Lp#|qL3W!&7p0X~#SCYMc2gUx$XF%1;2|&M=;rcgQ?e&TH(e5w z*APIAX%lHoQgl4_LDe*pvqN;y1d;5uJ)vKS9tO)_*l+wetZ|YCrVQ(I;gROn%L1<=H?N7=Kc6C z^e=}mP{Ux^r|6D~45*qfx@zlh=w6>x{gOxbFj#g-_zR>ophC1POrsqA*!%Hax>!%b z^KmTiNcu5tG~Fa#0H|`Tqrcfe)%+BGSMrk7U*5t4d@C zpkh9h-9$<&FY?GPyg&t@CQ8-Sv+C%49?-S-PkSn1XUt z_t;;5h)mSwHS<6+L20n@o%F=|cyj95 z{Qh-ADb0iiXt~NJ7_6j9-!8U|ZZ&)aVlY^Z9KS!XCtFj8=#D~z%k~{^F(`HR1)@6) zE!emaF3K)dd4cGTL(`mHlpPTryu?x2cVy0~L-dwFvJV|wph9%;nqaV6lkBb_Tw;Xi zEh8AL77q2*L_`Oo1D z$kBliw%-=p6ZuULz2(rgG^;m}M7r)mbif>#EnAqULv+B}Fhm)mgJ`$ZW(s#mAv%a~ zy6Z7S2N7$RKM9BqB5ZC^g2gBh#^#8Q@*)uCm&+d&L^SKcl!=Rb0$(+WDvcb&<2t1%|#=Gt^h*xb~3TO(n8#e zR9qNDZy%Jhd~xI(7fM?gL~j=~*vgfZg+X-i3d+kDj-)IR5FNZ?=GX~Eo}UZ1(SztM zz@K-yR}Xb2{Z`_4K(I(xQ9puRp>n-%-_sC#Rmx z@4s%SmcCCf8&5Ctv_-mO0T!bzgGzFD@u*10br0#Lk?A?T(qJ)&4x&-%{57FUO`l#S zmL#Vte+a1y$izVDH^E?~)FToR9b`f$1NmmtjB7%88Nu*Wv9_YRf#2D5s|=&r6w~bp72qC`WUKh+ZAuL2uOXV?K=yJprysel3&e zD3_1E-zrW=^7?;A^pZsy5M47A!Kaq&N^YoJw+L^SrKi%7@Y4z+{5RSJQwE}Ig7WVw ze#mukPtwoIN>3@xR8tOLI+`H*!J)1av_Nu3sn}2Npm*&Z^px65Q$*euZ66n7r1JIBfav9z zy52xP1}|g&qbfW%y}Pap9%YieRwf!QppW2jro( z`6)zKr_=_9190L0=Nl)?!U3|4E}K)cBJGv8z`K8pa+74QyvLx*qg$ikZ;-O&?H z7+x+61}pwKeop+EjgBfy65B>ei0JY-J$P0lyduR@ z8KsL7x`F5tP;~)mi#U5Ed*e?*h%Sc_(aAERRHo6PU=UpnD?P?_ z!Bdsr!`r3sSjP;>BO=+g?pZ)|c`7@2WFhCWm6Za?Becs7(dB{YPd#8hhPO-M(~0a! z^1guN5qmmsaY1xR!rSG7e$Z`&NAZ2xUXHutE1tjqE!pzo)9(06OD6z`u7*ov#zncI zTwE%OvnR=UWd2j}>z?UYoZ{kz8dJLR)rKRx_jG|mI%T6Ad!b=Y(t^wWaCW4;kAaA; z07W!y&2^cx0!>o2Cu|&?-0k zucE@E_*Wf&pCnV++g^z1nlWg!%MCqcw2WW~DlGfm0japa5W!W5}XMTECY9@VEB$1g@V>3`9GqnVycrC#}SU|i0Cy!!bTz+c?yyinSBFta1n%~ zIz$JL(3bs(B(Gv9ZdD8+I>?1;;C+(C*4^Fryb4rQfr02P!x6cYjEW9@;RrxG={q(} zZ;vWNnLzZG@o;u@_i%Q^rTFOBO{v@t1Q5NY&{>gvhmIQe^Da|E^cKVG&ktuu(PTK= z+j#8cjopFhfKkQb&;i*YI>_hY>}YNI*?LB|kW$kZhz_1{ME2CE4$;92=-kNA;xq=L z12JNe-A36{>*)|3q(ZXW*wSMsZSx4CgG{IzV`H5iq62A07Ve-zbRY*UPHjtJUR@M{ z=%A2?v!exFwzh_3V^W9?M6oTZLv$byRbgyem_l@*j7@#vAv#clj`rGeJOHADQfQki zR533uctdm`i#5Z_;ya`e9Y{fS^-J46HHZ#WsJsoHCJ@m99jKDobl*fEI#4BRAh3}@bfAh2qTfZofapLP z8sTWe&|QcQN|8J2&_rw(-MX7Dj$n)EC@+FiWDdnLjOT-A5FJRdJHF~$lAkj~bWjFc zqWP9ub36jkfiN`M@RKE<*o%uH5FN;1Q|oAm4&!>PAa)9IWMPdwkKAJlBBJw0!ZbWN zwQ+|G5R{?}EnjSAu}L1{N>U^`KO8eeieeuWxew{;+o1zFcpim^&hkY6hA3_abS5Pd zyF*>HY~0mU;r`+5s4$&Q6MaijJaYwWGP~oep>3UB1=&wD_I%eOnez}jLYx$xOMa-k zROtCVT@7#6VX!Q!%?~X+RtnVEQU$P(xai7~dwy>}mJ4<&%m+WBI~RyC$tx+wiHqKs zk=*lp9g3bq#}_)lOn7cCeDf(1{d0J7DikR_zZcbEzd``eAX0@AF8ZbDCJ^P?3&%DZ zfPLJER9fP^q12-g&+qYnLH|H_vGYI-N+wNkq^jT=K)Ri_ave^U20_~}Sc<%_{Tkj2kmsU~%1Rz-RhNZMYPk|cpuMK5j_l+ zGXv9?!p9wTl*NrGbKp_DV}j%?r;a`~s1w(XVA`Vfv>Q<7o}#!>DR7?Xri}g&y(`$S z_2Ih>+(Z(VTjs<9Km%e4{kQ*ufE*M*zZLx$yknMT31^ zWSZ3=ZE1!JYTa!7=%1pqE~AFVglr$%SPLBDq9Ii1SWq8ZN^NI02p-{uVVM6Lp4voT z*C_kbK!T^ut3r0iS9eHTYT-U6xOJyU3q@!3$}TJU1bx~oC7Q-OGp&jui;(E-KiR9C zG`yj7Md)!zb@QJUy_dFgNrLB2W3~xl(IaRWfX> Date: Tue, 5 Nov 2024 15:57:18 +0100 Subject: [PATCH 119/181] Update snowflake-jdbc to 3.20 (#33018) --- sdks/java/io/snowflake/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/java/io/snowflake/build.gradle b/sdks/java/io/snowflake/build.gradle index 2bdb9a867a34..9be257033edb 100644 --- a/sdks/java/io/snowflake/build.gradle +++ b/sdks/java/io/snowflake/build.gradle @@ -30,7 +30,7 @@ dependencies { implementation project(path: ":sdks:java:extensions:google-cloud-platform-core") permitUnusedDeclared project(path: ":sdks:java:extensions:google-cloud-platform-core") implementation library.java.slf4j_api - implementation group: 'net.snowflake', name: 'snowflake-jdbc', version: '3.12.11' + implementation group: 'net.snowflake', name: 'snowflake-jdbc', version: '3.20.0' implementation group: 'com.opencsv', name: 'opencsv', version: '5.0' implementation 'net.snowflake:snowflake-ingest-sdk:0.9.9' implementation "org.bouncycastle:bcprov-jdk15on:1.70" From 9baa7ba0182b50adeecfe1b97b215a3d0f4a39bd Mon Sep 17 00:00:00 2001 From: Idan Attias Date: Tue, 5 Nov 2024 18:26:12 +0300 Subject: [PATCH 120/181] Upgrade antlr from 4.7 to 4.13.1 (#33016) * Upgrade antlr from 4.7 to 4.13.1 To allow users of newer versions of antlr, that cannot downgrade, to use apache-beam. Fixes #32696 * Update CHANGES.md --------- Co-authored-by: Idan Attias --- CHANGES.md | 1 + .../groovy/org/apache/beam/gradle/BeamModulePlugin.groovy | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 1a9d2045cbf6..c98504df4d1b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -87,6 +87,7 @@ * Removed support for Flink 1.15 and 1.16 * Removed support for Python 3.8 * X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). +* Upgrade antlr from 4.7 to 4.13.1 ([#33016](https://github.com/apache/beam/pull/33016)). ## Bugfixes diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 5af91ec2f056..8d8bf9339c6e 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -665,8 +665,8 @@ class BeamModulePlugin implements Plugin { activemq_junit : "org.apache.activemq.tooling:activemq-junit:$activemq_version", activemq_kahadb_store : "org.apache.activemq:activemq-kahadb-store:$activemq_version", activemq_mqtt : "org.apache.activemq:activemq-mqtt:$activemq_version", - antlr : "org.antlr:antlr4:4.7", - antlr_runtime : "org.antlr:antlr4-runtime:4.7", + antlr : "org.antlr:antlr4:4.13.1", + antlr_runtime : "org.antlr:antlr4-runtime:4.13.1", args4j : "args4j:args4j:2.33", auto_value_annotations : "com.google.auto.value:auto-value-annotations:$autovalue_version", avro : "org.apache.avro:avro:1.11.3", From 689af5ba16ea5cba07783bc25bd21bfa2ab7537d Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com> Date: Tue, 5 Nov 2024 13:08:40 -0500 Subject: [PATCH 121/181] [Managed Iceberg] Allow updating partition specs during pipeline runtime (#32879) * allowed updating partition specs at runtime * add to changes md * add to changes md * trigger iceberg integration tests * refresh cached tables; split multiple partition specs into separate manifest files * add test * address comment * clarify changes comment --- .../IO_Iceberg_Integration_Tests.json | 2 +- CHANGES.md | 1 + .../sdk/io/iceberg/AppendFilesToTables.java | 115 ++++++++++++-- .../beam/sdk/io/iceberg/FileWriteResult.java | 5 +- .../sdk/io/iceberg/RecordWriterManager.java | 45 +++--- .../sdk/io/iceberg/SerializableDataFile.java | 17 ++- .../sdk/io/iceberg/WriteToDestinations.java | 2 +- .../io/iceberg/RecordWriterManagerTest.java | 140 +++++++++++++++--- 8 files changed, 260 insertions(+), 67 deletions(-) diff --git a/.github/trigger_files/IO_Iceberg_Integration_Tests.json b/.github/trigger_files/IO_Iceberg_Integration_Tests.json index 62ae7886c573..bbdc3a3910ef 100644 --- a/.github/trigger_files/IO_Iceberg_Integration_Tests.json +++ b/.github/trigger_files/IO_Iceberg_Integration_Tests.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 4 + "modification": 3 } diff --git a/CHANGES.md b/CHANGES.md index c98504df4d1b..cdedce22e975 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -69,6 +69,7 @@ * [Managed Iceberg] Now available in Python SDK ([#31495](https://github.com/apache/beam/pull/31495)) * [Managed Iceberg] Add support for TIMESTAMP, TIME, and DATE types ([#32688](https://github.com/apache/beam/pull/32688)) * BigQuery CDC writes are now available in Python SDK, only supported when using StorageWrite API at least once mode ([#32527](https://github.com/apache/beam/issues/32527)) +* [Managed Iceberg] Allow updating table partition specs during pipeline runtime ([#32879](https://github.com/apache/beam/pull/32879)) ## New Features / Improvements diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java index defe4f2a603d..d9768114e7c6 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/AppendFilesToTables.java @@ -17,6 +17,12 @@ */ package org.apache.beam.sdk.io.iceberg; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.metrics.Counter; @@ -29,14 +35,21 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.iceberg.AppendFiles; import org.apache.iceberg.DataFile; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.ManifestFiles; +import org.apache.iceberg.ManifestWriter; +import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Snapshot; import org.apache.iceberg.Table; import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.OutputFile; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,9 +58,11 @@ class AppendFilesToTables extends PTransform, PCollection>> { private static final Logger LOG = LoggerFactory.getLogger(AppendFilesToTables.class); private final IcebergCatalogConfig catalogConfig; + private final String manifestFilePrefix; - AppendFilesToTables(IcebergCatalogConfig catalogConfig) { + AppendFilesToTables(IcebergCatalogConfig catalogConfig, String manifestFilePrefix) { this.catalogConfig = catalogConfig; + this.manifestFilePrefix = manifestFilePrefix; } @Override @@ -67,7 +82,7 @@ public String apply(FileWriteResult input) { .apply("Group metadata updates by table", GroupByKey.create()) .apply( "Append metadata updates to tables", - ParDo.of(new AppendFilesToTablesDoFn(catalogConfig))) + ParDo.of(new AppendFilesToTablesDoFn(catalogConfig, manifestFilePrefix))) .setCoder(KvCoder.of(StringUtf8Coder.of(), SnapshotInfo.CODER)); } @@ -75,19 +90,19 @@ private static class AppendFilesToTablesDoFn extends DoFn>, KV> { private final Counter snapshotsCreated = Metrics.counter(AppendFilesToTables.class, "snapshotsCreated"); - private final Counter dataFilesCommitted = - Metrics.counter(AppendFilesToTables.class, "dataFilesCommitted"); private final Distribution committedDataFileByteSize = Metrics.distribution(RecordWriter.class, "committedDataFileByteSize"); private final Distribution committedDataFileRecordCount = Metrics.distribution(RecordWriter.class, "committedDataFileRecordCount"); private final IcebergCatalogConfig catalogConfig; + private final String manifestFilePrefix; private transient @MonotonicNonNull Catalog catalog; - private AppendFilesToTablesDoFn(IcebergCatalogConfig catalogConfig) { + private AppendFilesToTablesDoFn(IcebergCatalogConfig catalogConfig, String manifestFilePrefix) { this.catalogConfig = catalogConfig; + this.manifestFilePrefix = manifestFilePrefix; } private Catalog getCatalog() { @@ -97,11 +112,22 @@ private Catalog getCatalog() { return catalog; } + private boolean containsMultiplePartitionSpecs(Iterable fileWriteResults) { + int id = fileWriteResults.iterator().next().getSerializableDataFile().getPartitionSpecId(); + for (FileWriteResult result : fileWriteResults) { + if (id != result.getSerializableDataFile().getPartitionSpecId()) { + return true; + } + } + return false; + } + @ProcessElement public void processElement( @Element KV> element, OutputReceiver> out, - BoundedWindow window) { + BoundedWindow window) + throws IOException { String tableStringIdentifier = element.getKey(); Iterable fileWriteResults = element.getValue(); if (!fileWriteResults.iterator().hasNext()) { @@ -109,24 +135,81 @@ public void processElement( } Table table = getCatalog().loadTable(TableIdentifier.parse(element.getKey())); + + // vast majority of the time, we will simply append data files. + // in the rare case we get a batch that contains multiple partition specs, we will group + // data into manifest files and append. + // note: either way, we must use a single commit operation for atomicity. + if (containsMultiplePartitionSpecs(fileWriteResults)) { + appendManifestFiles(table, fileWriteResults); + } else { + appendDataFiles(table, fileWriteResults); + } + + Snapshot snapshot = table.currentSnapshot(); + LOG.info("Created new snapshot for table '{}': {}", tableStringIdentifier, snapshot); + snapshotsCreated.inc(); + out.outputWithTimestamp( + KV.of(element.getKey(), SnapshotInfo.fromSnapshot(snapshot)), window.maxTimestamp()); + } + + // This works only when all files are using the same partition spec. + private void appendDataFiles(Table table, Iterable fileWriteResults) { AppendFiles update = table.newAppend(); - long numFiles = 0; for (FileWriteResult result : fileWriteResults) { - DataFile dataFile = result.getDataFile(table.spec()); + DataFile dataFile = result.getDataFile(table.specs()); update.appendFile(dataFile); committedDataFileByteSize.update(dataFile.fileSizeInBytes()); committedDataFileRecordCount.update(dataFile.recordCount()); - numFiles++; } - // this commit will create a ManifestFile. we don't need to manually create one. update.commit(); - dataFilesCommitted.inc(numFiles); + } - Snapshot snapshot = table.currentSnapshot(); - LOG.info("Created new snapshot for table '{}': {}", tableStringIdentifier, snapshot); - snapshotsCreated.inc(); - out.outputWithTimestamp( - KV.of(element.getKey(), SnapshotInfo.fromSnapshot(snapshot)), window.maxTimestamp()); + // When a user updates their table partition spec during runtime, we can end up with + // a batch of files where some are written with the old spec and some are written with the new + // spec. + // A table commit is limited to a single partition spec. + // To handle this, we create a manifest file for each partition spec, and group data files + // accordingly. + // Afterward, we append all manifests using a single commit operation. + private void appendManifestFiles(Table table, Iterable fileWriteResults) + throws IOException { + String uuid = UUID.randomUUID().toString(); + Map specs = table.specs(); + + Map> dataFilesBySpec = new HashMap<>(); + for (FileWriteResult result : fileWriteResults) { + DataFile dataFile = result.getDataFile(specs); + dataFilesBySpec.computeIfAbsent(dataFile.specId(), i -> new ArrayList<>()).add(dataFile); + } + + AppendFiles update = table.newAppend(); + for (Map.Entry> entry : dataFilesBySpec.entrySet()) { + int specId = entry.getKey(); + List files = entry.getValue(); + PartitionSpec spec = Preconditions.checkStateNotNull(specs.get(specId)); + ManifestWriter writer = + createManifestWriter(table.location(), uuid, spec, table.io()); + for (DataFile file : files) { + writer.add(file); + committedDataFileByteSize.update(file.fileSizeInBytes()); + committedDataFileRecordCount.update(file.recordCount()); + } + writer.close(); + update.appendManifest(writer.toManifestFile()); + } + update.commit(); + } + + private ManifestWriter createManifestWriter( + String tableLocation, String uuid, PartitionSpec spec, FileIO io) { + String location = + FileFormat.AVRO.addExtension( + String.format( + "%s/metadata/%s-%s-%s.manifest", + tableLocation, manifestFilePrefix, uuid, spec.specId())); + OutputFile outputFile = io.newOutputFile(location); + return ManifestFiles.write(spec, outputFile); } } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java index c4090d9e7e53..bf00bf8519fc 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/FileWriteResult.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.iceberg; import com.google.auto.value.AutoValue; +import java.util.Map; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaIgnore; @@ -46,9 +47,9 @@ public TableIdentifier getTableIdentifier() { } @SchemaIgnore - public DataFile getDataFile(PartitionSpec spec) { + public DataFile getDataFile(Map specs) { if (cachedDataFile == null) { - cachedDataFile = getSerializableDataFile().createDataFile(spec); + cachedDataFile = getSerializableDataFile().createDataFile(specs); } return cachedDataFile; } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java index 396db7c20f36..12c425993826 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java @@ -25,7 +25,6 @@ import java.util.List; import java.util.Map; import java.util.UUID; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.util.Preconditions; @@ -195,7 +194,9 @@ private RecordWriter createWriter(PartitionKey partitionKey) { private final Map, List> totalSerializableDataFiles = Maps.newHashMap(); - private static final Cache TABLE_CACHE = + + @VisibleForTesting + static final Cache TABLE_CACHE = CacheBuilder.newBuilder().expireAfterAccess(10, TimeUnit.MINUTES).build(); private boolean isClosed = false; @@ -221,22 +222,28 @@ private RecordWriter createWriter(PartitionKey partitionKey) { private Table getOrCreateTable(TableIdentifier identifier, Schema dataSchema) { @Nullable Table table = TABLE_CACHE.getIfPresent(identifier); if (table == null) { - try { - table = catalog.loadTable(identifier); - } catch (NoSuchTableException e) { + synchronized (TABLE_CACHE) { try { - org.apache.iceberg.Schema tableSchema = - IcebergUtils.beamSchemaToIcebergSchema(dataSchema); - // TODO(ahmedabu98): support creating a table with a specified partition spec - table = catalog.createTable(identifier, tableSchema); - LOG.info("Created Iceberg table '{}' with schema: {}", identifier, tableSchema); - } catch (AlreadyExistsException alreadyExistsException) { - // handle race condition where workers are concurrently creating the same table. - // if running into already exists exception, we perform one last load table = catalog.loadTable(identifier); + } catch (NoSuchTableException e) { + try { + org.apache.iceberg.Schema tableSchema = + IcebergUtils.beamSchemaToIcebergSchema(dataSchema); + // TODO(ahmedabu98): support creating a table with a specified partition spec + table = catalog.createTable(identifier, tableSchema); + LOG.info("Created Iceberg table '{}' with schema: {}", identifier, tableSchema); + } catch (AlreadyExistsException alreadyExistsException) { + // handle race condition where workers are concurrently creating the same table. + // if running into already exists exception, we perform one last load + table = catalog.loadTable(identifier); + } } + TABLE_CACHE.put(identifier, table); } - TABLE_CACHE.put(identifier, table); + } else { + // If fetching from cache, refresh the table to avoid working with stale metadata + // (e.g. partition spec) + table.refresh(); } return table; } @@ -254,15 +261,7 @@ public boolean write(WindowedValue icebergDestination, Row r icebergDestination, destination -> { TableIdentifier identifier = destination.getValue().getTableIdentifier(); - Table table; - try { - table = - TABLE_CACHE.get( - identifier, () -> getOrCreateTable(identifier, row.getSchema())); - } catch (ExecutionException e) { - throw new RuntimeException( - "Error while fetching or creating table: " + identifier, e); - } + Table table = getOrCreateTable(identifier, row.getSchema()); return new DestinationState(destination.getValue(), table); }); diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java index 699d4fa4dfd0..59b456162008 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/SerializableDataFile.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.iceberg; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + import com.google.auto.value.AutoValue; import java.nio.ByteBuffer; import java.util.HashMap; @@ -24,7 +26,6 @@ import java.util.Map; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.iceberg.DataFile; import org.apache.iceberg.DataFiles; import org.apache.iceberg.FileFormat; @@ -141,12 +142,14 @@ static SerializableDataFile from(DataFile f, PartitionKey key) { * it from Beam-compatible types. */ @SuppressWarnings("nullness") - DataFile createDataFile(PartitionSpec partitionSpec) { - Preconditions.checkState( - partitionSpec.specId() == getPartitionSpecId(), - "Invalid partition spec id '%s'. This DataFile was originally created with spec id '%s'.", - partitionSpec.specId(), - getPartitionSpecId()); + DataFile createDataFile(Map partitionSpecs) { + PartitionSpec partitionSpec = + checkStateNotNull( + partitionSpecs.get(getPartitionSpecId()), + "This DataFile was originally created with spec id '%s'. Could not find " + + "this among table's partition specs: %s.", + getPartitionSpecId(), + partitionSpecs.keySet()); Metrics dataFileMetrics = new Metrics( diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteToDestinations.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteToDestinations.java index a2d0c320f58f..fb3bf43f3515 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteToDestinations.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/WriteToDestinations.java @@ -74,7 +74,7 @@ public IcebergWriteResult expand(PCollection> input) { // Commit files to tables PCollection> snapshots = - writtenFiles.apply(new AppendFilesToTables(catalogConfig)); + writtenFiles.apply(new AppendFilesToTables(catalogConfig, filePrefix)); return new IcebergWriteResult(input.getPipeline(), snapshots); } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java index 7adf6defe520..8ced06bc944f 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java @@ -19,24 +19,29 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.either; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.IOException; +import java.util.List; import java.util.Map; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.commons.lang3.RandomStringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.DataFile; import org.apache.iceberg.FileFormat; import org.apache.iceberg.PartitionKey; import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Table; import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.hadoop.HadoopCatalog; import org.checkerframework.checker.nullness.qual.Nullable; @@ -73,6 +78,7 @@ public void setUp() { windowedDestination = getWindowedDestination("table_" + testName.getMethodName(), PARTITION_SPEC); catalog = new HadoopCatalog(new Configuration(), warehouse.location); + RecordWriterManager.TABLE_CACHE.invalidateAll(); } private WindowedValue getWindowedDestination( @@ -269,6 +275,25 @@ public void testRequireClosingBeforeFetchingDataFiles() { assertThrows(IllegalStateException.class, writerManager::getSerializableDataFiles); } + /** DataFile doesn't implement a .equals() method. Check equality manually. */ + private static void checkDataFileEquality(DataFile d1, DataFile d2) { + assertEquals(d1.path(), d2.path()); + assertEquals(d1.format(), d2.format()); + assertEquals(d1.recordCount(), d2.recordCount()); + assertEquals(d1.partition(), d2.partition()); + assertEquals(d1.specId(), d2.specId()); + assertEquals(d1.keyMetadata(), d2.keyMetadata()); + assertEquals(d1.splitOffsets(), d2.splitOffsets()); + assertEquals(d1.columnSizes(), d2.columnSizes()); + assertEquals(d1.valueCounts(), d2.valueCounts()); + assertEquals(d1.nullValueCounts(), d2.nullValueCounts()); + assertEquals(d1.nanValueCounts(), d2.nanValueCounts()); + assertEquals(d1.equalityFieldIds(), d2.equalityFieldIds()); + assertEquals(d1.fileSequenceNumber(), d2.fileSequenceNumber()); + assertEquals(d1.dataSequenceNumber(), d2.dataSequenceNumber()); + assertEquals(d1.pos(), d2.pos()); + } + @Test public void testSerializableDataFileRoundTripEquality() throws IOException { PartitionKey partitionKey = new PartitionKey(PARTITION_SPEC, ICEBERG_SCHEMA); @@ -288,22 +313,103 @@ public void testSerializableDataFileRoundTripEquality() throws IOException { assertEquals(2L, datafile.recordCount()); DataFile roundTripDataFile = - SerializableDataFile.from(datafile, partitionKey).createDataFile(PARTITION_SPEC); - // DataFile doesn't implement a .equals() method. Check equality manually - assertEquals(datafile.path(), roundTripDataFile.path()); - assertEquals(datafile.format(), roundTripDataFile.format()); - assertEquals(datafile.recordCount(), roundTripDataFile.recordCount()); - assertEquals(datafile.partition(), roundTripDataFile.partition()); - assertEquals(datafile.specId(), roundTripDataFile.specId()); - assertEquals(datafile.keyMetadata(), roundTripDataFile.keyMetadata()); - assertEquals(datafile.splitOffsets(), roundTripDataFile.splitOffsets()); - assertEquals(datafile.columnSizes(), roundTripDataFile.columnSizes()); - assertEquals(datafile.valueCounts(), roundTripDataFile.valueCounts()); - assertEquals(datafile.nullValueCounts(), roundTripDataFile.nullValueCounts()); - assertEquals(datafile.nanValueCounts(), roundTripDataFile.nanValueCounts()); - assertEquals(datafile.equalityFieldIds(), roundTripDataFile.equalityFieldIds()); - assertEquals(datafile.fileSequenceNumber(), roundTripDataFile.fileSequenceNumber()); - assertEquals(datafile.dataSequenceNumber(), roundTripDataFile.dataSequenceNumber()); - assertEquals(datafile.pos(), roundTripDataFile.pos()); + SerializableDataFile.from(datafile, partitionKey) + .createDataFile(ImmutableMap.of(PARTITION_SPEC.specId(), PARTITION_SPEC)); + + checkDataFileEquality(datafile, roundTripDataFile); + } + + /** + * Users may update the table's spec while a write pipeline is running. Sometimes, this can happen + * after converting {@link DataFile} to {@link SerializableDataFile}s. When converting back to + * {@link DataFile} to commit in the {@link AppendFilesToTables} step, we need to make sure to use + * the same {@link PartitionSpec} it was originally created with. + * + *

This test checks that we're preserving the right {@link PartitionSpec} when such an update + * happens. + */ + @Test + public void testRecreateSerializableDataAfterUpdatingPartitionSpec() throws IOException { + PartitionKey partitionKey = new PartitionKey(PARTITION_SPEC, ICEBERG_SCHEMA); + + Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build(); + Row row2 = Row.withSchema(BEAM_SCHEMA).addValues(2, "abcxyz", true).build(); + // same partition for both records (name_trunc=abc, bool=true) + partitionKey.partition(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row)); + + // write some rows + RecordWriter writer = + new RecordWriter(catalog, windowedDestination.getValue(), "test_file_name", partitionKey); + writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row)); + writer.write(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row2)); + writer.close(); + + // fetch data file and its serializable version + DataFile datafile = writer.getDataFile(); + SerializableDataFile serializableDataFile = SerializableDataFile.from(datafile, partitionKey); + + assertEquals(2L, datafile.recordCount()); + assertEquals(serializableDataFile.getPartitionSpecId(), datafile.specId()); + + // update spec + Table table = catalog.loadTable(windowedDestination.getValue().getTableIdentifier()); + table.updateSpec().addField("id").removeField("bool").commit(); + + Map updatedSpecs = table.specs(); + DataFile roundTripDataFile = serializableDataFile.createDataFile(updatedSpecs); + + checkDataFileEquality(datafile, roundTripDataFile); + } + + @Test + public void testWriterKeepsUpWithUpdatingPartitionSpec() throws IOException { + Table table = catalog.loadTable(windowedDestination.getValue().getTableIdentifier()); + Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build(); + Row row2 = Row.withSchema(BEAM_SCHEMA).addValues(2, "abcxyz", true).build(); + + // write some rows + RecordWriterManager writer = + new RecordWriterManager(catalog, "test_prefix", Long.MAX_VALUE, Integer.MAX_VALUE); + writer.write(windowedDestination, row); + writer.write(windowedDestination, row2); + writer.close(); + DataFile dataFile = + writer + .getSerializableDataFiles() + .get(windowedDestination) + .get(0) + .createDataFile(table.specs()); + + // check data file path contains the correct partition components + assertEquals(2L, dataFile.recordCount()); + assertEquals(dataFile.specId(), PARTITION_SPEC.specId()); + assertThat(dataFile.path().toString(), containsString("name_trunc=abc")); + assertThat(dataFile.path().toString(), containsString("bool=true")); + + // table is cached + assertEquals(1, RecordWriterManager.TABLE_CACHE.size()); + + // update spec + table.updateSpec().addField("id").removeField("bool").commit(); + + // write a second data file + // should refresh the table and use the new partition spec + RecordWriterManager writer2 = + new RecordWriterManager(catalog, "test_prefix_2", Long.MAX_VALUE, Integer.MAX_VALUE); + writer2.write(windowedDestination, row); + writer2.write(windowedDestination, row2); + writer2.close(); + + List serializableDataFiles = + writer2.getSerializableDataFiles().get(windowedDestination); + assertEquals(2, serializableDataFiles.size()); + for (SerializableDataFile serializableDataFile : serializableDataFiles) { + assertEquals(table.spec().specId(), serializableDataFile.getPartitionSpecId()); + dataFile = serializableDataFile.createDataFile(table.specs()); + assertEquals(1L, dataFile.recordCount()); + assertThat(dataFile.path().toString(), containsString("name_trunc=abc")); + assertThat( + dataFile.path().toString(), either(containsString("id=1")).or(containsString("id=2"))); + } } } From 738a76dd8f1779a30e666b07426002f61f798027 Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:23:16 -0500 Subject: [PATCH 122/181] [Managed Iceberg] bubble up exceptions due to writer close (#32940) * throw suppressed cache exceptions * add test --- .../beam/sdk/io/iceberg/RecordWriter.java | 4 ++ .../sdk/io/iceberg/RecordWriterManager.java | 25 ++++++-- .../io/iceberg/RecordWriterManagerTest.java | 60 +++++++++++++++++++ 3 files changed, 84 insertions(+), 5 deletions(-) diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriter.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriter.java index 9a3262e19845..7941c13b0dfe 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriter.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriter.java @@ -136,4 +136,8 @@ public long bytesWritten() { public DataFile getDataFile() { return icebergDataWriter.toDataFile(); } + + public String path() { + return absoluteFilename; + } } diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java index 12c425993826..255fce9ece4e 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/RecordWriterManager.java @@ -90,6 +90,7 @@ class DestinationState { final Cache writers; private final List dataFiles = Lists.newArrayList(); @VisibleForTesting final Map writerCounts = Maps.newHashMap(); + private final List exceptions = Lists.newArrayList(); DestinationState(IcebergDestination icebergDestination, Table table) { this.icebergDestination = icebergDestination; @@ -112,11 +113,14 @@ class DestinationState { try { recordWriter.close(); } catch (IOException e) { - throw new RuntimeException( - String.format( - "Encountered an error when closing data writer for table '%s', partition %s", - icebergDestination.getTableIdentifier(), pk), - e); + RuntimeException rethrow = + new RuntimeException( + String.format( + "Encountered an error when closing data writer for table '%s', path: %s", + icebergDestination.getTableIdentifier(), recordWriter.path()), + e); + exceptions.add(rethrow); + throw rethrow; } openWriters--; dataFiles.add(SerializableDataFile.from(recordWriter.getDataFile(), pk)); @@ -282,6 +286,17 @@ public void close() throws IOException { // removing writers from the state's cache will trigger the logic to collect each writer's // data file. state.writers.invalidateAll(); + // first check for any exceptions swallowed by the cache + if (!state.exceptions.isEmpty()) { + IllegalStateException exception = + new IllegalStateException( + String.format("Encountered %s failed writer(s).", state.exceptions.size())); + for (Exception e : state.exceptions) { + exception.addSuppressed(e); + } + throw exception; + } + if (state.dataFiles.isEmpty()) { continue; } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java index 8ced06bc944f..2bce390e0992 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/RecordWriterManagerTest.java @@ -42,6 +42,7 @@ import org.apache.iceberg.PartitionKey; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.hadoop.HadoopCatalog; import org.checkerframework.checker.nullness.qual.Nullable; @@ -49,6 +50,7 @@ import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.rules.TemporaryFolder; import org.junit.rules.TestName; import org.junit.runner.RunWith; @@ -412,4 +414,62 @@ public void testWriterKeepsUpWithUpdatingPartitionSpec() throws IOException { dataFile.path().toString(), either(containsString("id=1")).or(containsString("id=2"))); } } + + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testWriterExceptionGetsCaught() throws IOException { + RecordWriterManager writerManager = new RecordWriterManager(catalog, "test_file_name", 100, 2); + Row row = Row.withSchema(BEAM_SCHEMA).addValues(1, "abcdef", true).build(); + PartitionKey partitionKey = new PartitionKey(PARTITION_SPEC, ICEBERG_SCHEMA); + partitionKey.partition(IcebergUtils.beamRowToIcebergRecord(ICEBERG_SCHEMA, row)); + + writerManager.write(windowedDestination, row); + + RecordWriterManager.DestinationState state = + writerManager.destinations.get(windowedDestination); + // replace with a failing record writer + FailingRecordWriter failingWriter = + new FailingRecordWriter( + catalog, windowedDestination.getValue(), "test_failing_writer", partitionKey); + state.writers.put(partitionKey, failingWriter); + writerManager.write(windowedDestination, row); + + // this tests that we indeed enter the catch block + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Encountered 1 failed writer(s)"); + try { + writerManager.close(); + } catch (IllegalStateException e) { + // fetch underlying exceptions and validate + Throwable[] underlyingExceptions = e.getSuppressed(); + assertEquals(1, underlyingExceptions.length); + for (Throwable t : underlyingExceptions) { + assertThat( + t.getMessage(), + containsString("Encountered an error when closing data writer for table")); + assertThat( + t.getMessage(), + containsString(windowedDestination.getValue().getTableIdentifier().toString())); + assertThat(t.getMessage(), containsString(failingWriter.path())); + Throwable realCause = t.getCause(); + assertEquals("I am failing!", realCause.getMessage()); + } + + throw e; + } + } + + static class FailingRecordWriter extends RecordWriter { + FailingRecordWriter( + Catalog catalog, IcebergDestination destination, String filename, PartitionKey partitionKey) + throws IOException { + super(catalog, destination, filename, partitionKey); + } + + @Override + public void close() throws IOException { + throw new IOException("I am failing!"); + } + } } From 3a941c07bcd3ed9730acc08c8c39217e2dcd8157 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Tue, 5 Nov 2024 14:32:26 -0800 Subject: [PATCH 123/181] Update Java ExpansionService to use arbitrary PipelineOptions set through an ExpansionRequest --- .../PipelineOptionsTranslation.java | 11 +- .../expansion/service/ExpansionService.java | 67 +++++----- .../service/ExpansionServiceTest.java | 114 +++++++++++++++++- 3 files changed, 155 insertions(+), 37 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PipelineOptionsTranslation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PipelineOptionsTranslation.java index cd6ab7dd414a..de1717f0a45f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PipelineOptionsTranslation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/PipelineOptionsTranslation.java @@ -43,6 +43,9 @@ public class PipelineOptionsTranslation { new ObjectMapper() .registerModules(ObjectMapper.findModules(ReflectHelpers.findClassLoader())); + public static final String PIPELINE_OPTIONS_URN_PREFIX = "beam:option:"; + public static final String PIPELINE_OPTIONS_URN_SUFFIX = ":v1"; + /** Converts the provided {@link PipelineOptions} to a {@link Struct}. */ public static Struct toProto(PipelineOptions options) { Struct.Builder builder = Struct.newBuilder(); @@ -65,9 +68,9 @@ public static Struct toProto(PipelineOptions options) { while (optionsEntries.hasNext()) { Map.Entry entry = optionsEntries.next(); optionsUsingUrns.put( - "beam:option:" + PIPELINE_OPTIONS_URN_PREFIX + CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, entry.getKey()) - + ":v1", + + PIPELINE_OPTIONS_URN_SUFFIX, entry.getValue()); } @@ -92,7 +95,9 @@ public static PipelineOptions fromProto(Struct protoOptions) { mapWithoutUrns.put( CaseFormat.LOWER_UNDERSCORE.to( CaseFormat.LOWER_CAMEL, - optionKey.substring("beam:option:".length(), optionKey.length() - ":v1".length())), + optionKey.substring( + PIPELINE_OPTIONS_URN_PREFIX.length(), + optionKey.length() - PIPELINE_OPTIONS_URN_SUFFIX.length())), optionValue); } return MAPPER.readValue( diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java index 150fe9729573..9c5b5a0ad136 100644 --- a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java @@ -60,7 +60,6 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.PortablePipelineOptions; -import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; @@ -535,7 +534,7 @@ private static void invokeSetter(ConfigT config, @Nullable Object valu } private @MonotonicNonNull Map registeredTransforms; - private final PipelineOptions pipelineOptions; + private final PipelineOptions commandLineOptions; private final @Nullable String loopbackAddress; public ExpansionService() { @@ -551,7 +550,7 @@ public ExpansionService(PipelineOptions opts) { } public ExpansionService(PipelineOptions opts, @Nullable String loopbackAddress) { - this.pipelineOptions = opts; + this.commandLineOptions = opts; this.loopbackAddress = loopbackAddress; } @@ -587,12 +586,15 @@ private Map loadRegisteredTransforms() { request.getTransform().getSpec().getUrn()); LOG.debug("Full transform: {}", request.getTransform()); Set existingTransformIds = request.getComponents().getTransformsMap().keySet(); - Pipeline pipeline = - createPipeline(PipelineOptionsTranslation.fromProto(request.getPipelineOptions())); + + PipelineOptions pipelineOptionsFromRequest = + PipelineOptionsTranslation.fromProto(request.getPipelineOptions()); + Pipeline pipeline = createPipeline(pipelineOptionsFromRequest); + boolean isUseDeprecatedRead = - ExperimentalOptions.hasExperiment(pipelineOptions, "use_deprecated_read") + ExperimentalOptions.hasExperiment(commandLineOptions, "use_deprecated_read") || ExperimentalOptions.hasExperiment( - pipelineOptions, "beam_fn_api_use_deprecated_read"); + commandLineOptions, "beam_fn_api_use_deprecated_read"); if (!isUseDeprecatedRead) { ExperimentalOptions.addExperiment( pipeline.getOptions().as(ExperimentalOptions.class), "beam_fn_api"); @@ -629,7 +631,7 @@ private Map loadRegisteredTransforms() { if (transformProvider == null) { if (getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP).equals(urn)) { AllowList allowList = - pipelineOptions.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlist(); + commandLineOptions.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlist(); assert allowList != null; transformProvider = new JavaClassLookupTransformProvider(allowList); } else if (getUrn(SCHEMA_TRANSFORM).equals(urn)) { @@ -671,7 +673,7 @@ private Map loadRegisteredTransforms() { RunnerApi.Environment defaultEnvironment = Environments.createOrGetDefaultEnvironment( pipeline.getOptions().as(PortablePipelineOptions.class)); - if (pipelineOptions.as(ExpansionServiceOptions.class).getAlsoStartLoopbackWorker()) { + if (commandLineOptions.as(ExpansionServiceOptions.class).getAlsoStartLoopbackWorker()) { PortablePipelineOptions externalOptions = PipelineOptionsFactory.create().as(PortablePipelineOptions.class); externalOptions.setDefaultEnvironmentType(Environments.ENVIRONMENT_EXTERNAL); @@ -723,35 +725,34 @@ private Map loadRegisteredTransforms() { } protected Pipeline createPipeline(PipelineOptions requestOptions) { - // TODO: [https://github.com/apache/beam/issues/21064]: implement proper validation - PipelineOptions effectiveOpts = PipelineOptionsFactory.create(); - PortablePipelineOptions portableOptions = effectiveOpts.as(PortablePipelineOptions.class); - PortablePipelineOptions specifiedOptions = pipelineOptions.as(PortablePipelineOptions.class); - Optional.ofNullable(specifiedOptions.getDefaultEnvironmentType()) - .ifPresent(portableOptions::setDefaultEnvironmentType); - Optional.ofNullable(specifiedOptions.getDefaultEnvironmentConfig()) - .ifPresent(portableOptions::setDefaultEnvironmentConfig); - List filesToStage = specifiedOptions.getFilesToStage(); + // We expect the ExpansionRequest to contain a valid set of options to be used for this + // expansion. + // Additionally, we override selected options using options values set via command line or + // ExpansionService wide overrides. + + PortablePipelineOptions requestPortablePipelineOptions = + requestOptions.as(PortablePipelineOptions.class); + PortablePipelineOptions commandLinePortablePipelineOptions = + commandLineOptions.as(PortablePipelineOptions.class); + Optional.ofNullable(commandLinePortablePipelineOptions.getDefaultEnvironmentType()) + .ifPresent(requestPortablePipelineOptions::setDefaultEnvironmentType); + Optional.ofNullable(commandLinePortablePipelineOptions.getDefaultEnvironmentConfig()) + .ifPresent(requestPortablePipelineOptions::setDefaultEnvironmentConfig); + List filesToStage = commandLinePortablePipelineOptions.getFilesToStage(); if (filesToStage != null) { - effectiveOpts.as(PortablePipelineOptions.class).setFilesToStage(filesToStage); + requestPortablePipelineOptions + .as(PortablePipelineOptions.class) + .setFilesToStage(filesToStage); } - effectiveOpts + requestPortablePipelineOptions .as(ExperimentalOptions.class) - .setExperiments(pipelineOptions.as(ExperimentalOptions.class).getExperiments()); - effectiveOpts.setRunner(NotRunnableRunner.class); - effectiveOpts + .setExperiments(commandLineOptions.as(ExperimentalOptions.class).getExperiments()); + requestPortablePipelineOptions.setRunner(NotRunnableRunner.class); + requestPortablePipelineOptions .as(ExpansionServiceOptions.class) .setExpansionServiceConfig( - pipelineOptions.as(ExpansionServiceOptions.class).getExpansionServiceConfig()); - // TODO(https://github.com/apache/beam/issues/20090): Figure out the correct subset of options - // to propagate. - if (requestOptions.as(StreamingOptions.class).getUpdateCompatibilityVersion() != null) { - effectiveOpts - .as(StreamingOptions.class) - .setUpdateCompatibilityVersion( - requestOptions.as(StreamingOptions.class).getUpdateCompatibilityVersion()); - } - return Pipeline.create(effectiveOpts); + commandLineOptions.as(ExpansionServiceOptions.class).getExpansionServiceConfig()); + return Pipeline.create(requestOptions); } @Override diff --git a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java index 1c8d515d5c85..9ee0c2c1797b 100644 --- a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java +++ b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.expansion.service; +import static org.apache.beam.sdk.util.construction.PipelineOptionsTranslation.PIPELINE_OPTIONS_URN_PREFIX; +import static org.apache.beam.sdk.util.construction.PipelineOptionsTranslation.PIPELINE_OPTIONS_URN_SUFFIX; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.contains; @@ -49,6 +51,8 @@ import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; @@ -58,15 +62,20 @@ import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Impulse; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.Struct; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.Value; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Resources; import org.checkerframework.checker.nullness.qual.Nullable; import org.hamcrest.Matchers; +import org.junit.Assert; import org.junit.Test; /** Tests for {@link ExpansionService}. */ @@ -76,6 +85,7 @@ public class ExpansionServiceTest { private static final String TEST_URN = "test:beam:transforms:count"; + private static final String TEST_OPTIONS_URN = "test:beam:transforms:test_options"; private static final String TEST_NAME = "TestName"; @@ -98,9 +108,59 @@ public class ExpansionServiceTest { @AutoService(ExpansionService.ExpansionServiceRegistrar.class) public static class TestTransformRegistrar implements ExpansionService.ExpansionServiceRegistrar { + static final String EXPECTED_STRING_VALUE = "abcde"; + static final Boolean EXPECTED_BOOLEAN_VALUE = true; + static final Integer EXPECTED_INTEGER_VALUE = 12345; + @Override public Map knownTransforms() { - return ImmutableMap.of(TEST_URN, (spec, options) -> Count.perElement()); + return ImmutableMap.of( + TEST_URN, (spec, options) -> Count.perElement(), + TEST_OPTIONS_URN, + (spec, options) -> + new TestOptionsTransform( + EXPECTED_STRING_VALUE, EXPECTED_BOOLEAN_VALUE, EXPECTED_INTEGER_VALUE)); + } + } + + public interface TestOptions extends PipelineOptions { + String getStringOption(); + + void setStringOption(String value); + + Boolean getBooleanOption(); + + void setBooleanOption(Boolean value); + + Integer getIntegerOption(); + + void setIntegerOption(Integer value); + } + + public static class TestOptionsTransform + extends PTransform, PCollection> { + String expectedStringValue; + + Boolean expectedBooleanValue; + + Integer expectedIntegerValue; + + public TestOptionsTransform( + String expectedStringValue, Boolean expectedBooleanValue, Integer expectedIntegerValue) { + this.expectedStringValue = expectedStringValue; + this.expectedBooleanValue = expectedBooleanValue; + this.expectedIntegerValue = expectedIntegerValue; + } + + @Override + public PCollection expand(PCollection input) { + TestOptions testOption = input.getPipeline().getOptions().as(TestOptions.class); + + Assert.assertEquals(expectedStringValue, testOption.getStringOption()); + Assert.assertEquals(expectedBooleanValue, testOption.getBooleanOption()); + Assert.assertEquals(expectedIntegerValue, testOption.getIntegerOption()); + + return input; } } @@ -146,6 +206,58 @@ public void testConstruct() { } } + @Test + public void testConstructWithPipelineOptions() { + PipelineOptionsFactory.register(TestOptions.class); + Pipeline p = Pipeline.create(); + p.apply(Impulse.create()); + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + String inputPcollId = + Iterables.getOnlyElement( + Iterables.getOnlyElement(pipelineProto.getComponents().getTransformsMap().values()) + .getOutputsMap() + .values()); + + Struct optionsStruct = + Struct.newBuilder() + .putFields( + PIPELINE_OPTIONS_URN_PREFIX + "string_option" + PIPELINE_OPTIONS_URN_SUFFIX, + Value.newBuilder() + .setStringValue(TestTransformRegistrar.EXPECTED_STRING_VALUE) + .build()) + .putFields( + PIPELINE_OPTIONS_URN_PREFIX + "boolean_option" + PIPELINE_OPTIONS_URN_SUFFIX, + Value.newBuilder() + .setBoolValue(TestTransformRegistrar.EXPECTED_BOOLEAN_VALUE) + .build()) + .putFields( + PIPELINE_OPTIONS_URN_PREFIX + "integer_option" + PIPELINE_OPTIONS_URN_SUFFIX, + Value.newBuilder() + .setNumberValue(TestTransformRegistrar.EXPECTED_INTEGER_VALUE) + .build()) + .build(); + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setPipelineOptions(optionsStruct) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName(TEST_NAME) + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(TEST_OPTIONS_URN)) + .putInputs("input", inputPcollId)) + .setNamespace(TEST_NAMESPACE) + .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); + RunnerApi.PTransform expandedTransform = response.getTransform(); + assertEquals(TEST_NAMESPACE + TEST_NAME, expandedTransform.getUniqueName()); + + // Verify it has the right input. + assertThat(expandedTransform.getInputsMap().values(), contains(inputPcollId)); + + // Verify it has the right output. + assertThat(expandedTransform.getOutputsMap().keySet(), contains("output")); + } + @Test public void testConstructGenerateSequenceWithRegistration() { ExternalTransforms.ExternalConfigurationPayload payload = From 708932149f1a86af87d5d599335157801a98ec9e Mon Sep 17 00:00:00 2001 From: martin trieu Date: Wed, 6 Nov 2024 03:24:13 -0600 Subject: [PATCH 124/181] extract semaphore logic out of WeightBoundedQueue to allow for sharing the weigher (#32905) --- .../worker/StreamingDataflowWorker.java | 2 + .../streaming/WeightedBoundedQueue.java | 45 ++++------- .../worker/streaming/WeightedSemaphore.java | 53 ++++++++++++ .../windmill/client/commits/Commits.java | 36 +++++++++ .../StreamingApplianceWorkCommitter.java | 8 +- .../commits/StreamingEngineWorkCommitter.java | 16 ++-- .../streaming/WeightBoundedQueueTest.java | 81 +++++++++++++++---- .../StreamingEngineWorkCommitterTest.java | 2 + 8 files changed, 184 insertions(+), 59 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedSemaphore.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index ff72add83e4d..6ce60283735f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -65,6 +65,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.appliance.JniWindmillApplianceServer; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; +import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commits; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.CompleteCommit; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingApplianceWorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.StreamingEngineWorkCommitter; @@ -199,6 +200,7 @@ private StreamingDataflowWorker( this.workCommitter = windmillServiceEnabled ? StreamingEngineWorkCommitter.builder() + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) .setCommitWorkStreamFactory( WindmillStreamPool.create( numCommitThreads, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java index f2893f3e7191..5f039be7b00f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WeightedBoundedQueue.java @@ -18,33 +18,24 @@ package org.apache.beam.runners.dataflow.worker.streaming; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import java.util.function.Function; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.checkerframework.checker.nullness.qual.Nullable; -/** Bounded set of queues, with a maximum total weight. */ +/** Queue bounded by a {@link WeightedSemaphore}. */ public final class WeightedBoundedQueue { private final LinkedBlockingQueue queue; - private final int maxWeight; - private final Semaphore limit; - private final Function weigher; + private final WeightedSemaphore weightedSemaphore; private WeightedBoundedQueue( - LinkedBlockingQueue linkedBlockingQueue, - int maxWeight, - Semaphore limit, - Function weigher) { + LinkedBlockingQueue linkedBlockingQueue, WeightedSemaphore weightedSemaphore) { this.queue = linkedBlockingQueue; - this.maxWeight = maxWeight; - this.limit = limit; - this.weigher = weigher; + this.weightedSemaphore = weightedSemaphore; } - public static WeightedBoundedQueue create(int maxWeight, Function weigherFn) { - return new WeightedBoundedQueue<>( - new LinkedBlockingQueue<>(), maxWeight, new Semaphore(maxWeight, true), weigherFn); + public static WeightedBoundedQueue create(WeightedSemaphore weightedSemaphore) { + return new WeightedBoundedQueue<>(new LinkedBlockingQueue<>(), weightedSemaphore); } /** @@ -52,15 +43,15 @@ public static WeightedBoundedQueue create(int maxWeight, Function { + private final int maxWeight; + private final Semaphore limit; + private final Function weigher; + + private WeightedSemaphore(int maxWeight, Semaphore limit, Function weigher) { + this.maxWeight = maxWeight; + this.limit = limit; + this.weigher = weigher; + } + + public static WeightedSemaphore create(int maxWeight, Function weigherFn) { + return new WeightedSemaphore<>(maxWeight, new Semaphore(maxWeight, true), weigherFn); + } + + public void acquireUninterruptibly(V value) { + limit.acquireUninterruptibly(computePermits(value)); + } + + public void release(V value) { + limit.release(computePermits(value)); + } + + private int computePermits(V value) { + return Math.min(weigher.apply(value), maxWeight); + } + + public int currentWeight() { + return maxWeight - limit.availablePermits(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java new file mode 100644 index 000000000000..498e90f78e29 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/Commits.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.commits; + +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; + +/** Utility class for commits. */ +@Internal +public final class Commits { + + /** Max bytes of commits queued on the user worker. */ + @VisibleForTesting static final int MAX_QUEUED_COMMITS_BYTES = 500 << 20; // 500MB + + private Commits() {} + + public static WeightedSemaphore maxCommitByteSemaphore() { + return WeightedSemaphore.create(MAX_QUEUED_COMMITS_BYTES, Commit::getSize); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java index 6889764afe69..20b95b0661d0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitter.java @@ -42,7 +42,6 @@ public final class StreamingApplianceWorkCommitter implements WorkCommitter { private static final Logger LOG = LoggerFactory.getLogger(StreamingApplianceWorkCommitter.class); private static final long TARGET_COMMIT_BUNDLE_BYTES = 32 << 20; - private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB private final Consumer commitWorkFn; private final WeightedBoundedQueue commitQueue; @@ -53,9 +52,7 @@ public final class StreamingApplianceWorkCommitter implements WorkCommitter { private StreamingApplianceWorkCommitter( Consumer commitWorkFn, Consumer onCommitComplete) { this.commitWorkFn = commitWorkFn; - this.commitQueue = - WeightedBoundedQueue.create( - MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); + this.commitQueue = WeightedBoundedQueue.create(Commits.maxCommitByteSemaphore()); this.commitWorkers = Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() @@ -73,10 +70,9 @@ public static StreamingApplianceWorkCommitter create( } @Override - @SuppressWarnings("FutureReturnValueIgnored") public void start() { if (!commitWorkers.isShutdown()) { - commitWorkers.submit(this::commitLoop); + commitWorkers.execute(this::commitLoop); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index bf1007bc4bfb..85fa1d67c6c3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -28,6 +28,7 @@ import javax.annotation.Nullable; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.streaming.WeightedBoundedQueue; +import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; @@ -46,7 +47,6 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { private static final Logger LOG = LoggerFactory.getLogger(StreamingEngineWorkCommitter.class); private static final int TARGET_COMMIT_BATCH_KEYS = 5; - private static final int MAX_COMMIT_QUEUE_BYTES = 500 << 20; // 500MB private static final String NO_BACKEND_WORKER_TOKEN = ""; private final Supplier> commitWorkStreamFactory; @@ -61,11 +61,10 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { Supplier> commitWorkStreamFactory, int numCommitSenders, Consumer onCommitComplete, - String backendWorkerToken) { + String backendWorkerToken, + WeightedSemaphore commitByteSemaphore) { this.commitWorkStreamFactory = commitWorkStreamFactory; - this.commitQueue = - WeightedBoundedQueue.create( - MAX_COMMIT_QUEUE_BYTES, commit -> Math.min(MAX_COMMIT_QUEUE_BYTES, commit.getSize())); + this.commitQueue = WeightedBoundedQueue.create(commitByteSemaphore); this.commitSenders = Executors.newFixedThreadPool( numCommitSenders, @@ -90,12 +89,11 @@ public static Builder builder() { } @Override - @SuppressWarnings("FutureReturnValueIgnored") public void start() { Preconditions.checkState( isRunning.compareAndSet(false, true), "Multiple calls to WorkCommitter.start()."); for (int i = 0; i < numCommitSenders; i++) { - commitSenders.submit(this::streamingCommitLoop); + commitSenders.execute(this::streamingCommitLoop); } } @@ -166,6 +164,8 @@ private void streamingCommitLoop() { return; } } + + // take() blocks until a value is available in the commitQueue. Preconditions.checkNotNull(initialCommit); if (initialCommit.work().isFailed()) { @@ -258,6 +258,8 @@ public interface Builder { Builder setCommitWorkStreamFactory( Supplier> commitWorkStreamFactory); + Builder setCommitByteSemaphore(WeightedSemaphore commitByteSemaphore); + Builder setNumCommitSenders(int numCommitSenders); Builder setOnCommitComplete(Consumer onCommitComplete); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java index 4f035c88774c..c71001fbeee7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/WeightBoundedQueueTest.java @@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; @@ -30,27 +31,29 @@ @RunWith(JUnit4.class) public class WeightBoundedQueueTest { - @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private static final int MAX_WEIGHT = 10; + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); @Test public void testPut_hasCapacity() { - WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedSemaphore weightedSemaphore = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue = WeightedBoundedQueue.create(weightedSemaphore); int insertedValue = 1; queue.put(insertedValue); - assertEquals(insertedValue, queue.queuedElementsWeight()); + assertEquals(insertedValue, weightedSemaphore.currentWeight()); assertEquals(1, queue.size()); assertEquals(insertedValue, (int) queue.poll()); } @Test public void testPut_noCapacity() throws InterruptedException { - WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedSemaphore weightedSemaphore = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue = WeightedBoundedQueue.create(weightedSemaphore); // Insert value that takes all the capacity into the queue. queue.put(MAX_WEIGHT); @@ -71,7 +74,7 @@ public void testPut_noCapacity() throws InterruptedException { // Should only see the first value in the queue, since the queue is at capacity. thread2 // should be blocked. - assertEquals(MAX_WEIGHT, queue.queuedElementsWeight()); + assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight()); assertEquals(1, queue.size()); // Poll the queue, pulling off the only value inside and freeing up the capacity in the queue. @@ -80,14 +83,15 @@ public void testPut_noCapacity() throws InterruptedException { // Wait for the putThread which was previously blocked due to the queue being at capacity. putThread.join(); - assertEquals(MAX_WEIGHT, queue.queuedElementsWeight()); + assertEquals(MAX_WEIGHT, weightedSemaphore.currentWeight()); assertEquals(1, queue.size()); } @Test public void testPoll() { - WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedSemaphore weightedSemaphore = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue = WeightedBoundedQueue.create(weightedSemaphore); int insertedValue1 = 1; int insertedValue2 = 2; @@ -95,7 +99,7 @@ public void testPoll() { queue.put(insertedValue1); queue.put(insertedValue2); - assertEquals(insertedValue1 + insertedValue2, queue.queuedElementsWeight()); + assertEquals(insertedValue1 + insertedValue2, weightedSemaphore.currentWeight()); assertEquals(2, queue.size()); assertEquals(insertedValue1, (int) queue.poll()); assertEquals(1, queue.size()); @@ -104,7 +108,8 @@ public void testPoll() { @Test public void testPoll_withTimeout() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); int pollWaitTimeMillis = 10000; int insertedValue1 = 1; @@ -132,7 +137,8 @@ public void testPoll_withTimeout() throws InterruptedException { @Test public void testPoll_withTimeout_timesOut() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); int defaultPollResult = -10; int pollWaitTimeMillis = 100; int insertedValue1 = 1; @@ -144,13 +150,17 @@ public void testPoll_withTimeout_timesOut() throws InterruptedException { Thread pollThread = new Thread( () -> { - int polled; + @Nullable Integer polled; try { polled = queue.poll(pollWaitTimeMillis, TimeUnit.MILLISECONDS); - pollResult.set(polled); + if (polled != null) { + pollResult.set(polled); + } } catch (InterruptedException e) { throw new RuntimeException(e); } + + assertNull(polled); }); pollThread.start(); @@ -164,7 +174,8 @@ public void testPoll_withTimeout_timesOut() throws InterruptedException { @Test public void testPoll_emptyQueue() { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); assertNull(queue.poll()); } @@ -172,7 +183,8 @@ public void testPoll_emptyQueue() { @Test public void testTake() throws InterruptedException { WeightedBoundedQueue queue = - WeightedBoundedQueue.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue.create( + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i))); AtomicInteger value = new AtomicInteger(); // Should block until value is available @@ -194,4 +206,39 @@ public void testTake() throws InterruptedException { assertEquals(MAX_WEIGHT, value.get()); } + + @Test + public void testPut_sharedWeigher() throws InterruptedException { + WeightedSemaphore weigher = + WeightedSemaphore.create(MAX_WEIGHT, i -> Math.min(MAX_WEIGHT, i)); + WeightedBoundedQueue queue1 = WeightedBoundedQueue.create(weigher); + WeightedBoundedQueue queue2 = WeightedBoundedQueue.create(weigher); + + // Insert value that takes all the weight into the queue1. + queue1.put(MAX_WEIGHT); + + // Try to insert a value into the queue2. This will block since there is no capacity in the + // weigher. + Thread putThread = new Thread(() -> queue2.put(MAX_WEIGHT)); + putThread.start(); + // Should only see the first value in the queue, since the queue is at capacity. putThread + // should be blocked. The weight should be the same however, since queue1 and queue2 are sharing + // the weigher. + Thread.sleep(100); + assertEquals(MAX_WEIGHT, weigher.currentWeight()); + assertEquals(1, queue1.size()); + assertEquals(0, queue2.size()); + + // Poll queue1, pulling off the only value inside and freeing up the capacity in the weigher. + queue1.poll(); + + // Wait for the putThread which was previously blocked due to the weigher being at capacity. + putThread.join(); + + assertEquals(MAX_WEIGHT, weigher.currentWeight()); + assertEquals(1, queue2.size()); + queue2.poll(); + assertEquals(0, queue2.size()); + assertEquals(0, weigher.currentWeight()); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 546a2883e3b2..c05a4dd340dd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -121,6 +121,7 @@ public void setUp() throws IOException { private WorkCommitter createWorkCommitter(Consumer onCommitComplete) { return StreamingEngineWorkCommitter.builder() + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) .setCommitWorkStreamFactory(commitWorkStreamFactory) .setOnCommitComplete(onCommitComplete) .build(); @@ -342,6 +343,7 @@ public void testMultipleCommitSendersSingleStream() { Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); workCommitter = StreamingEngineWorkCommitter.builder() + .setCommitByteSemaphore(Commits.maxCommitByteSemaphore()) .setCommitWorkStreamFactory(commitWorkStreamFactory) .setNumCommitSenders(5) .setOnCommitComplete(completeCommits::add) From 8e61c18b752d01ffe7eb835183aebe5d9a13e908 Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Wed, 6 Nov 2024 09:30:58 -0500 Subject: [PATCH 125/181] Add Flush Interval to default Buffered Logger (#33009) * Add Flush Interval to default Buffered Logger in Python boot.go * have buffered logger set default value * restore deleted line --- sdks/go/container/tools/buffered_logging.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sdks/go/container/tools/buffered_logging.go b/sdks/go/container/tools/buffered_logging.go index 445d19fabfdc..a7b84e56af3a 100644 --- a/sdks/go/container/tools/buffered_logging.go +++ b/sdks/go/container/tools/buffered_logging.go @@ -18,13 +18,15 @@ package tools import ( "context" "log" - "math" "os" "strings" "time" ) -const initialLogSize int = 255 +const ( + initialLogSize int = 255 + defaultFlushInterval time.Duration = 15 * time.Second +) // BufferedLogger is a wrapper around the FnAPI logging client meant to be used // in place of stdout and stderr in bootloader subprocesses. Not intended for @@ -41,7 +43,7 @@ type BufferedLogger struct { // NewBufferedLogger returns a new BufferedLogger type by reference. func NewBufferedLogger(logger *Logger) *BufferedLogger { - return &BufferedLogger{logger: logger, lastFlush: time.Now(), flushInterval: time.Duration(math.MaxInt64), periodicFlushContext: context.Background(), now: time.Now} + return &BufferedLogger{logger: logger, lastFlush: time.Now(), flushInterval: defaultFlushInterval, periodicFlushContext: context.Background(), now: time.Now} } // NewBufferedLoggerWithFlushInterval returns a new BufferedLogger type by reference. This type will From 36c19a324289c85c6c44dbf6920bcd31589ad265 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 Nov 2024 09:47:18 -0500 Subject: [PATCH 126/181] Bump github.com/aws/aws-sdk-go-v2/credentials in /sdks (#32996) Bumps [github.com/aws/aws-sdk-go-v2/credentials](https://github.com/aws/aws-sdk-go-v2) from 1.17.41 to 1.17.42. - [Release notes](https://github.com/aws/aws-sdk-go-v2/releases) - [Commits](https://github.com/aws/aws-sdk-go-v2/compare/credentials/v1.17.41...credentials/v1.17.42) --- updated-dependencies: - dependency-name: github.com/aws/aws-sdk-go-v2/credentials dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- sdks/go.mod | 18 +++++++++--------- sdks/go.sum | 36 ++++++++++++++++++------------------ 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index 81221f98e276..0f35782d52bf 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -30,9 +30,9 @@ require ( cloud.google.com/go/pubsub v1.44.0 cloud.google.com/go/spanner v1.70.0 cloud.google.com/go/storage v1.45.0 - github.com/aws/aws-sdk-go-v2 v1.32.2 + github.com/aws/aws-sdk-go-v2 v1.32.3 github.com/aws/aws-sdk-go-v2/config v1.28.0 - github.com/aws/aws-sdk-go-v2/credentials v1.17.41 + github.com/aws/aws-sdk-go-v2/credentials v1.17.42 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.33 github.com/aws/aws-sdk-go-v2/service/s3 v1.66.0 github.com/aws/smithy-go v1.22.0 @@ -132,18 +132,18 @@ require ( github.com/apache/thrift v0.17.0 // indirect github.com/aws/aws-sdk-go v1.34.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.22 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.22 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.3 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.24.2 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.32.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.3 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.32.3 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/sdks/go.sum b/sdks/go.sum index a45baf72a02b..0c6a74211be9 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -689,26 +689,26 @@ github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZve github.com/aws/aws-sdk-go v1.34.0 h1:brux2dRrlwCF5JhTL7MUT3WUwo9zfDHZZp3+g3Mvlmo= github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go-v2 v1.7.1/go.mod h1:L5LuPC1ZgDr2xQS7AmIec/Jlc7O/Y1u2KxJyNVab250= -github.com/aws/aws-sdk-go-v2 v1.32.2 h1:AkNLZEyYMLnx/Q/mSKkcMqwNFXMAvFto9bNsHqcTduI= -github.com/aws/aws-sdk-go-v2 v1.32.2/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= +github.com/aws/aws-sdk-go-v2 v1.32.3 h1:T0dRlFBKcdaUPGNtkBSwHZxrtis8CQU17UpNBZYd0wk= +github.com/aws/aws-sdk-go-v2 v1.32.3/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 h1:pT3hpW0cOHRJx8Y0DfJUEQuqPild8jRGmSFmBgvydr0= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6/go.mod h1:j/I2++U0xX+cr44QjHay4Cvxj6FUbnxrgmqN3H1jTZA= github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA= github.com/aws/aws-sdk-go-v2/config v1.28.0 h1:FosVYWcqEtWNxHn8gB/Vs6jOlNwSoyOCA/g/sxyySOQ= github.com/aws/aws-sdk-go-v2/config v1.28.0/go.mod h1:pYhbtvg1siOOg8h5an77rXle9tVG8T+BWLWAo7cOukc= github.com/aws/aws-sdk-go-v2/credentials v1.3.1/go.mod h1:r0n73xwsIVagq8RsxmZbGSRQFj9As3je72C2WzUIToc= -github.com/aws/aws-sdk-go-v2/credentials v1.17.41 h1:7gXo+Axmp+R4Z+AK8YFQO0ZV3L0gizGINCOWxSLY9W8= -github.com/aws/aws-sdk-go-v2/credentials v1.17.41/go.mod h1:u4Eb8d3394YLubphT4jLEwN1rLNq2wFOlT6OuxFwPzU= +github.com/aws/aws-sdk-go-v2/credentials v1.17.42 h1:sBP0RPjBU4neGpIYyx8mkU2QqLPl5u9cmdTWVzIpHkM= +github.com/aws/aws-sdk-go-v2/credentials v1.17.42/go.mod h1:FwZBfU530dJ26rv9saAbxa9Ej3eF/AK0OAY86k13n4M= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.3.0/go.mod h1:2LAuqPx1I6jNfaGDucWfA2zqQCYCOMCDHiCOciALyNw= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17 h1:TMH3f/SCAWdNtXXVPPu5D6wrr4G5hI1rAxbcocKfC7Q= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.17/go.mod h1:1ZRXLdTpzdJb9fwTMXiLipENRxkGMTn1sfKexGllQCw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.18 h1:68jFVtt3NulEzojFesM/WVarlFpCaXLKaBxDpzkQ9OQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.18/go.mod h1:Fjnn5jQVIo6VyedMc0/EhPpfNlPl7dHV916O6B+49aE= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.3.2/go.mod h1:qaqQiHSrOUVOfKe6fhgQ6UzhxjwqVW8aHNegd6Ws4w4= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.33 h1:X+4YY5kZRI/cOoSMVMGTqFXHAMg1bvvay7IBcqHpybQ= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.33/go.mod h1:DPynzu+cn92k5UQ6tZhX+wfTB4ah6QDU/NgdHqatmvk= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 h1:UAsR3xA31QGf79WzpG/ixT9FZvQlh5HY1NRqSHBNOCk= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21/go.mod h1:JNr43NFf5L9YaG3eKTm7HQzls9J+A9YYcGI5Quh1r2Y= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21 h1:6jZVETqmYCadGFvrYEQfC5fAQmlo80CeL5psbno6r0s= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21/go.mod h1:1SR0GbLlnN3QUmYaflZNiH1ql+1qrSiB2vwcJ+4UM60= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.22 h1:Jw50LwEkVjuVzE1NzkhNKkBf9cRN7MtE1F/b2cOKTUM= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.22/go.mod h1:Y/SmAyPcOTmpeVaWSzSKiILfXTVJwrGmYZhcRbhWuEY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.22 h1:981MHwBaRZM7+9QSR6XamDzF/o7ouUGxFzr+nVSIhrs= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.22/go.mod h1:1RA1+aBEfn+CAB/Mh0MB6LsdCYCnjZm7tKXtnk499ZQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.1.1/go.mod h1:Zy8smImhTdOETZqfyn01iNOe0CNggVbPjCajyaz6Gvg= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= @@ -720,8 +720,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0/go.mod h1: github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2 h1:4FMHqLfk0efmTqhXVRL5xYRqlEBNBiRI7N6w4jsEdd4= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2/go.mod h1:LWoqeWlK9OZeJxsROW2RqrSPvQHKTpp69r/iDjwsSaw= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.1/go.mod h1:zceowr5Z1Nh2WVP8bf/3ikB41IZW59E4yIYbg+pC6mw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2 h1:s7NA1SOw8q/5c0wr8477yOPp0z+uBaXBnLE0XYb0POA= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.2/go.mod h1:fnjjWyAW/Pj5HYOxl9LJqWtEwS7W2qgcRLWP+uWbss0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.3 h1:qcxX0JYlgWH3hpPUnd6U0ikcl6LLA9sLkXE2w1fpMvY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.3/go.mod h1:cLSNEmI45soc+Ef8K/L+8sEA3A3pYFEYf5B5UI+6bH4= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.1/go.mod h1:6EQZIwNNvHpq/2/QSJnp4+ECvqIy55w95Ofs0ze+nGQ= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2 h1:t7iUP9+4wdc5lt3E41huP+GvQZJD38WLsgVp4iOtAjg= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2/go.mod h1:/niFCtmuQNxqx9v8WAPq5qh7EH25U4BF6tjoyq9bObM= @@ -729,13 +729,13 @@ github.com/aws/aws-sdk-go-v2/service/s3 v1.11.1/go.mod h1:XLAGFrEjbvMCLvAtWLLP32 github.com/aws/aws-sdk-go-v2/service/s3 v1.66.0 h1:xA6XhTF7PE89BCNHJbQi8VvPzcgMtmGC5dr8S8N7lHk= github.com/aws/aws-sdk-go-v2/service/s3 v1.66.0/go.mod h1:cB6oAuus7YXRZhWCc1wIwPywwZ1XwweNp2TVAEGYeB8= github.com/aws/aws-sdk-go-v2/service/sso v1.3.1/go.mod h1:J3A3RGUvuCZjvSuZEcOpHDnzZP/sKbhDWV2T1EOzFIM= -github.com/aws/aws-sdk-go-v2/service/sso v1.24.2 h1:bSYXVyUzoTHoKalBmwaZxs97HU9DWWI3ehHSAMa7xOk= -github.com/aws/aws-sdk-go-v2/service/sso v1.24.2/go.mod h1:skMqY7JElusiOUjMJMOv1jJsP7YUg7DrhgqZZWuzu1U= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2 h1:AhmO1fHINP9vFYUE0LHzCWg/LfUWUF+zFPEcY9QXb7o= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.2/go.mod h1:o8aQygT2+MVP0NaV6kbdE1YnnIM8RRVQzoeUH45GOdI= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.3 h1:UTpsIf0loCIWEbrqdLb+0RxnTXfWh2vhw4nQmFi4nPc= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.3/go.mod h1:FZ9j3PFHHAR+w0BSEjK955w5YD2UwB/l/H0yAK3MJvI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.3 h1:2YCmIXv3tmiItw0LlYf6v7gEHebLY45kBEnPezbUKyU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.3/go.mod h1:u19stRyNPxGhj6dRm+Cdgu6N75qnbW7+QN0q0dsAk58= github.com/aws/aws-sdk-go-v2/service/sts v1.6.0/go.mod h1:q7o0j7d7HrJk/vr9uUt3BVRASvcU7gYZB9PUgPiByXg= -github.com/aws/aws-sdk-go-v2/service/sts v1.32.2 h1:CiS7i0+FUe+/YY1GvIBLLrR/XNGZ4CtM1Ll0XavNuVo= -github.com/aws/aws-sdk-go-v2/service/sts v1.32.2/go.mod h1:HtaiBI8CjYoNVde8arShXb94UbQQi9L4EMr6D+xGBwo= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.3 h1:wVnQ6tigGsRqSWDEEyH6lSAJ9OyFUsSnbaUWChuSGzs= +github.com/aws/aws-sdk-go-v2/service/sts v1.32.3/go.mod h1:VZa9yTFyj4o10YGsmDO4gbQJUvvhY72fhumT8W4LqsE= github.com/aws/smithy-go v1.6.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/aws/smithy-go v1.22.0 h1:uunKnWlcoL3zO7q+gG2Pk53joueEOsnNB28QdMsmiMM= github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= From deeddd1d9a4c57c594af02bbc4c2dd8dc6f71fb5 Mon Sep 17 00:00:00 2001 From: Steven van Rossum Date: Wed, 6 Nov 2024 16:38:47 +0100 Subject: [PATCH 127/181] [KafkaIO] Determine partition backlog using endOffsets instead of seek2End and position (#32889) * Determine partition backlog using endOffsets instead of seekToEnd and position * Remove offset consumer assignments * Explicitly update partitions and start/end offsets for relevant mock consumers * Clean up partition and offset updates in tests --- .../meta/provider/kafka/KafkaTestTable.java | 14 ++--- .../sdk/io/kafka/KafkaUnboundedReader.java | 32 ++++++----- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 10 +++- .../apache/beam/sdk/io/kafka/KafkaIOTest.java | 22 +++---- .../apache/beam/sdk/io/kafka/KafkaMocks.java | 57 +++++++------------ .../sdk/io/kafka/ReadFromKafkaDoFnTest.java | 4 ++ 6 files changed, 66 insertions(+), 73 deletions(-) diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestTable.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestTable.java index 372e77c54c67..d0f6427a262e 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestTable.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestTable.java @@ -36,7 +36,6 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.MockConsumer; @@ -138,10 +137,6 @@ public synchronized void assign(final Collection assigned) { .collect(Collectors.toList()); super.assign(realPartitions); assignedPartitions.set(ImmutableList.copyOf(realPartitions)); - for (TopicPartition tp : realPartitions) { - updateBeginningOffsets(ImmutableMap.of(tp, 0L)); - updateEndOffsets(ImmutableMap.of(tp, (long) kafkaRecords.get(tp).size())); - } } // Override offsetsForTimes() in order to look up the offsets by timestamp. @Override @@ -163,9 +158,12 @@ public synchronized Map offsetsForTimes( } }; - for (String topic : getTopics()) { - consumer.updatePartitions(topic, partitionInfoMap.get(topic)); - } + partitionInfoMap.forEach(consumer::updatePartitions); + consumer.updateBeginningOffsets( + kafkaRecords.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> 0L))); + consumer.updateEndOffsets( + kafkaRecords.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size()))); Runnable recordEnqueueTask = new Runnable() { diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java index 6ce6c7d5d233..d86a5d0ce686 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java @@ -138,7 +138,6 @@ public boolean start() throws IOException { name, spec.getOffsetConsumerConfig(), spec.getConsumerConfig()); offsetConsumer = spec.getConsumerFactoryFn().apply(offsetConsumerConfig); - ConsumerSpEL.evaluateAssign(offsetConsumer, topicPartitions); // Fetch offsets once before running periodically. updateLatestOffsets(); @@ -711,23 +710,28 @@ private void setupInitialOffset(PartitionState pState) { // Called from setupInitialOffset() at the start and then periodically from offsetFetcher thread. private void updateLatestOffsets() { Consumer offsetConsumer = Preconditions.checkStateNotNull(this.offsetConsumer); - for (PartitionState p : partitionStates) { - try { - Instant fetchTime = Instant.now(); - ConsumerSpEL.evaluateSeek2End(offsetConsumer, p.topicPartition); - long offset = offsetConsumer.position(p.topicPartition); - p.setLatestOffset(offset, fetchTime); - } catch (Exception e) { - if (closed.get()) { // Ignore the exception if the reader is closed. - break; - } + List topicPartitions = + Preconditions.checkStateNotNull(source.getSpec().getTopicPartitions()); + Instant fetchTime = Instant.now(); + try { + Map endOffsets = offsetConsumer.endOffsets(topicPartitions); + for (PartitionState p : partitionStates) { + p.setLatestOffset( + Preconditions.checkStateNotNull( + endOffsets.get(p.topicPartition), + "No end offset found for partition %s.", + p.topicPartition), + fetchTime); + } + } catch (Exception e) { + if (!closed.get()) { // Ignore the exception if the reader is closed. LOG.warn( - "{}: exception while fetching latest offset for partition {}. will be retried.", + "{}: exception while fetching latest offset for partitions {}. will be retried.", this, - p.topicPartition, + topicPartitions, e); - // Don't update the latest offset. } + // Don't update the latest offset. } LOG.debug("{}: backlog {}", this, getSplitBacklogBytes()); diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 4bda8cf28d4e..fe2d7a64a37f 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -19,6 +19,7 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -253,13 +254,16 @@ private static class KafkaLatestOffsetEstimator Consumer offsetConsumer, TopicPartition topicPartition) { this.offsetConsumer = offsetConsumer; this.topicPartition = topicPartition; - ConsumerSpEL.evaluateAssign(this.offsetConsumer, ImmutableList.of(this.topicPartition)); memoizedBacklog = Suppliers.memoizeWithExpiration( () -> { synchronized (offsetConsumer) { - ConsumerSpEL.evaluateSeek2End(offsetConsumer, topicPartition); - return offsetConsumer.position(topicPartition); + return Preconditions.checkStateNotNull( + offsetConsumer + .endOffsets(Collections.singleton(topicPartition)) + .get(topicPartition), + "No end offset found for partition %s.", + topicPartition); } }, 1, diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java index 764e406f71cb..e614320db150 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java @@ -77,7 +77,7 @@ import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.io.kafka.KafkaIO.Read.FakeFlinkPipelineOptions; -import org.apache.beam.sdk.io.kafka.KafkaMocks.PositionErrorConsumerFactory; +import org.apache.beam.sdk.io.kafka.KafkaMocks.EndOffsetErrorConsumerFactory; import org.apache.beam.sdk.io.kafka.KafkaMocks.SendErrorProducerFactory; import org.apache.beam.sdk.metrics.DistributionResult; import org.apache.beam.sdk.metrics.Lineage; @@ -267,10 +267,6 @@ private static MockConsumer mkMockConsumer( public synchronized void assign(final Collection assigned) { super.assign(assigned); assignedPartitions.set(ImmutableList.copyOf(assigned)); - for (TopicPartition tp : assigned) { - updateBeginningOffsets(ImmutableMap.of(tp, 0L)); - updateEndOffsets(ImmutableMap.of(tp, (long) records.get(tp).size())); - } } // Override offsetsForTimes() in order to look up the offsets by timestamp. @Override @@ -290,9 +286,12 @@ public synchronized Map offsetsForTimes( } }; - for (String topic : topics) { - consumer.updatePartitions(topic, partitionMap.get(topic)); - } + partitionMap.forEach(consumer::updatePartitions); + consumer.updateBeginningOffsets( + records.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> 0L))); + consumer.updateEndOffsets( + records.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size()))); // MockConsumer does not maintain any relationship between partition seek position and the // records added. e.g. if we add 10 records to a partition and then seek to end of the @@ -1525,13 +1524,14 @@ public void testUnboundedReaderLogsCommitFailure() throws Exception { List topics = ImmutableList.of("topic_a"); - PositionErrorConsumerFactory positionErrorConsumerFactory = new PositionErrorConsumerFactory(); + EndOffsetErrorConsumerFactory endOffsetErrorConsumerFactory = + new EndOffsetErrorConsumerFactory(); UnboundedSource, KafkaCheckpointMark> source = KafkaIO.read() .withBootstrapServers("myServer1:9092,myServer2:9092") .withTopics(topics) - .withConsumerFactoryFn(positionErrorConsumerFactory) + .withConsumerFactoryFn(endOffsetErrorConsumerFactory) .withKeyDeserializer(IntegerDeserializer.class) .withValueDeserializer(LongDeserializer.class) .makeSource(); @@ -1540,7 +1540,7 @@ public void testUnboundedReaderLogsCommitFailure() throws Exception { reader.start(); - unboundedReaderExpectedLogs.verifyWarn("exception while fetching latest offset for partition"); + unboundedReaderExpectedLogs.verifyWarn("exception while fetching latest offset for partitions"); reader.close(); } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMocks.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMocks.java index 0844d71e7105..1303f1da3bcd 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMocks.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaMocks.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.kafka; import java.io.Serializable; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; @@ -27,8 +28,8 @@ import org.apache.beam.sdk.values.KV; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; -import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; import org.apache.kafka.clients.producer.Callback; import org.apache.kafka.clients.producer.MockProducer; import org.apache.kafka.clients.producer.Producer; @@ -66,51 +67,33 @@ public Producer apply(Map input) { } } - public static final class PositionErrorConsumer extends MockConsumer { - - public PositionErrorConsumer() { - super(null); - } - - @Override - public synchronized long position(TopicPartition partition) { - throw new KafkaException("fakeException"); - } - - @Override - public synchronized List partitionsFor(String topic) { - return Collections.singletonList( - new PartitionInfo("topic_a", 1, new Node(1, "myServer1", 9092), null, null)); - } - } - - public static final class PositionErrorConsumerFactory + public static final class EndOffsetErrorConsumerFactory implements SerializableFunction, Consumer> { - public PositionErrorConsumerFactory() {} + public EndOffsetErrorConsumerFactory() {} @Override public MockConsumer apply(Map input) { + final MockConsumer consumer; if (input.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) { - return new PositionErrorConsumer(); - } else { - MockConsumer consumer = - new MockConsumer(null) { + consumer = + new MockConsumer(OffsetResetStrategy.EARLIEST) { @Override - public synchronized long position(TopicPartition partition) { - return 1L; - } - - @Override - public synchronized ConsumerRecords poll(long timeout) { - return ConsumerRecords.empty(); + public synchronized Map endOffsets( + Collection partitions) { + throw new KafkaException("fakeException"); } }; - consumer.updatePartitions( - "topic_a", - Collections.singletonList( - new PartitionInfo("topic_a", 1, new Node(1, "myServer1", 9092), null, null))); - return consumer; + } else { + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST); } + consumer.updatePartitions( + "topic_a", + Collections.singletonList( + new PartitionInfo("topic_a", 1, new Node(1, "myServer1", 9092), null, null))); + consumer.updateBeginningOffsets( + Collections.singletonMap(new TopicPartition("topic_a", 1), 0L)); + consumer.updateEndOffsets(Collections.singletonMap(new TopicPartition("topic_a", 1), 0L)); + return consumer; } } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java index 3189bbb140f0..52c141685760 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java @@ -205,6 +205,8 @@ public SimpleMockKafkaConsumer( OffsetResetStrategy offsetResetStrategy, TopicPartition topicPartition) { super(offsetResetStrategy); this.topicPartition = topicPartition; + updateBeginningOffsets(ImmutableMap.of(topicPartition, 0L)); + updateEndOffsets(ImmutableMap.of(topicPartition, Long.MAX_VALUE)); } public void reset() { @@ -214,6 +216,8 @@ public void reset() { this.startOffsetForTime = KV.of(0L, Instant.now()); this.stopOffsetForTime = KV.of(Long.MAX_VALUE, null); this.numOfRecordsPerPoll = 0L; + updateBeginningOffsets(ImmutableMap.of(topicPartition, 0L)); + updateEndOffsets(ImmutableMap.of(topicPartition, Long.MAX_VALUE)); } public void setRemoved() { From bf2574bd962d439723aba7a60ac502c6c78ec1ae Mon Sep 17 00:00:00 2001 From: Steven van Rossum Date: Wed, 6 Nov 2024 16:39:12 +0100 Subject: [PATCH 128/181] [KafkaIO] Remove unused property, assignment in finalize will not be observed (#32920) * Remove unused property, assignment in finalize will not be observed * Remove unobservable use of isClosed --- .../org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index fe2d7a64a37f..7c2064883488 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -248,7 +248,6 @@ private static class KafkaLatestOffsetEstimator private final Consumer offsetConsumer; private final TopicPartition topicPartition; private final Supplier memoizedBacklog; - private boolean closed; KafkaLatestOffsetEstimator( Consumer offsetConsumer, TopicPartition topicPartition) { @@ -274,7 +273,6 @@ private static class KafkaLatestOffsetEstimator protected void finalize() { try { Closeables.close(offsetConsumer, true); - closed = true; LOG.info("Offset Estimator consumer was closed for {}", topicPartition); } catch (Exception anyException) { LOG.warn("Failed to close offset consumer for {}", topicPartition); @@ -285,10 +283,6 @@ protected void finalize() { public long estimate() { return memoizedBacklog.get(); } - - public boolean isClosed() { - return closed; - } } @GetInitialRestriction @@ -377,7 +371,7 @@ public OffsetRangeTracker restrictionTracker( TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition(); KafkaLatestOffsetEstimator offsetEstimator = offsetEstimatorCacheInstance.get(topicPartition); - if (offsetEstimator == null || offsetEstimator.isClosed()) { + if (offsetEstimator == null) { Map updatedConsumerConfig = overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); From ad2af883219a587d7c80a6e477d859a265579316 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 6 Nov 2024 11:58:20 -0500 Subject: [PATCH 129/181] Add buildSrc to trigger path of Java PreCommit (#33029) --- .github/workflows/beam_PreCommit_Java.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/beam_PreCommit_Java.yml b/.github/workflows/beam_PreCommit_Java.yml index 772eab98c343..20dafca72a57 100644 --- a/.github/workflows/beam_PreCommit_Java.yml +++ b/.github/workflows/beam_PreCommit_Java.yml @@ -19,6 +19,7 @@ on: tags: ['v*'] branches: ['master', 'release-*'] paths: + - "buildSrc/**" - 'model/**' - 'sdks/java/**' - 'runners/**' From 13049a5857c88a1a9bfaca05b968f94ec498d09c Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 6 Nov 2024 11:58:37 -0500 Subject: [PATCH 130/181] Revert "Upgrade antlr from 4.7 to 4.13.1 (#33016)" (#33028) This reverts commit 9baa7ba0182b50adeecfe1b97b215a3d0f4a39bd. --- CHANGES.md | 1 - .../groovy/org/apache/beam/gradle/BeamModulePlugin.groovy | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index cdedce22e975..261fafc024f3 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -88,7 +88,6 @@ * Removed support for Flink 1.15 and 1.16 * Removed support for Python 3.8 * X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). -* Upgrade antlr from 4.7 to 4.13.1 ([#33016](https://github.com/apache/beam/pull/33016)). ## Bugfixes diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 8d8bf9339c6e..5af91ec2f056 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -665,8 +665,8 @@ class BeamModulePlugin implements Plugin { activemq_junit : "org.apache.activemq.tooling:activemq-junit:$activemq_version", activemq_kahadb_store : "org.apache.activemq:activemq-kahadb-store:$activemq_version", activemq_mqtt : "org.apache.activemq:activemq-mqtt:$activemq_version", - antlr : "org.antlr:antlr4:4.13.1", - antlr_runtime : "org.antlr:antlr4-runtime:4.13.1", + antlr : "org.antlr:antlr4:4.7", + antlr_runtime : "org.antlr:antlr4-runtime:4.7", args4j : "args4j:args4j:2.33", auto_value_annotations : "com.google.auto.value:auto-value-annotations:$autovalue_version", avro : "org.apache.avro:avro:1.11.3", From eeebae1bda6b211463e53a4e4ca469bfa9763399 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 6 Nov 2024 12:54:08 -0500 Subject: [PATCH 131/181] Add Kafka 3 to and remove Kafka 0.x and 1.x from compatibility test (#32981) * Add Kafka 3.1.2 IT target Signed-off-by: Jeffrey Kinard * Remove kafka 0.x and 1.x from compatibilty test --------- Signed-off-by: Jeffrey Kinard Co-authored-by: Jeffrey Kinard --- sdks/java/io/kafka/build.gradle | 7 +----- sdks/java/io/kafka/kafka-01103/build.gradle | 24 --------------------- sdks/java/io/kafka/kafka-100/build.gradle | 24 --------------------- sdks/java/io/kafka/kafka-111/build.gradle | 24 --------------------- settings.gradle.kts | 8 ++----- 5 files changed, 3 insertions(+), 84 deletions(-) delete mode 100644 sdks/java/io/kafka/kafka-01103/build.gradle delete mode 100644 sdks/java/io/kafka/kafka-100/build.gradle delete mode 100644 sdks/java/io/kafka/kafka-111/build.gradle diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle index ec4654bd88df..c2f056b0b7cb 100644 --- a/sdks/java/io/kafka/build.gradle +++ b/sdks/java/io/kafka/build.gradle @@ -35,9 +35,6 @@ ext { } def kafkaVersions = [ - '01103': "0.11.0.3", - '100': "1.0.0", - '111': "1.1.1", '201': "2.0.1", '211': "2.1.1", '222': "2.2.2", @@ -139,15 +136,13 @@ task kafkaVersionsCompatibilityTest { description = 'Runs KafkaIO with different Kafka client APIs' def testNames = createTestList(kafkaVersions, "Test") dependsOn testNames - dependsOn (":sdks:java:io:kafka:kafka-01103:kafkaVersion01103BatchIT") - dependsOn (":sdks:java:io:kafka:kafka-100:kafkaVersion100BatchIT") - dependsOn (":sdks:java:io:kafka:kafka-111:kafkaVersion111BatchIT") dependsOn (":sdks:java:io:kafka:kafka-201:kafkaVersion201BatchIT") dependsOn (":sdks:java:io:kafka:kafka-211:kafkaVersion211BatchIT") dependsOn (":sdks:java:io:kafka:kafka-222:kafkaVersion222BatchIT") dependsOn (":sdks:java:io:kafka:kafka-231:kafkaVersion231BatchIT") dependsOn (":sdks:java:io:kafka:kafka-241:kafkaVersion241BatchIT") dependsOn (":sdks:java:io:kafka:kafka-251:kafkaVersion251BatchIT") + dependsOn (":sdks:java:io:kafka:kafka-312:kafkaVersion312BatchIT") } static def createTestList(Map prefixMap, String suffix) { diff --git a/sdks/java/io/kafka/kafka-01103/build.gradle b/sdks/java/io/kafka/kafka-01103/build.gradle deleted file mode 100644 index 3a74bf04ef22..000000000000 --- a/sdks/java/io/kafka/kafka-01103/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -project.ext { - delimited="0.11.0.3" - undelimited="01103" - sdfCompatible=false -} - -apply from: "../kafka-integration-test.gradle" \ No newline at end of file diff --git a/sdks/java/io/kafka/kafka-100/build.gradle b/sdks/java/io/kafka/kafka-100/build.gradle deleted file mode 100644 index bd5fa67b1cfc..000000000000 --- a/sdks/java/io/kafka/kafka-100/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -project.ext { - delimited="1.0.0" - undelimited="100" - sdfCompatible=false -} - -apply from: "../kafka-integration-test.gradle" diff --git a/sdks/java/io/kafka/kafka-111/build.gradle b/sdks/java/io/kafka/kafka-111/build.gradle deleted file mode 100644 index c2b0c8f82827..000000000000 --- a/sdks/java/io/kafka/kafka-111/build.gradle +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -project.ext { - delimited="1.1.1" - undelimited="111" - sdfCompatible=false -} - -apply from: "../kafka-integration-test.gradle" \ No newline at end of file diff --git a/settings.gradle.kts b/settings.gradle.kts index a38f69dac09e..ca30a5ea750a 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -333,6 +333,8 @@ project(":beam-test-gha").projectDir = file(".github") include("beam-validate-runner") project(":beam-validate-runner").projectDir = file(".test-infra/validate-runner") include("com.google.api.gax.batching") +include("sdks:java:io:kafka:kafka-312") +findProject(":sdks:java:io:kafka:kafka-312")?.name = "kafka-312" include("sdks:java:io:kafka:kafka-251") findProject(":sdks:java:io:kafka:kafka-251")?.name = "kafka-251" include("sdks:java:io:kafka:kafka-241") @@ -345,12 +347,6 @@ include("sdks:java:io:kafka:kafka-211") findProject(":sdks:java:io:kafka:kafka-211")?.name = "kafka-211" include("sdks:java:io:kafka:kafka-201") findProject(":sdks:java:io:kafka:kafka-201")?.name = "kafka-201" -include("sdks:java:io:kafka:kafka-111") -findProject(":sdks:java:io:kafka:kafka-111")?.name = "kafka-111" -include("sdks:java:io:kafka:kafka-100") -findProject(":sdks:java:io:kafka:kafka-100")?.name = "kafka-100" -include("sdks:java:io:kafka:kafka-01103") -findProject(":sdks:java:io:kafka:kafka-01103")?.name = "kafka-01103" include("sdks:java:managed") findProject(":sdks:java:managed")?.name = "managed" include("sdks:java:io:iceberg") From 81f35ab62298a2ec9fadeded82461b363b6401db Mon Sep 17 00:00:00 2001 From: Damon Date: Wed, 6 Nov 2024 12:06:52 -0800 Subject: [PATCH 132/181] Distroless python sdk (#32960) * Enable Python distroless container image variants * Fix missing entrypoint * Revert testing using validatescontainer.sh * Create validateDistrolessContainerTests * Refactor for reusable gradle methods * Revert back * Finalize gradle * Migrate distroless build to its own gradle task * Remove gradle distroless build task * Add base target * Build docker image directly in test * Revert back to using plugin --- sdks/python/container/Dockerfile | 26 ++++++++++- sdks/python/container/common.gradle | 9 +++- sdks/python/test-suites/dataflow/build.gradle | 6 +++ .../python/test-suites/dataflow/common.gradle | 45 +++++++++++++++++++ sdks/python/test-suites/gradle.properties | 3 ++ 5 files changed, 87 insertions(+), 2 deletions(-) diff --git a/sdks/python/container/Dockerfile b/sdks/python/container/Dockerfile index 7bea6229668f..f3d22a4b5bc6 100644 --- a/sdks/python/container/Dockerfile +++ b/sdks/python/container/Dockerfile @@ -103,9 +103,33 @@ RUN if [ "$pull_licenses" = "true" ] ; then \ python /tmp/license_scripts/pull_licenses_py.py ; \ fi -FROM beam +FROM beam as base ARG pull_licenses COPY --from=third_party_licenses /opt/apache/beam/third_party_licenses /opt/apache/beam/third_party_licenses RUN if [ "$pull_licenses" != "true" ] ; then \ rm -rf /opt/apache/beam/third_party_licenses ; \ fi + +ARG TARGETARCH +FROM gcr.io/distroless/python3-debian12:latest-${TARGETARCH} as distroless +ARG py_version + +# Contains header files needed by the Python interpreter. +COPY --from=base /usr/local/include /usr/local/include + +# Contains the Python interpreter executables. +COPY --from=base /usr/local/bin /usr/local/bin + +# Contains the Python library dependencies. +COPY --from=base /usr/local/lib /usr/local/lib + +# Python standard library modules. +COPY --from=base /usr/lib/python${py_version} /usr/lib/python${py_version} + +# Contains the boot entrypoint and related files such as licenses. +COPY --from=base /opt /opt + +ENV PATH "$PATH:/usr/local/bin" + +# Despite the ENTRYPOINT set above, need to reset since deriving the layer derives from a different image. +ENTRYPOINT ["/opt/apache/beam/boot"] diff --git a/sdks/python/container/common.gradle b/sdks/python/container/common.gradle index 0175778a6301..885662362894 100644 --- a/sdks/python/container/common.gradle +++ b/sdks/python/container/common.gradle @@ -71,10 +71,16 @@ def copyLauncherDependencies = tasks.register("copyLauncherDependencies", Copy) } def pushContainers = project.rootProject.hasProperty(["isRelease"]) || project.rootProject.hasProperty("push-containers") +def baseBuildTarget = 'base' +def buildTarget = project.findProperty('container-build-target') ?: 'base' +var imageName = project.docker_image_default_repo_prefix + "python${project.ext.pythonVersion}_sdk" +if (buildTarget != baseBuildTarget) { + imageName += "_${buildTarget}" +} docker { name containerImageName( - name: project.docker_image_default_repo_prefix + "python${project.ext.pythonVersion}_sdk", + name: imageName, root: project.rootProject.hasProperty(["docker-repository-root"]) ? project.rootProject["docker-repository-root"] : project.docker_image_default_repo_root, @@ -90,6 +96,7 @@ docker { platform(*project.containerPlatforms()) load project.useBuildx() && !pushContainers push pushContainers + target buildTarget } dockerPrepare.dependsOn copyLauncherDependencies diff --git a/sdks/python/test-suites/dataflow/build.gradle b/sdks/python/test-suites/dataflow/build.gradle index 04a79683fd36..4500b395b0a6 100644 --- a/sdks/python/test-suites/dataflow/build.gradle +++ b/sdks/python/test-suites/dataflow/build.gradle @@ -60,6 +60,12 @@ task validatesContainerTests { } } +task validatesDistrolessContainerTests { + getVersionsAsList('distroless_python_versions').each { + dependsOn.add(":sdks:python:test-suites:dataflow:py${getVersionSuffix(it)}:validatesDistrolessContainer") + } +} + task examplesPostCommit { getVersionsAsList('dataflow_examples_postcommit_py_versions').each { dependsOn.add(":sdks:python:test-suites:dataflow:py${getVersionSuffix(it)}:examples") diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 71d44652bc7e..cd0db4a62f77 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -380,6 +380,51 @@ task validatesContainer() { } } +/** + * Validates the distroless (https://github.com/GoogleContainerTools/distroless) variant of the Python SDK container + * image (sdks/python/container/Dockerfile). + * To test a single version of Python: + * ./gradlew :sdks:python:test-suites:dataflow:py311:validatesDistrolessContainer + * See https://cwiki.apache.org/confluence/display/BEAM/Python+Tips#PythonTips-VirtualEnvironmentSetup + * for more information on setting up different Python versions. + */ +task validatesDistrolessContainer() { + def pyversion = "${project.ext.pythonVersion.replace('.', '')}" + def buildTarget = 'distroless' + def repository = "us.gcr.io/apache-beam-testing/${System.getenv('USER')}" + def tag = java.time.Instant.now().getEpochSecond() + def imageURL = "${repository}/beam_python${project.ext.pythonVersion}_sdk_${buildTarget}:${tag}" + project.rootProject.ext['docker-repository-root'] = repository + project.rootProject.ext['container-build-target'] = buildTarget + project.rootProject.ext['docker-tag'] = tag + if (project.rootProject.hasProperty('dry-run')) { + println "Running in dry run mode: imageURL: ${imageURL}, pyversion: ${pyversion}, buildTarget: ${buildTarget}, repository: ${repository}, tag: ${tag}, envdir: ${envdir}" + return + } + dependsOn 'initializeForDataflowJob' + dependsOn ":sdks:python:container:py${pyversion}:docker" + dependsOn ":sdks:python:container:py${pyversion}:dockerPush" + def testTarget = "apache_beam/examples/wordcount_it_test.py::WordCountIT::test_wordcount_it" + def argMap = [ + "output": "gs://temp-storage-for-end-to-end-tests/py-it-cloud/output", + "project": "apache-beam-testing", + "region": "us-central1", + "runner": "TestDataflowRunner", + "sdk_container_image": "${imageURL}", + "sdk_location": "container", + "staging_location": "gs://temp-storage-for-end-to-end-tests/staging-it", + "temp_location": "gs://temp-storage-for-end-to-end-tests/temp-it", + ] + def cmdArgs = mapToArgString(argMap) + doLast { + exec { + workingDir = "${rootDir}/sdks/python" + executable 'sh' + args '-c', ". ${envdir}/bin/activate && pytest ${testTarget} --test-pipeline-options=\"${cmdArgs}\"" + } + } +} + task validatesContainerARM() { def pyversion = "${project.ext.pythonVersion.replace('.', '')}" dependsOn 'initializeForDataflowJob' diff --git a/sdks/python/test-suites/gradle.properties b/sdks/python/test-suites/gradle.properties index d027cd3144d3..08266c4b0dd5 100644 --- a/sdks/python/test-suites/gradle.properties +++ b/sdks/python/test-suites/gradle.properties @@ -54,3 +54,6 @@ prism_examples_postcommit_py_versions=3.9,3.12 # cross language postcommit python test suites cross_language_validates_py_versions=3.9,3.12 + +# Python versions to support distroless variants +distroless_python_versions=3.9,3.10,3.11,3.12 From e598df7e3ccb4543d02e7d4b9c524d3f471228cd Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 6 Nov 2024 22:46:45 -0500 Subject: [PATCH 133/181] Remove antlr from dependency from java-core pom (#33030) * Remove antlr from dependency of jaca-core pom * Fix to use java-core shadowjar as source to build direct-runner-java --- runners/direct-java/build.gradle | 22 ++++++++++++---------- sdks/java/core/build.gradle | 1 - 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/runners/direct-java/build.gradle b/runners/direct-java/build.gradle index c357b8a04328..404b864c9c31 100644 --- a/runners/direct-java/build.gradle +++ b/runners/direct-java/build.gradle @@ -22,12 +22,12 @@ plugins { id 'org.apache.beam.module' } // Shade away runner execution utilities till because this causes ServiceLoader conflicts with // TransformPayloadTranslatorRegistrar amongst other runners. This only happens in the DirectRunner // because it is likely to appear on the classpath of another runner. -def dependOnProjects = [ - ":runners:core-java", - ":runners:local-java", - ":runners:java-fn-execution", - ":sdks:java:core", - ] +def dependOnProjectsAndConfigs = [ + ":runners:core-java":null, + ":runners:local-java":null, + ":runners:java-fn-execution":null, + ":sdks:java:core":"shadow", +] applyJavaNature( automaticModuleName: 'org.apache.beam.runners.direct', @@ -36,8 +36,8 @@ applyJavaNature( ], shadowClosure: { dependencies { - dependOnProjects.each { - include(project(path: it, configuration: "shadow")) + dependOnProjectsAndConfigs.each { + include(project(path: it.key, configuration: "shadow")) } } }, @@ -63,8 +63,10 @@ configurations { dependencies { shadow library.java.vendored_guava_32_1_2_jre shadow project(path: ":model:pipeline", configuration: "shadow") - dependOnProjects.each { - implementation project(it) + dependOnProjectsAndConfigs.each { + // For projects producing shadowjar, use the packaged jar as dependency to + // handle redirected packages from it + implementation project(path: it.key, configuration: it.value) } shadow library.java.vendored_grpc_1_60_1 shadow library.java.joda_time diff --git a/sdks/java/core/build.gradle b/sdks/java/core/build.gradle index e150c22de62d..07144c8de053 100644 --- a/sdks/java/core/build.gradle +++ b/sdks/java/core/build.gradle @@ -81,7 +81,6 @@ dependencies { shadow library.java.vendored_grpc_1_60_1 shadow library.java.vendored_guava_32_1_2_jre shadow library.java.byte_buddy - shadow library.java.antlr_runtime shadow library.java.commons_compress shadow library.java.commons_lang3 testImplementation library.java.mockito_inline From 9ed818732814b7b265ff2d843da5b59d1823ff59 Mon Sep 17 00:00:00 2001 From: tvalentyn Date: Thu, 7 Nov 2024 06:41:58 -0800 Subject: [PATCH 134/181] Update run_inference_vllm.ipynb (#33032) Enable verbose logging for Dataflow launch. --- examples/notebooks/beam-ml/run_inference_vllm.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/notebooks/beam-ml/run_inference_vllm.ipynb b/examples/notebooks/beam-ml/run_inference_vllm.ipynb index fea953bc1e66..5729ac200782 100644 --- a/examples/notebooks/beam-ml/run_inference_vllm.ipynb +++ b/examples/notebooks/beam-ml/run_inference_vllm.ipynb @@ -484,6 +484,7 @@ " def process(self, element, *args, **kwargs):\n", " yield \"Input: {input}, Output: {output}\".format(input=element.example, output=element.inference)\n", "\n", + "logging.getLogger().setLevel(logging.INFO) # Output additional Dataflow Job metadata and launch logs. \n", "prompts = [\n", " \"Hello, my name is\",\n", " \"The president of the United States is\",\n", From 3a841b1ec53df94d34a32e8efc613db3c604bde8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 09:48:43 -0500 Subject: [PATCH 135/181] Bump cloud.google.com/go/pubsub from 1.44.0 to 1.45.1 in /sdks (#32952) Bumps [cloud.google.com/go/pubsub](https://github.com/googleapis/google-cloud-go) from 1.44.0 to 1.45.1. - [Release notes](https://github.com/googleapis/google-cloud-go/releases) - [Changelog](https://github.com/googleapis/google-cloud-go/blob/main/CHANGES.md) - [Commits](https://github.com/googleapis/google-cloud-go/compare/pubsub/v1.44.0...pubsub/v1.45.1) --- updated-dependencies: - dependency-name: cloud.google.com/go/pubsub dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- sdks/go.mod | 2 +- sdks/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index 0f35782d52bf..369b0c8d13a4 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -27,7 +27,7 @@ require ( cloud.google.com/go/bigtable v1.33.0 cloud.google.com/go/datastore v1.19.0 cloud.google.com/go/profiler v0.4.1 - cloud.google.com/go/pubsub v1.44.0 + cloud.google.com/go/pubsub v1.45.1 cloud.google.com/go/spanner v1.70.0 cloud.google.com/go/storage v1.45.0 github.com/aws/aws-sdk-go-v2 v1.32.3 diff --git a/sdks/go.sum b/sdks/go.sum index 0c6a74211be9..071c318d1241 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -451,8 +451,8 @@ cloud.google.com/go/pubsub v1.26.0/go.mod h1:QgBH3U/jdJy/ftjPhTkyXNj543Tin1pRYcd cloud.google.com/go/pubsub v1.27.1/go.mod h1:hQN39ymbV9geqBnfQq6Xf63yNhUAhv9CZhzp5O6qsW0= cloud.google.com/go/pubsub v1.28.0/go.mod h1:vuXFpwaVoIPQMGXqRyUQigu/AX1S3IWugR9xznmcXX8= cloud.google.com/go/pubsub v1.30.0/go.mod h1:qWi1OPS0B+b5L+Sg6Gmc9zD1Y+HaM0MdUr7LsupY1P4= -cloud.google.com/go/pubsub v1.44.0 h1:pLaMJVDTlnUDIKT5L0k53YyLszfBbGoUBo/IqDK/fEI= -cloud.google.com/go/pubsub v1.44.0/go.mod h1:BD4a/kmE8OePyHoa1qAHEw1rMzXX+Pc8Se54T/8mc3I= +cloud.google.com/go/pubsub v1.45.1 h1:ZC/UzYcrmK12THWn1P72z+Pnp2vu/zCZRXyhAfP1hJY= +cloud.google.com/go/pubsub v1.45.1/go.mod h1:3bn7fTmzZFwaUjllitv1WlsNMkqBgGUb3UdMhI54eCc= cloud.google.com/go/pubsublite v1.5.0/go.mod h1:xapqNQ1CuLfGi23Yda/9l4bBCKz/wC3KIJ5gKcxveZg= cloud.google.com/go/pubsublite v1.6.0/go.mod h1:1eFCS0U11xlOuMFV/0iBqw3zP12kddMeCbj/F3FSj9k= cloud.google.com/go/pubsublite v1.7.0/go.mod h1:8hVMwRXfDfvGm3fahVbtDbiLePT3gpoiJYJY+vxWxVM= From 3c8bb54147f7220e5c4ae4534d2b5adfe03d47f7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 7 Nov 2024 10:45:20 -0500 Subject: [PATCH 136/181] Bump github.com/aws/aws-sdk-go-v2/service/s3 in /sdks (#33036) Bumps [github.com/aws/aws-sdk-go-v2/service/s3](https://github.com/aws/aws-sdk-go-v2) from 1.66.0 to 1.66.3. - [Release notes](https://github.com/aws/aws-sdk-go-v2/releases) - [Commits](https://github.com/aws/aws-sdk-go-v2/compare/service/s3/v1.66.0...service/s3/v1.66.3) --- updated-dependencies: - dependency-name: github.com/aws/aws-sdk-go-v2/service/s3 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- sdks/go.mod | 16 ++++++++-------- sdks/go.sum | 32 ++++++++++++++++---------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index 369b0c8d13a4..15ab3248ba70 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -30,11 +30,11 @@ require ( cloud.google.com/go/pubsub v1.45.1 cloud.google.com/go/spanner v1.70.0 cloud.google.com/go/storage v1.45.0 - github.com/aws/aws-sdk-go-v2 v1.32.3 + github.com/aws/aws-sdk-go-v2 v1.32.4 github.com/aws/aws-sdk-go-v2/config v1.28.0 github.com/aws/aws-sdk-go-v2/credentials v1.17.42 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.33 - github.com/aws/aws-sdk-go-v2/service/s3 v1.66.0 + github.com/aws/aws-sdk-go-v2/service/s3 v1.66.3 github.com/aws/smithy-go v1.22.0 github.com/docker/go-connections v0.5.0 github.com/dustin/go-humanize v1.0.1 @@ -133,14 +133,14 @@ require ( github.com/aws/aws-sdk-go v1.34.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.18 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.22 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.22 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.23 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.23 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.3 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.4 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.24.3 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.3 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.32.3 // indirect diff --git a/sdks/go.sum b/sdks/go.sum index 071c318d1241..61f9a1b586fb 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -689,8 +689,8 @@ github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZve github.com/aws/aws-sdk-go v1.34.0 h1:brux2dRrlwCF5JhTL7MUT3WUwo9zfDHZZp3+g3Mvlmo= github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go-v2 v1.7.1/go.mod h1:L5LuPC1ZgDr2xQS7AmIec/Jlc7O/Y1u2KxJyNVab250= -github.com/aws/aws-sdk-go-v2 v1.32.3 h1:T0dRlFBKcdaUPGNtkBSwHZxrtis8CQU17UpNBZYd0wk= -github.com/aws/aws-sdk-go-v2 v1.32.3/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= +github.com/aws/aws-sdk-go-v2 v1.32.4 h1:S13INUiTxgrPueTmrm5DZ+MiAo99zYzHEFh1UNkOxNE= +github.com/aws/aws-sdk-go-v2 v1.32.4/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 h1:pT3hpW0cOHRJx8Y0DfJUEQuqPild8jRGmSFmBgvydr0= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6/go.mod h1:j/I2++U0xX+cr44QjHay4Cvxj6FUbnxrgmqN3H1jTZA= github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA= @@ -705,29 +705,29 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.18/go.mod h1:Fjnn5jQVIo6Vyed github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.3.2/go.mod h1:qaqQiHSrOUVOfKe6fhgQ6UzhxjwqVW8aHNegd6Ws4w4= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.33 h1:X+4YY5kZRI/cOoSMVMGTqFXHAMg1bvvay7IBcqHpybQ= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.33/go.mod h1:DPynzu+cn92k5UQ6tZhX+wfTB4ah6QDU/NgdHqatmvk= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.22 h1:Jw50LwEkVjuVzE1NzkhNKkBf9cRN7MtE1F/b2cOKTUM= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.22/go.mod h1:Y/SmAyPcOTmpeVaWSzSKiILfXTVJwrGmYZhcRbhWuEY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.22 h1:981MHwBaRZM7+9QSR6XamDzF/o7ouUGxFzr+nVSIhrs= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.22/go.mod h1:1RA1+aBEfn+CAB/Mh0MB6LsdCYCnjZm7tKXtnk499ZQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23 h1:A2w6m6Tmr+BNXjDsr7M90zkWjsu4JXHwrzPg235STs4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23/go.mod h1:35EVp9wyeANdujZruvHiQUAo9E3vbhnIO1mTCAxMlY0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.23 h1:pgYW9FCabt2M25MoHYCfMrVY2ghiiBKYWUVXfwZs+sU= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.23/go.mod h1:c48kLgzO19wAu3CPkDWC28JbaJ+hfQlsdl7I2+oqIbk= github.com/aws/aws-sdk-go-v2/internal/ini v1.1.1/go.mod h1:Zy8smImhTdOETZqfyn01iNOe0CNggVbPjCajyaz6Gvg= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21 h1:7edmS3VOBDhK00b/MwGtGglCm7hhwNYnjJs/PgFdMQE= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.21/go.mod h1:Q9o5h4HoIWG8XfzxqiuK/CGUbepCJ8uTlaE3bAbxytQ= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.23 h1:1SZBDiRzzs3sNhOMVApyWPduWYGAX0imGy06XiBnCAM= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.23/go.mod h1:i9TkxgbZmHVh2S0La6CAXtnyFhlCX/pJ0JsOvBAS6Mk= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.1/go.mod h1:v33JQ57i2nekYTA70Mb+O18KeH4KqhdqxTJZNK1zdRE= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 h1:TToQNkvGguu209puTojY/ozlqy2d/SFNcoLIqTFi42g= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0/go.mod h1:0jp+ltwkf+SwG2fm/PKo8t4y8pJSgOCO4D8Lz3k0aHQ= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2 h1:4FMHqLfk0efmTqhXVRL5xYRqlEBNBiRI7N6w4jsEdd4= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.2/go.mod h1:LWoqeWlK9OZeJxsROW2RqrSPvQHKTpp69r/iDjwsSaw= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.4 h1:aaPpoG15S2qHkWm4KlEyF01zovK1nW4BBbyXuHNSE90= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.4/go.mod h1:eD9gS2EARTKgGr/W5xwgY/ik9z/zqpW+m/xOQbVxrMk= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.1/go.mod h1:zceowr5Z1Nh2WVP8bf/3ikB41IZW59E4yIYbg+pC6mw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.3 h1:qcxX0JYlgWH3hpPUnd6U0ikcl6LLA9sLkXE2w1fpMvY= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.3/go.mod h1:cLSNEmI45soc+Ef8K/L+8sEA3A3pYFEYf5B5UI+6bH4= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4 h1:tHxQi/XHPK0ctd/wdOw0t7Xrc2OxcRCnVzv8lwWPu0c= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4/go.mod h1:4GQbF1vJzG60poZqWatZlhP31y8PGCCVTvIGPdaaYJ0= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.1/go.mod h1:6EQZIwNNvHpq/2/QSJnp4+ECvqIy55w95Ofs0ze+nGQ= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2 h1:t7iUP9+4wdc5lt3E41huP+GvQZJD38WLsgVp4iOtAjg= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.2/go.mod h1:/niFCtmuQNxqx9v8WAPq5qh7EH25U4BF6tjoyq9bObM= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.4 h1:E5ZAVOmI2apR8ADb72Q63KqwwwdW1XcMeXIlrZ1Psjg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.4/go.mod h1:wezzqVUOVVdk+2Z/JzQT4NxAU0NbhRe5W8pIE72jsWI= github.com/aws/aws-sdk-go-v2/service/s3 v1.11.1/go.mod h1:XLAGFrEjbvMCLvAtWLLP32yTv8GpBquCApZEycDLunI= -github.com/aws/aws-sdk-go-v2/service/s3 v1.66.0 h1:xA6XhTF7PE89BCNHJbQi8VvPzcgMtmGC5dr8S8N7lHk= -github.com/aws/aws-sdk-go-v2/service/s3 v1.66.0/go.mod h1:cB6oAuus7YXRZhWCc1wIwPywwZ1XwweNp2TVAEGYeB8= +github.com/aws/aws-sdk-go-v2/service/s3 v1.66.3 h1:neNOYJl72bHrz9ikAEED4VqWyND/Po0DnEx64RW6YM4= +github.com/aws/aws-sdk-go-v2/service/s3 v1.66.3/go.mod h1:TMhLIyRIyoGVlaEMAt+ITMbwskSTpcGsCPDq91/ihY0= github.com/aws/aws-sdk-go-v2/service/sso v1.3.1/go.mod h1:J3A3RGUvuCZjvSuZEcOpHDnzZP/sKbhDWV2T1EOzFIM= github.com/aws/aws-sdk-go-v2/service/sso v1.24.3 h1:UTpsIf0loCIWEbrqdLb+0RxnTXfWh2vhw4nQmFi4nPc= github.com/aws/aws-sdk-go-v2/service/sso v1.24.3/go.mod h1:FZ9j3PFHHAR+w0BSEjK955w5YD2UwB/l/H0yAK3MJvI= From c7e4db76a8796639b411bb9f260a87a17248316f Mon Sep 17 00:00:00 2001 From: scwhittle Date: Thu, 7 Nov 2024 18:13:57 +0100 Subject: [PATCH 137/181] Add equals/hashcode to test data class to avoid direct runner warning log (#32883) --- .../apache/beam/sdk/transforms/WithKeysTest.java | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java index 296a53f48e80..fd178f8e7649 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Objects; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -226,5 +227,19 @@ public long getNum() { public String getStr() { return this.str; } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Pojo)) { + return false; + } + Pojo pojo = (Pojo) o; + return num == pojo.num && Objects.equals(str, pojo.str); + } + + @Override + public int hashCode() { + return Objects.hash(num, str); + } } } From 674249919c51236041b635a3f864a5ba595d916e Mon Sep 17 00:00:00 2001 From: tvalentyn Date: Thu, 7 Nov 2024 11:54:34 -0800 Subject: [PATCH 138/181] Use sys.executable to find python command. (#33033) * Use sys.executable to find python command. * Make linter happy --------- Co-authored-by: Danny McCormick --- sdks/python/apache_beam/ml/inference/vllm_inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index b86d33ec16b1..799083d16ceb 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -21,6 +21,7 @@ import logging import os import subprocess +import sys import threading import time import uuid @@ -118,7 +119,7 @@ def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]): def start_server(self, retries=3): if not self._server_started: server_cmd = [ - 'python', + sys.executable, '-m', 'vllm.entrypoints.openai.api_server', '--model', From 5ebfd82c4ce9b533fe27d27bfcc2ddfbdecff532 Mon Sep 17 00:00:00 2001 From: Jeffrey Kinard Date: Thu, 7 Nov 2024 17:21:49 -0500 Subject: [PATCH 139/181] [yaml] SpannerIO docs and minor improvments Signed-off-by: Jeffrey Kinard --- .../beam/sdk/io/gcp/spanner/SpannerIO.java | 4 +- .../SpannerReadSchemaTransformProvider.java | 137 ++++++++++++------ .../SpannerWriteSchemaTransformProvider.java | 77 ++++------ sdks/python/apache_beam/yaml/standard_io.yaml | 2 + 4 files changed, 130 insertions(+), 90 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index d9dde11a3081..a6cf7ebb12a5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -291,8 +291,8 @@ * grouped into batches. The default maximum size of the batch is set to 1MB or 5000 mutated cells, * or 500 rows (whichever is reached first). To override this use {@link * Write#withBatchSizeBytes(long) withBatchSizeBytes()}, {@link Write#withMaxNumMutations(long) - * withMaxNumMutations()} or {@link Write#withMaxNumMutations(long) withMaxNumRows()}. Setting - * either to a small value or zero disables batching. + * withMaxNumMutations()} or {@link Write#withMaxNumRows(long) withMaxNumRows()}. Setting either to + * a small value or zero disables batching. * *

Note that the maximum diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java index 5cd9cb47b696..76440b1ebf1a 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java @@ -41,6 +41,7 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +/** A provider for reading from Cloud Spanner using a Schema Transform Provider. */ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) @@ -54,43 +55,81 @@ * *

The transformation leverages the {@link SpannerIO} to perform the read operation and maps the * results to Beam rows, preserving the schema. - * - *

Example usage in a YAML pipeline using query: - * - *

{@code
- * pipeline:
- *   transforms:
- *     - type: ReadFromSpanner
- *       name: ReadShipments
- *       # Columns: shipment_id, customer_id, shipment_date, shipment_cost, customer_name, customer_email
- *       config:
- *         project_id: 'apache-beam-testing'
- *         instance_id: 'shipment-test'
- *         database_id: 'shipment'
- *         query: 'SELECT * FROM shipments'
- * }
- * - *

Example usage in a YAML pipeline using a table and columns: - * - *

{@code
- * pipeline:
- *   transforms:
- *     - type: ReadFromSpanner
- *       name: ReadShipments
- *       # Columns: shipment_id, customer_id, shipment_date, shipment_cost, customer_name, customer_email
- *       config:
- *         project_id: 'apache-beam-testing'
- *         instance_id: 'shipment-test'
- *         database_id: 'shipment'
- *         table: 'shipments'
- *         columns: ['customer_id', 'customer_name']
- * }
*/ @AutoService(SchemaTransformProvider.class) public class SpannerReadSchemaTransformProvider extends TypedSchemaTransformProvider< SpannerReadSchemaTransformProvider.SpannerReadSchemaTransformConfiguration> { + @Override + public String identifier() { + return "beam:schematransform:org.apache.beam:spanner_read:v1"; + } + + @Override + public String description() { + return "Performs a Bulk read from Google Cloud Spanner using a specified SQL query or " + + "by directly accessing a single table and its columns.\n" + + "\n" + + "Both Query and Read APIs are supported. See more information about " + + "
reading from Cloud Spanner.\n" + + "\n" + + "Example configuration for performing a read using a SQL query: ::\n" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " query: 'SELECT * FROM table'\n" + + "\n" + + "It is also possible to read a table by specifying a table name and a list of columns. For " + + "example, the following configuration will perform a read on an entire table: ::\n" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " table: 'my-table'\n" + + " columns: ['col1', 'col2']\n" + + "\n" + + "Additionally, to read using a " + + "Secondary Index, specify the index name: ::" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " table: 'my-table'\n" + + " index: 'my-index'\n" + + " columns: ['col1', 'col2']\n" + + "\n" + + "### Advanced Usage\n" + + "\n" + + "Reads by default use the " + + "PartitionQuery API which enforces some limitations on the type of queries that can be used so that " + + "the data can be read in parallel. If the query is not supported by the PartitionQuery API, then you " + + "can specify a non-partitioned read by setting batching to false.\n" + + "\n" + + "For example: ::" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " batching: false\n" + + " ...\n" + + "\n" + + "Note: See " + + "SpannerIO for more advanced information."; + } + static class SpannerSchemaTransformRead extends SchemaTransform implements Serializable { private final SpannerReadSchemaTransformConfiguration configuration; @@ -113,6 +152,12 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } else { read = read.withTable(configuration.getTableId()).withColumns(configuration.getColumns()); } + if (!Strings.isNullOrEmpty(configuration.getIndex())) { + read = read.withIndex(configuration.getIndex()); + } + if (Boolean.FALSE.equals(configuration.getBatching())) { + read = read.withBatching(false); + } PCollection spannerRows = input.getPipeline().apply(read); Schema schema = spannerRows.getSchema(); PCollection rows = @@ -124,11 +169,6 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } } - @Override - public String identifier() { - return "beam:schematransform:org.apache.beam:spanner_read:v1"; - } - @Override public List inputCollectionNames() { return Collections.emptyList(); @@ -157,6 +197,10 @@ public abstract static class Builder { public abstract Builder setColumns(List columns); + public abstract Builder setIndex(String index); + + public abstract Builder setBatching(Boolean batching); + public abstract SpannerReadSchemaTransformConfiguration build(); } @@ -193,16 +237,16 @@ public static Builder builder() { .Builder(); } - @SchemaFieldDescription("Specifies the GCP project ID.") - @Nullable - public abstract String getProjectId(); - @SchemaFieldDescription("Specifies the Cloud Spanner instance.") public abstract String getInstanceId(); @SchemaFieldDescription("Specifies the Cloud Spanner database.") public abstract String getDatabaseId(); + @SchemaFieldDescription("Specifies the GCP project ID.") + @Nullable + public abstract String getProjectId(); + @SchemaFieldDescription("Specifies the Cloud Spanner table.") @Nullable public abstract String getTableId(); @@ -211,9 +255,20 @@ public static Builder builder() { @Nullable public abstract String getQuery(); - @SchemaFieldDescription("Specifies the columns to read from the table.") + @SchemaFieldDescription( + "Specifies the columns to read from the table. This parameter is required when table is specified.") @Nullable public abstract List getColumns(); + + @SchemaFieldDescription( + "Specifies the Index to read from. This parameter can only be specified when using table.") + @Nullable + public abstract String getIndex(); + + @SchemaFieldDescription( + "Set to false to disable batching. Useful when using a query that is not compatible with the PartitionQuery API. Defaults to true.") + @Nullable + public abstract Boolean getBatching(); } @Override diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java index 9f079c78f886..8601da09ea09 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteSchemaTransformProvider.java @@ -67,49 +67,37 @@ *

The transformation uses the {@link SpannerIO} to perform the write operation and provides * options to handle failed mutations, either by throwing an error, or passing the failed mutation * further in the pipeline for dealing with accordingly. - * - *

Example usage in a YAML pipeline without error handling: - * - *

{@code
- * pipeline:
- *   transforms:
- *     - type: WriteToSpanner
- *       name: WriteShipments
- *       config:
- *         project_id: 'apache-beam-testing'
- *         instance_id: 'shipment-test'
- *         database_id: 'shipment'
- *         table_id: 'shipments'
- *
- * }
- * - *

Example usage in a YAML pipeline using error handling: - * - *

{@code
- * pipeline:
- *   transforms:
- *     - type: WriteToSpanner
- *       name: WriteShipments
- *       config:
- *         project_id: 'apache-beam-testing'
- *         instance_id: 'shipment-test'
- *         database_id: 'shipment'
- *         table_id: 'shipments'
- *         error_handling:
- *           output: 'errors'
- *
- *     - type: WriteToJson
- *       input: WriteSpanner.my_error_output
- *       config:
- *          path: errors.json
- *
- * }
*/ @AutoService(SchemaTransformProvider.class) public class SpannerWriteSchemaTransformProvider extends TypedSchemaTransformProvider< SpannerWriteSchemaTransformProvider.SpannerWriteSchemaTransformConfiguration> { + @Override + public String identifier() { + return "beam:schematransform:org.apache.beam:spanner_write:v1"; + } + + @Override + public String description() { + return "Performs a bulk write to a Google Cloud Spanner table.\n" + + "\n" + + "Example configuration for performing a write to a single table: ::\n" + + "\n" + + " pipeline:\n" + + " transforms:\n" + + " - type: ReadFromSpanner\n" + + " config:\n" + + " project_id: 'my-project-id'\n" + + " instance_id: 'my-instance-id'\n" + + " database_id: 'my-database'\n" + + " table: 'my-table'\n" + + "\n" + + "Note: See " + + "SpannerIO for more advanced information."; + } + @Override protected Class configurationClass() { return SpannerWriteSchemaTransformConfiguration.class; @@ -225,11 +213,6 @@ public PCollectionRowTuple expand(@NonNull PCollectionRowTuple input) { } } - @Override - public String identifier() { - return "beam:schematransform:org.apache.beam:spanner_write:v1"; - } - @Override public List inputCollectionNames() { return Collections.singletonList("input"); @@ -244,10 +227,6 @@ public List outputCollectionNames() { @DefaultSchema(AutoValueSchema.class) public abstract static class SpannerWriteSchemaTransformConfiguration implements Serializable { - @SchemaFieldDescription("Specifies the GCP project.") - @Nullable - public abstract String getProjectId(); - @SchemaFieldDescription("Specifies the Cloud Spanner instance.") public abstract String getInstanceId(); @@ -257,7 +236,11 @@ public abstract static class SpannerWriteSchemaTransformConfiguration implements @SchemaFieldDescription("Specifies the Cloud Spanner table.") public abstract String getTableId(); - @SchemaFieldDescription("Specifies how to handle errors.") + @SchemaFieldDescription("Specifies the GCP project.") + @Nullable + public abstract String getProjectId(); + + @SchemaFieldDescription("Whether and how to handle write errors.") @Nullable public abstract ErrorHandling getErrorHandling(); diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml b/sdks/python/apache_beam/yaml/standard_io.yaml index 4de36b3dc9e0..400ab07a41fa 100644 --- a/sdks/python/apache_beam/yaml/standard_io.yaml +++ b/sdks/python/apache_beam/yaml/standard_io.yaml @@ -271,6 +271,8 @@ table: 'table_id' query: 'query' columns: 'columns' + index: 'index' + batching: 'batching' 'WriteToSpanner': project: 'project_id' instance: 'instance_id' From a797dc254a2f5c0ca148033aecc0c21b72171dfe Mon Sep 17 00:00:00 2001 From: Jeffrey Kinard Date: Thu, 7 Nov 2024 19:00:01 -0500 Subject: [PATCH 140/181] [yaml] Normalize error_handling docs on API catalog Signed-off-by: Jeffrey Kinard --- .../apache_beam/yaml/generate_yaml_docs.py | 24 +++++++++++++++++++ sdks/python/apache_beam/yaml/yaml_mapping.py | 5 ++++ 2 files changed, 29 insertions(+) diff --git a/sdks/python/apache_beam/yaml/generate_yaml_docs.py b/sdks/python/apache_beam/yaml/generate_yaml_docs.py index 84a5e62f0abd..4088e17afe2c 100644 --- a/sdks/python/apache_beam/yaml/generate_yaml_docs.py +++ b/sdks/python/apache_beam/yaml/generate_yaml_docs.py @@ -20,13 +20,17 @@ import itertools import re +import docstring_parser import yaml from apache_beam.portability.api import schema_pb2 +from apache_beam.typehints import schemas from apache_beam.utils import subprocess_server +from apache_beam.utils.python_callable import PythonCallableWithSource from apache_beam.version import __version__ as beam_version from apache_beam.yaml import json_utils from apache_beam.yaml import yaml_provider +from apache_beam.yaml.yaml_mapping import ErrorHandlingConfig def _singular(name): @@ -135,8 +139,28 @@ def maybe_row_parameters(t): def maybe_optional(t): return " (Optional)" if t.nullable else "" + def normalize_error_handling(f): + doc = docstring_parser.parse( + ErrorHandlingConfig.__doc__, docstring_parser.DocstringStyle.GOOGLE) + if f.name == "error_handling": + f = schema_pb2.Field( + name="error_handling", + type=schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + fields=[ + schemas.schema_field( + param.arg_name, + PythonCallableWithSource.load_from_expression( + param.type_name), + param.description) for param in doc.params + ]))), + description=f.description) + return f + def lines(): for f in schema.fields: + f = normalize_error_handling(f) yield ''.join([ f'**{f.name}** `{pretty_type(f.type)}`', maybe_optional(f.type), diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 960fcdeecf30..5c14b0f5ea79 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -418,6 +418,11 @@ def checking_func(row): class ErrorHandlingConfig(NamedTuple): + """Class to define Error Handling parameters. + + Args: + output (str): Name to use for the output error collection + """ output: str # TODO: Other parameters are valid here too, but not common to Java. From e5cdd5bfd0bbef9deb226ef58a611bbd632752b8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 8 Nov 2024 11:24:53 -0500 Subject: [PATCH 141/181] Bump golang.org/x/sys from 0.26.0 to 0.27.0 in /sdks (#33049) Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.26.0 to 0.27.0. - [Commits](https://github.com/golang/sys/compare/v0.26.0...v0.27.0) --- updated-dependencies: - dependency-name: golang.org/x/sys dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- sdks/go.mod | 2 +- sdks/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index 15ab3248ba70..3a5a891a88af 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -56,7 +56,7 @@ require ( golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.23.0 golang.org/x/sync v0.8.0 - golang.org/x/sys v0.26.0 + golang.org/x/sys v0.27.0 golang.org/x/text v0.19.0 google.golang.org/api v0.203.0 google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53 diff --git a/sdks/go.sum b/sdks/go.sum index 61f9a1b586fb..b31e1db22fef 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -1524,8 +1524,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= From 159c1a4f67aa8b634fdb4ab03afde432ce8ed0d2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 8 Nov 2024 11:28:02 -0500 Subject: [PATCH 142/181] Bump cloud.google.com/go/datastore from 1.19.0 to 1.20.0 in /sdks (#33051) Bumps [cloud.google.com/go/datastore](https://github.com/googleapis/google-cloud-go) from 1.19.0 to 1.20.0. - [Release notes](https://github.com/googleapis/google-cloud-go/releases) - [Changelog](https://github.com/googleapis/google-cloud-go/blob/main/documentai/CHANGES.md) - [Commits](https://github.com/googleapis/google-cloud-go/compare/kms/v1.19.0...kms/v1.20.0) --- updated-dependencies: - dependency-name: cloud.google.com/go/datastore dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- sdks/go.mod | 2 +- sdks/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index 3a5a891a88af..ff711cbe91b0 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -25,7 +25,7 @@ go 1.21.0 require ( cloud.google.com/go/bigquery v1.63.1 cloud.google.com/go/bigtable v1.33.0 - cloud.google.com/go/datastore v1.19.0 + cloud.google.com/go/datastore v1.20.0 cloud.google.com/go/profiler v0.4.1 cloud.google.com/go/pubsub v1.45.1 cloud.google.com/go/spanner v1.70.0 diff --git a/sdks/go.sum b/sdks/go.sum index b31e1db22fef..c24cb10126c8 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -240,8 +240,8 @@ cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7 cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/datastore v1.10.0/go.mod h1:PC5UzAmDEkAmkfaknstTYbNpgE49HAgW2J1gcgUfmdM= cloud.google.com/go/datastore v1.11.0/go.mod h1:TvGxBIHCS50u8jzG+AW/ppf87v1of8nwzFNgEZU1D3c= -cloud.google.com/go/datastore v1.19.0 h1:p5H3bUQltOa26GcMRAxPoNwoqGkq5v8ftx9/ZBB35MI= -cloud.google.com/go/datastore v1.19.0/go.mod h1:KGzkszuj87VT8tJe67GuB+qLolfsOt6bZq/KFuWaahc= +cloud.google.com/go/datastore v1.20.0 h1:NNpXoyEqIJmZFc0ACcwBEaXnmscUpcG4NkKnbCePmiM= +cloud.google.com/go/datastore v1.20.0/go.mod h1:uFo3e+aEpRfHgtp5pp0+6M0o147KoPaYNaPAKpfh8Ew= cloud.google.com/go/datastream v1.2.0/go.mod h1:i/uTP8/fZwgATHS/XFu0TcNUhuA0twZxxQ3EyCUQMwo= cloud.google.com/go/datastream v1.3.0/go.mod h1:cqlOX8xlyYF/uxhiKn6Hbv6WjwPPuI9W2M9SAXwaLLQ= cloud.google.com/go/datastream v1.4.0/go.mod h1:h9dpzScPhDTs5noEMQVWP8Wx8AFBRyS0s8KWPx/9r0g= From 5851415587481fee66f1abcac32f613160cf9fa4 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Fri, 8 Nov 2024 14:37:40 -0500 Subject: [PATCH 143/181] Don't pin to latest tag (#33058) * Don't pin to latest tag Its usually better to suggest a custom tag to avoid overwriting containers from old jobs * Apply suggestions from code review Co-authored-by: tvalentyn --------- Co-authored-by: tvalentyn --- examples/notebooks/beam-ml/run_inference_vllm.ipynb | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/notebooks/beam-ml/run_inference_vllm.ipynb b/examples/notebooks/beam-ml/run_inference_vllm.ipynb index 5729ac200782..e9f1e53a452b 100644 --- a/examples/notebooks/beam-ml/run_inference_vllm.ipynb +++ b/examples/notebooks/beam-ml/run_inference_vllm.ipynb @@ -352,11 +352,11 @@ "\n", "1. In the sidebar, click **Files** to open the **Files** pane.\n", "2. In an environment with Docker installed, download the file **VllmDockerfile** file to an empty folder.\n", - "3. Run the following commands. Replace `` with a valid [Artifact Registry](https://cloud.google.com/artifact-registry/docs/overview) repository.\n", + "3. Run the following commands. Replace `:` with a valid [Artifact Registry](https://cloud.google.com/artifact-registry/docs/overview) repository and tag.\n", "\n", " ```\n", - " docker build -t \":latest\" -f VllmDockerfile ./\n", - " docker image push \":latest\"\n", + " docker build -t \":\" -f VllmDockerfile ./\n", + " docker image push \":\"\n", " ```" ], "metadata": { @@ -373,7 +373,8 @@ "First, define the pipeline options that you want to use to launch the Dataflow job. Before running the next cell, replace the following variables:\n", "\n", "- ``: the name of a valid [Google Cloud Storage](https://cloud.google.com/storage?e=48754805&hl=en) bucket. Don't include a `gs://` prefix or trailing slashes.\n", - "- ``: the name of the Google Artifact Registry repository that you used in the previous step. Don't include the `latest` tag, because this tag is automatically appended as part of the cell.\n", + "- ``: the name of the Google Artifact Registry repository that you used in the previous step. \n", + "- ``: image tag used in the previous step. Prefer a versioned tag or SHA instead of :latest tag or mutable tags.\n", "- ``: the name of the Google Cloud project that you created your bucket and Artifact Registry repository in.\n", "\n", "This workflow uses the following Dataflow service option: `worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx`. When you use this service option, Dataflow to installs a T4 GPU that uses a `5xx` series Nvidia driver on each worker machine. The 5xx driver is required to run vLLM jobs." @@ -396,7 +397,7 @@ "options = PipelineOptions()\n", "\n", "BUCKET_NAME = '' # Replace with your bucket name.\n", - "CONTAINER_LOCATION = '' # Replace with your container location ( from the previous step)\n", + "CONTAINER_IMAGE = ':' # Replace with the image repository and tag from the previous step.\n", "PROJECT_NAME = '' # Replace with your GCP project\n", "\n", "options.view_as(GoogleCloudOptions).project = PROJECT_NAME\n", @@ -428,7 +429,7 @@ "# Choose a machine type compatible with GPU type\n", "options.view_as(WorkerOptions).machine_type = \"n1-standard-4\"\n", "\n", - "options.view_as(WorkerOptions).worker_harness_container_image = '%s:latest' % CONTAINER_LOCATION" + "options.view_as(WorkerOptions).sdk_container_image = CONTAINER_IMAGE" ], "metadata": { "id": "kXy9FRYVCSjq" From 97fa43ba21dac48710d1ffb463718d0a26780c4b Mon Sep 17 00:00:00 2001 From: liferoad Date: Fri, 8 Nov 2024 16:44:15 -0500 Subject: [PATCH 144/181] Fixed the broken beam python on flink with PortableRunner --- runners/flink/job-server-container/Dockerfile | 4 +++- .../www/site/content/en/documentation/runners/flink.md | 9 +++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/runners/flink/job-server-container/Dockerfile b/runners/flink/job-server-container/Dockerfile index cbb73512400e..5f19aa0dc851 100644 --- a/runners/flink/job-server-container/Dockerfile +++ b/runners/flink/job-server-container/Dockerfile @@ -28,4 +28,6 @@ COPY target/LICENSE /opt/apache/beam/ COPY target/NOTICE /opt/apache/beam/ WORKDIR /opt/apache/beam -ENTRYPOINT ["./flink-job-server.sh"] + +# Add a conditional check for a mounted volume. This allows passing flink configs. +ENTRYPOINT ["/bin/sh", "-c", "if [ -d \"/flink-conf\" ]; then /opt/apache/beam/flink-job-server.sh --flink-conf-dir /flink-conf; else /opt/apache/beam/flink-job-server.sh; fi"] diff --git a/website/www/site/content/en/documentation/runners/flink.md b/website/www/site/content/en/documentation/runners/flink.md index e9522d76e832..94bf394c6b11 100644 --- a/website/www/site/content/en/documentation/runners/flink.md +++ b/website/www/site/content/en/documentation/runners/flink.md @@ -207,12 +207,17 @@ To run a pipeline on an embedded Flink cluster: {{< /paragraph >}} {{< paragraph class="language-portable" >}} -(1) Start the JobService endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest` +(1) Start the JobService endpoint: `docker run --net=host apache/beam_flink1.18_job_server:latest` {{< /paragraph >}} {{< paragraph class="language-portable" >}} The JobService is the central instance where you submit your Beam pipeline to. The JobService will create a Flink job for the pipeline and execute the job. +Note that you might see the error message like `Caused by: java.io.IOException: Insufficient number of network buffers:...`, +which can be fixed by passing a Flink configuration file to change the default ones. +One example can be found [here](https://github.com/apache/beam/blob/master/runners/flink/src/test/resources/flink-conf.yaml). +Then start the JobService endpoint by mounting a local configuration directory to `/flink`: +`docker run --net=host -v :/flink-conf beam-flink-runner apache/beam_flink1.18_job_server:latest` {{< /paragraph >}} {{< paragraph class="language-portable" >}} @@ -243,7 +248,7 @@ To run on a separate [Flink cluster](https://ci.apache.org/projects/flink/flink- {{< /paragraph >}} {{< paragraph class="language-portable" >}} -(2) Start JobService with Flink Rest endpoint: `docker run --net=host apache/beam_flink1.10_job_server:latest --flink-master=localhost:8081`. +(2) Start JobService with Flink Rest endpoint: `docker run --net=host apache/beam_flink1.18_job_server:latest --flink-master=localhost:8081`. {{< /paragraph >}} {{< paragraph class="language-portable" >}} From 4b7bcd8305680fe82495803089a7275ce63a1253 Mon Sep 17 00:00:00 2001 From: liferoad Date: Fri, 8 Nov 2024 17:57:36 -0500 Subject: [PATCH 145/181] Polished the doc --- .../www/site/content/en/documentation/runners/flink.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/website/www/site/content/en/documentation/runners/flink.md b/website/www/site/content/en/documentation/runners/flink.md index 94bf394c6b11..af73751c256a 100644 --- a/website/www/site/content/en/documentation/runners/flink.md +++ b/website/www/site/content/en/documentation/runners/flink.md @@ -212,11 +212,11 @@ To run a pipeline on an embedded Flink cluster: {{< paragraph class="language-portable" >}} The JobService is the central instance where you submit your Beam pipeline to. -The JobService will create a Flink job for the pipeline and execute the job. -Note that you might see the error message like `Caused by: java.io.IOException: Insufficient number of network buffers:...`, -which can be fixed by passing a Flink configuration file to change the default ones. -One example can be found [here](https://github.com/apache/beam/blob/master/runners/flink/src/test/resources/flink-conf.yaml). -Then start the JobService endpoint by mounting a local configuration directory to `/flink`: +It creates a Flink job from your pipeline and executes it. +You might encounter an error message like `Caused by: java.io.IOException: Insufficient number of network buffers:...`. +This can be resolved by providing a Flink configuration file to override the default settings. +You can find an example configuration file [here](https://github.com/apache/beam/blob/master/runners/flink/src/test/resources/flink-conf.yaml). +To start the Job Service endpoint with your custom configuration, mount a local directory containing your Flink configuration to the `/flink-conf` path in the Docker container: `docker run --net=host -v :/flink-conf beam-flink-runner apache/beam_flink1.18_job_server:latest` {{< /paragraph >}} From 47740d07ed9e0824b4975ca4cadece74aae38302 Mon Sep 17 00:00:00 2001 From: scwhittle Date: Mon, 11 Nov 2024 17:57:54 +0100 Subject: [PATCH 146/181] Simplify MoreFutures.supplyAsync and MoreFutures.runAsync using a wrapper instead of chained stages and multiple completions. (#33042) This ensures that what is joined, interacts with ForkJoinPool execution and avoids the possibility that async scheduled future is not joined. --- .../org/apache/beam/sdk/util/MoreFutures.java | 60 ++++++++----------- 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java index cd38da100a79..0999f2ad0771 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/MoreFutures.java @@ -45,9 +45,6 @@ *
  • Return {@link CompletableFuture} only to the producer of a future value. * */ -@SuppressWarnings({ - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) -}) public class MoreFutures { /** @@ -99,22 +96,18 @@ public static boolean isCancelled(CompletionStage future) { */ public static CompletionStage supplyAsync( ThrowingSupplier supplier, ExecutorService executorService) { - CompletableFuture result = new CompletableFuture<>(); - - CompletionStage wrapper = - CompletableFuture.runAsync( - () -> { - try { - result.complete(supplier.get()); - } catch (InterruptedException e) { - result.completeExceptionally(e); - Thread.currentThread().interrupt(); - } catch (Throwable t) { - result.completeExceptionally(t); - } - }, - executorService); - return wrapper.thenCompose(nothing -> result); + return CompletableFuture.supplyAsync( + () -> { + try { + return supplier.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new CompletionException(e); + } catch (Throwable t) { + throw new CompletionException(t); + } + }, + executorService); } /** @@ -132,23 +125,18 @@ public static CompletionStage supplyAsync(ThrowingSupplier supplier) { */ public static CompletionStage runAsync( ThrowingRunnable runnable, ExecutorService executorService) { - CompletableFuture result = new CompletableFuture<>(); - - CompletionStage wrapper = - CompletableFuture.runAsync( - () -> { - try { - runnable.run(); - result.complete(null); - } catch (InterruptedException e) { - result.completeExceptionally(e); - Thread.currentThread().interrupt(); - } catch (Throwable t) { - result.completeExceptionally(t); - } - }, - executorService); - return wrapper.thenCompose(nothing -> result); + return CompletableFuture.runAsync( + () -> { + try { + runnable.run(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new CompletionException(e); + } catch (Throwable t) { + throw new CompletionException(t); + } + }, + executorService); } /** From c6549e71ea454a7efb80c890bacb257813a36be8 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 11 Nov 2024 12:09:39 -0500 Subject: [PATCH 147/181] Fix exception sampling logic (#33076) * Fix exception sampling logic * Allow subclasses of BaseException * Mock cython if not present * Use correct types/names * format --- sdks/python/apache_beam/runners/common.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 8a4f26c18e88..c43870d55ebb 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -1504,8 +1504,7 @@ def process(self, windowed_value): return [] def _maybe_sample_exception( - self, exn: BaseException, - windowed_value: Optional[WindowedValue]) -> None: + self, exc_info: Tuple, windowed_value: Optional[WindowedValue]) -> None: if self.execution_context is None: return @@ -1516,7 +1515,7 @@ def _maybe_sample_exception( output_sampler.sample_exception( windowed_value, - exn, + exc_info, self.transform_id, self.execution_context.instruction_id) From 073aac9d8b42067b8d9b969d1a203ef2a7a561b7 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Mon, 11 Nov 2024 13:02:53 -0500 Subject: [PATCH 148/181] Fix antlr analyzeClassesDependencies (#33080) --- sdks/java/core/build.gradle | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/java/core/build.gradle b/sdks/java/core/build.gradle index 07144c8de053..a8dfbf42f970 100644 --- a/sdks/java/core/build.gradle +++ b/sdks/java/core/build.gradle @@ -73,6 +73,7 @@ dependencies { antlr library.java.antlr // antlr is used to generate code from sdks/java/core/src/main/antlr/ permitUnusedDeclared library.java.antlr + permitUsedUndeclared library.java.antlr_runtime // Required to load constants from the model, e.g. max timestamp for global window shadow project(path: ":model:pipeline", configuration: "shadow") shadow project(path: ":model:fn-execution", configuration: "shadow") From fca0bea5e9fd9bff31c784b66085d0196ad04678 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Mon, 11 Nov 2024 10:12:29 -0800 Subject: [PATCH 149/181] Fix cleanup timer timestamp to not exceed max allowed timestamp (#33037) This fixes an exception during drain on jobs with GlobalWindows + AllowedLateness > 24h + @OnExpiredWindows callback --- .../dataflow/worker/SimpleParDoFn.java | 16 ++- .../worker/UserParDoFnFactoryTest.java | 102 ++++++++++++++++++ 2 files changed, 114 insertions(+), 4 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java index a413c2c03dbe..558848f488a7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java @@ -77,6 +77,7 @@ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) public class SimpleParDoFn implements ParDoFn { + // TODO: Remove once Distributions has shipped. @VisibleForTesting static final String OUTPUTS_PER_ELEMENT_EXPERIMENT = "outputs_per_element_counter"; @@ -174,6 +175,7 @@ private boolean hasExperiment(String experiment) { /** Simple state tracker to calculate PerElementOutputCount counter. */ private interface OutputsPerElementTracker { + void onOutput(); void onProcessElement(); @@ -182,6 +184,7 @@ private interface OutputsPerElementTracker { } private class OutputsPerElementTrackerImpl implements OutputsPerElementTracker { + private long outputsPerElement; private final Counter counter; @@ -214,6 +217,7 @@ private void reset() { /** No-op {@link OutputsPerElementTracker} implementation used when the counter is disabled. */ private static class NoopOutputsPerElementTracker implements OutputsPerElementTracker { + private NoopOutputsPerElementTracker() {} public static final OutputsPerElementTracker INSTANCE = new NoopOutputsPerElementTracker(); @@ -516,10 +520,14 @@ private void registerStateCleanup( private Instant earliestAllowableCleanupTime( BoundedWindow window, WindowingStrategy windowingStrategy) { - return window - .maxTimestamp() - .plus(windowingStrategy.getAllowedLateness()) - .plus(Duration.millis(1L)); + Instant cleanupTime = + window + .maxTimestamp() + .plus(windowingStrategy.getAllowedLateness()) + .plus(Duration.millis(1L)); + return cleanupTime.isAfter(BoundedWindow.TIMESTAMP_MAX_VALUE) + ? BoundedWindow.TIMESTAMP_MAX_VALUE + : cleanupTime; } /** diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java index ff114ef2f078..c1e5000f03da 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.theInstance; @@ -153,6 +154,21 @@ private static class TestStatefulDoFn extends DoFn, Void> { public void processElement(ProcessContext c) {} } + private static class TestStatefulDoFnWithWindowExpiration + extends DoFn, Void> { + + public static final String STATE_ID = "state-id"; + + @StateId(STATE_ID) + private final StateSpec> spec = StateSpecs.value(StringUtf8Coder.of()); + + @ProcessElement + public void processElement(ProcessContext c) {} + + @OnWindowExpiration + public void onWindowExpiration() {} + } + private static final TupleTag MAIN_OUTPUT = new TupleTag<>("1"); private UserParDoFnFactory factory = UserParDoFnFactory.createDefault(); @@ -373,6 +389,92 @@ public void testCleanupRegistered() throws Exception { firstWindow.maxTimestamp().plus(Duration.millis(1L))); } + /** + * Regression test for global window + OnWindowExpiration + allowed lateness > max allowed time + */ + @Test + public void testCleanupTimerForGlobalWindowWithAllowedLateness() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + CounterSet counters = new CounterSet(); + DoFn initialFn = new TestStatefulDoFnWithWindowExpiration(); + Duration allowedLateness = Duration.standardDays(2); + CloudObject cloudObject = + getCloudObject( + initialFn, WindowingStrategy.globalDefault().withAllowedLateness(allowedLateness)); + + StateInternals stateInternals = InMemoryStateInternals.forKey("dummy"); + + TimerInternals timerInternals = mock(TimerInternals.class); + + DataflowStepContext stepContext = mock(DataflowStepContext.class); + when(stepContext.timerInternals()).thenReturn(timerInternals); + DataflowStepContext userStepContext = mock(DataflowStepContext.class); + when(stepContext.namespacedToUser()).thenReturn(userStepContext); + when(stepContext.stateInternals()).thenReturn(stateInternals); + when(userStepContext.stateInternals()).thenReturn((StateInternals) stateInternals); + + DataflowExecutionContext executionContext = + mock(DataflowExecutionContext.class); + TestOperationContext operationContext = TestOperationContext.create(counters); + when(executionContext.getStepContext(operationContext)).thenReturn(stepContext); + when(executionContext.getSideInputReader(any(), any(), any())) + .thenReturn(NullSideInputReader.empty()); + + ParDoFn parDoFn = + factory.create( + options, + cloudObject, + Collections.emptyList(), + MAIN_OUTPUT, + ImmutableMap.of(MAIN_OUTPUT, 0), + executionContext, + operationContext); + + Receiver rcvr = new OutputReceiver(); + parDoFn.startBundle(rcvr); + + GlobalWindow globalWindow = GlobalWindow.INSTANCE; + parDoFn.processElement( + WindowedValue.of("foo", new Instant(1), globalWindow, PaneInfo.NO_FIRING)); + + assertThat( + globalWindow.maxTimestamp().plus(allowedLateness), + greaterThan(BoundedWindow.TIMESTAMP_MAX_VALUE)); + verify(stepContext) + .setStateCleanupTimer( + SimpleParDoFn.CLEANUP_TIMER_ID, + globalWindow, + GlobalWindow.Coder.INSTANCE, + BoundedWindow.TIMESTAMP_MAX_VALUE, + BoundedWindow.TIMESTAMP_MAX_VALUE.minus(Duration.millis(1))); + + StateNamespace globalWindowNamespace = + StateNamespaces.window(GlobalWindow.Coder.INSTANCE, globalWindow); + StateTag> tag = + StateTags.tagForSpec( + TestStatefulDoFnWithWindowExpiration.STATE_ID, StateSpecs.value(StringUtf8Coder.of())); + + when(userStepContext.getNextFiredTimer((Coder) GlobalWindow.Coder.INSTANCE)).thenReturn(null); + when(stepContext.getNextFiredTimer((Coder) GlobalWindow.Coder.INSTANCE)) + .thenReturn( + TimerData.of( + SimpleParDoFn.CLEANUP_TIMER_ID, + globalWindowNamespace, + BoundedWindow.TIMESTAMP_MAX_VALUE, + BoundedWindow.TIMESTAMP_MAX_VALUE.minus(Duration.millis(1)), + TimeDomain.EVENT_TIME)) + .thenReturn(null); + + // Set up non-empty state. We don't mock + verify calls to clear() but instead + // check that state is actually empty. We mustn't care how it is accomplished. + stateInternals.state(globalWindowNamespace, tag).write("first"); + + // And this should clean up the second window + parDoFn.processTimers(); + + assertThat(stateInternals.state(globalWindowNamespace, tag).read(), nullValue()); + } + @Test public void testCleanupWorks() throws Exception { PipelineOptions options = PipelineOptionsFactory.create(); From 97ae5aa8d70519921da761201c1e85113b5cdf26 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Mon, 11 Nov 2024 13:42:25 -0500 Subject: [PATCH 150/181] Remove nondeterminism from expansion tests (#33082) --- sdks/python/apache_beam/transforms/external_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sdks/python/apache_beam/transforms/external_test.py b/sdks/python/apache_beam/transforms/external_test.py index fe2914a08699..c95a5d19f0cd 100644 --- a/sdks/python/apache_beam/transforms/external_test.py +++ b/sdks/python/apache_beam/transforms/external_test.py @@ -52,6 +52,7 @@ from apache_beam.typehints.native_type_compatibility import convert_to_beam_type from apache_beam.utils import proto_utils from apache_beam.utils.subprocess_server import JavaJarServer +from apache_beam.utils.subprocess_server import SubprocessServer # Protect against environments where apitools library is not available. # pylint: disable=wrong-import-order, wrong-import-position @@ -718,6 +719,9 @@ def test_implicit_builder_with_constructor_method(self): class JavaJarExpansionServiceTest(unittest.TestCase): + def setUp(self): + SubprocessServer._cache._live_owners = set() + def test_classpath(self): with tempfile.TemporaryDirectory() as temp_dir: try: From 126f278505b339c32e403fbcb8b9cb34bfaea81e Mon Sep 17 00:00:00 2001 From: scwhittle Date: Mon, 11 Nov 2024 22:00:49 +0100 Subject: [PATCH 151/181] Optimize proto PubsubMessage to/from conversion to beam PubsubMessage (#32973) --- sdks/python/apache_beam/io/gcp/pubsub.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index b6f801c63f79..9e006dbeda93 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -126,7 +126,7 @@ def _from_proto_str(proto_msg: bytes) -> 'PubsubMessage': """ msg = pubsub.types.PubsubMessage.deserialize(proto_msg) # Convert ScalarMapContainer to dict. - attributes = dict((key, msg.attributes[key]) for key in msg.attributes) + attributes = dict(msg.attributes) return PubsubMessage( msg.data, attributes, @@ -151,10 +151,8 @@ def _to_proto_str(self, for_publish=False): https://cloud.google.com/pubsub/docs/reference/rpc/google.pubsub.v1#google.pubsub.v1.PubsubMessage containing the payload of this object. """ - msg = pubsub.types.PubsubMessage() if len(self.data) > (10_000_000): raise ValueError('A pubsub message data field must not exceed 10MB') - msg.data = self.data if self.attributes: if len(self.attributes) > 100: @@ -167,19 +165,25 @@ def _to_proto_str(self, for_publish=False): if len(value) > 1024: raise ValueError( 'A pubsub message attribute value must not exceed 1024 bytes') - msg.attributes[key] = value + message_id = None + publish_time = None if not for_publish: if self.message_id: - msg.message_id = self.message_id + message_id = self.message_id if self.publish_time: - msg.publish_time = self.publish_time + publish_time = self.publish_time if len(self.ordering_key) > 1024: raise ValueError( 'A pubsub message ordering key must not exceed 1024 bytes.') - msg.ordering_key = self.ordering_key + msg = pubsub.types.PubsubMessage( + data=self.data, + attributes=self.attributes, + message_id=message_id, + publish_time=publish_time, + ordering_key=self.ordering_key) serialized = pubsub.types.PubsubMessage.serialize(msg) if len(serialized) > (10_000_000): raise ValueError( @@ -193,7 +197,7 @@ def _from_message(msg: Any) -> 'PubsubMessage': https://googleapis.github.io/google-cloud-python/latest/pubsub/subscriber/api/message.html """ # Convert ScalarMapContainer to dict. - attributes = dict((key, msg.attributes[key]) for key in msg.attributes) + attributes = dict(msg.attributes) pubsubmessage = PubsubMessage(msg.data, attributes) if msg.message_id: pubsubmessage.message_id = msg.message_id From d760383817dce7d5d0d6ec5b4a500d948b041204 Mon Sep 17 00:00:00 2001 From: scwhittle Date: Mon, 11 Nov 2024 22:03:07 +0100 Subject: [PATCH 152/181] Changes to reduce memory pinned while iterating through state backed iterable: (#32961) - remove reference to completed encoded input page from decoder once we have read it. - re-read from cache after loading the next page to give eviction a chance to remove blocks --- .../beam/sdk/fn/stream/DataStreams.java | 6 ++ .../harness/state/StateFetchingIterators.java | 83 ++++++++++--------- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java index b0d29e2295a8..2c6b61e62121 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/DataStreams.java @@ -202,6 +202,12 @@ public WeightedList decodeFromChunkBoundaryToChunkBoundary() { T next = next(); rvals.add(next); } + // We don't support seeking backwards so release the memory of the last + // page if it is completed. + if (inbound.currentStream.available() == 0) { + inbound.position = 0; + inbound.currentStream = EMPTY_STREAM; + } // Uses the size of the ByteString as an approximation for the heap size occupied by the // page, considering an overhead of {@link BYTES_LIST_ELEMENT_OVERHEAD} for each element. diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 3b9fccfa2a5e..81a2aa6d1cc6 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -105,7 +105,7 @@ public long getWeight() { // many different state subcaches. return 0; } - }; + } /** A mutable iterable that supports prefetch and is backed by a cache. */ static class CachingStateIterable extends PrefetchableIterables.Default { @@ -138,8 +138,8 @@ public long getWeight() { private static long sumWeight(List> blocks) { try { long sum = 0; - for (int i = 0; i < blocks.size(); ++i) { - sum = Math.addExact(sum, blocks.get(i).getWeight()); + for (Block block : blocks) { + sum = Math.addExact(sum, block.getWeight()); } return sum; } catch (ArithmeticException e) { @@ -437,50 +437,59 @@ public boolean hasNext() { if (currentBlock.getValues().size() > currentCachedBlockValueIndex) { return true; } - if (currentBlock.getNextToken() == null) { + final ByteString nextToken = currentBlock.getNextToken(); + if (nextToken == null) { return false; } - Blocks existing = cache.peek(IterableCacheKey.INSTANCE); - boolean isFirstBlock = ByteString.EMPTY.equals(currentBlock.getNextToken()); + // Release the block while we are loading the next one. + currentBlock = + Block.fromValues(new WeightedList<>(Collections.emptyList(), 0L), ByteString.EMPTY); + + @Nullable Blocks existing = cache.peek(IterableCacheKey.INSTANCE); + boolean isFirstBlock = ByteString.EMPTY.equals(nextToken); if (existing == null) { - currentBlock = loadNextBlock(currentBlock.getNextToken()); + currentBlock = loadNextBlock(nextToken); if (isFirstBlock) { cache.put( IterableCacheKey.INSTANCE, new BlocksPrefix<>(Collections.singletonList(currentBlock))); } + } else if (isFirstBlock) { + currentBlock = existing.getBlocks().get(0); } else { - if (isFirstBlock) { - currentBlock = existing.getBlocks().get(0); - } else { - checkState( - existing instanceof BlocksPrefix, - "Unexpected blocks type %s, expected a %s.", - existing.getClass(), - BlocksPrefix.class); - List> blocks = existing.getBlocks(); - int currentBlockIndex = 0; - for (; currentBlockIndex < blocks.size(); ++currentBlockIndex) { - if (currentBlock - .getNextToken() - .equals(blocks.get(currentBlockIndex).getNextToken())) { - break; - } + checkState( + existing instanceof BlocksPrefix, + "Unexpected blocks type %s, expected a %s.", + existing.getClass(), + BlocksPrefix.class); + List> blocks = existing.getBlocks(); + int currentBlockIndex = 0; + for (; currentBlockIndex < blocks.size(); ++currentBlockIndex) { + if (nextToken.equals(blocks.get(currentBlockIndex).getNextToken())) { + break; } - // Load the next block from cache if it was found. - if (currentBlockIndex + 1 < blocks.size()) { - currentBlock = blocks.get(currentBlockIndex + 1); - } else { - // Otherwise load the block from state API. - currentBlock = loadNextBlock(currentBlock.getNextToken()); - - // Append this block to the existing set of blocks if it is logically the next one. - if (currentBlockIndex == blocks.size() - 1) { - List> newBlocks = new ArrayList<>(currentBlockIndex + 1); - newBlocks.addAll(blocks); - newBlocks.add(currentBlock); - cache.put(IterableCacheKey.INSTANCE, new BlocksPrefix<>(newBlocks)); - } + } + // Take the next block from the cache if it was found. + if (currentBlockIndex + 1 < blocks.size()) { + currentBlock = blocks.get(currentBlockIndex + 1); + } else { + // Otherwise load the block from state API. + // Remove references on the cached values while we are loading the next block. + existing = null; + blocks = null; + currentBlock = loadNextBlock(nextToken); + existing = cache.peek(IterableCacheKey.INSTANCE); + // Append this block to the existing set of blocks if it is logically the next one + // according to the + // tokens. + if (existing != null + && !existing.getBlocks().isEmpty() + && nextToken.equals( + existing.getBlocks().get(existing.getBlocks().size() - 1).getNextToken())) { + List> newBlocks = new ArrayList<>(currentBlockIndex + 1); + newBlocks.addAll(existing.getBlocks()); + newBlocks.add(currentBlock); + cache.put(IterableCacheKey.INSTANCE, new BlocksPrefix<>(newBlocks)); } } } From 43d27ed52aff45e47ee48e1e1aef5ba2abfb0a94 Mon Sep 17 00:00:00 2001 From: Minbo Bae <49642083+baeminbo@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:46:22 -0800 Subject: [PATCH 153/181] Add a guide to build custom Beam Python SDK image (#33048) * Add a guide to build custom Beam Python SDK image * Address review comments --- .../en/documentation/runtime/environments.md | 4 +- .../sdks/python-sdk-image-build.md | 306 ++++++++++++++++++ .../partials/section-menu/en/sdks.html | 1 + 3 files changed, 310 insertions(+), 1 deletion(-) create mode 100644 website/www/site/content/en/documentation/sdks/python-sdk-image-build.md diff --git a/website/www/site/content/en/documentation/runtime/environments.md b/website/www/site/content/en/documentation/runtime/environments.md index a048c21046ba..48039d50a10b 100644 --- a/website/www/site/content/en/documentation/runtime/environments.md +++ b/website/www/site/content/en/documentation/runtime/environments.md @@ -105,7 +105,9 @@ This method requires building image artifacts from Beam source. For additional i 2. Customize the `Dockerfile` for a given language, typically `sdks//container/Dockerfile` directory (e.g. the [Dockerfile for Python](https://github.com/apache/beam/blob/master/sdks/python/container/Dockerfile). -3. Return to the root Beam directory and run the Gradle `docker` target for your image. +3. Return to the root Beam directory and run the Gradle `docker` target for your + image. For self-contained instructions on building a container image, + follow [this guide](/documentation/sdks/python-sdk-image-build). ``` cd $BEAM_WORKDIR diff --git a/website/www/site/content/en/documentation/sdks/python-sdk-image-build.md b/website/www/site/content/en/documentation/sdks/python-sdk-image-build.md new file mode 100644 index 000000000000..f456a686afea --- /dev/null +++ b/website/www/site/content/en/documentation/sdks/python-sdk-image-build.md @@ -0,0 +1,306 @@ + + +# Building Beam Python SDK Image Guide + +There are two options to build Beam Python SDK image. If you only need to modify +[the Python SDK boot entrypoint binary](https://github.com/apache/beam/blob/master/sdks/python/container/boot.go), +read [Update Boot Entrypoint Application Only](#update-boot-entrypoint-application-only). +If you need to build a Beam Python SDK image fully, +read [Build Beam Python SDK Image Fully](#build-beam-python-sdk-image-fully). + + +## Update Boot Entrypoint Application Only. + +If you only need to make a change to [the Python SDK boot entrypoint binary](https://github.com/apache/beam/blob/master/sdks/python/container/boot.go). You +can rebuild the boot application only and include the updated boot application +in the preexisting image. +Read [the Python container Dockerfile](https://github.com/apache/beam/blob/master/sdks/python/container/Dockerfile) +for reference. + +```shell +# From beam repo root, make changes to boot.go. +your_editor sdks/python/container/boot.go + +# Rebuild the entrypoint +./gradlew :sdks:python:container:gobuild + +cd sdks/python/container/build/target/launcher/linux_amd64 + +# Create a simple Dockerfile to use custom boot entrypoint. +cat >Dockerfile <//beam_python3.10_sdk:2.60.0-custom-boot +docker push us-central1-docker.pkg.dev///beam_python3.10_sdk:2.60.0-custom-boot +``` + +You can build a docker image if your local environment has Java, Python, Golang +and Docker installation. Try +`./gradlew :sdks:python:container:py:docker`. For example, +`:sdks:python:container:py310:docker` builds `apache/beam_python3.10_sdk` +locally if successful. You can follow this guide building a custom image from +a VM if the build fails in your local environment. + +## Build Beam Python SDK Image Fully + +This section introduces a way to build everything from the scratch. + +### Prepare VM + +Prepare a VM with Debian 11. This guide was tested on Debian 11. + +#### Google Compute Engine + +An option to create a Debian 11 VM is using a GCE instance. + +```shell +gcloud compute instances create beam-builder \ + --zone=us-central1-a \ + --image-project=debian-cloud \ + --image-family=debian-11 \ + --machine-type=n1-standard-8 \ + --boot-disk-size=20GB \ + --scopes=cloud-platform +``` + +Login to the VM. All the following steps are executed inside the VM. + +```shell +gcloud compute ssh beam-builder --zone=us-central1-a --tunnel-through-iap +``` + +Update the apt package list. + +```shell +sudo apt-get update +``` + +> [!NOTE] +> * A high CPU machine is recommended to reduce the compile time. +> * The image build needs a large disk. The build will fail with "no space left + on device" with the default disk size 10GB. +> * The `cloud-platform` is recommended to avoid permission issues with Google + Cloud Artifact Registry. You can use the default scopes if you don't push + the image to Google Cloud Artifact Registry. +> * Use a zone in the region of your docker repository of Artifact Registry if + you push the image to Artifact Registry. + +### Prerequisite Packages + +#### Java + +You need Java to run Gradle tasks. + +```shell +sudo apt-get install -y openjdk-11-jdk +``` + +#### Golang + +Download and install. Reference: https://go.dev/doc/install. + +```shell +# Download and install +curl -OL https://go.dev/dl/go1.23.2.linux-amd64.tar.gz +sudo rm -rf /usr/local/go && sudo tar -C /usr/local -xzf go1.23.2.linux-amd64.tar.gz + +# Add go to PATH. +export PATH=:/usr/local/go/bin:$PATH +``` + +Confirm the Golang version + +```shell +go version +``` + +Expected output: + +```text +go version go1.23.2 linux/amd64 +``` + +> [!NOTE] +> Old Go version (e.g. 1.16) will fail at `:sdks:python:container:goBuild`. + +#### Python + +This guide uses Pyenv to manage multiple Python versions. +Reference: https://realpython.com/intro-to-pyenv/#build-dependencies + +```shell +# Install dependencies +sudo apt-get install -y make build-essential libssl-dev zlib1g-dev \ +libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev \ +libncursesw5-dev xz-utils tk-dev libffi-dev liblzma-dev + +# Install Pyenv +curl https://pyenv.run | bash + +# Add pyenv to PATH. +export PATH="$HOME/.pyenv/bin:$PATH" +eval "$(pyenv init -)" +eval "$(pyenv virtualenv-init -)" +``` + +Install Python 3.9 and set the Python version. This will take several minutes. + +```shell +pyenv install 3.9 +pyenv global 3.9 +``` + +Confirm the python version. + +```shell +python --version +``` + +Expected output example: + +```text +Python 3.9.17 +``` + +> [!NOTE] +> You can use a different Python version for building with [ +`-PpythonVersion` option](https://github.com/apache/beam/blob/v2.60.0/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy#L2956-L2961) +> to Gradle task run. Otherwise, you should have `python3.9` in the build +> environment for Apache Beam 2.60.0 or later (python3.8 for older Apache Beam +> versions). If you use the wrong version, the Gradle task +`:sdks:python:setupVirtualenv` fails. + +#### Docker + +Install Docker +following [the reference](https://docs.docker.com/engine/install/debian/#install-using-the-repository). + +```shell +# Add GPG keys. +sudo apt-get update +sudo apt-get install ca-certificates curl +sudo install -m 0755 -d /etc/apt/keyrings +sudo curl -fsSL https://download.docker.com/linux/debian/gpg -o /etc/apt/keyrings/docker.asc +sudo chmod a+r /etc/apt/keyrings/docker.asc + +# Add the Apt repository. +echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/debian \ + $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \ + sudo tee /etc/apt/sources.list.d/docker.list > /dev/null +sudo apt-get update + +# Install docker packages. +sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin +``` + +You need to run `docker` command without the root privilege in Beam Python SDK +image build. You can do this +by [adding your account to the docker group](https://docs.docker.com/engine/install/linux-postinstall/). + +```shell +sudo usermod -aG docker $USER +newgrp docker +``` + +Confirm if you can run a container without the root privilege. + +```shell +docker run hello-world +``` + +#### Git + +Git is not necessary for building Python SDK image. Git is just used to download +the Apache Beam code in this guide. + +```shell +sudo apt-get install -y git +``` + +### Build Beam Python SDK Image + +Download Apache Beam +from [the Github repository](https://github.com/apache/beam). + +```shell +git clone https://github.com/apache/beam beam +cd beam +``` + +Make changes to the Apache Beam code. + +Run the Gradle task to start Docker image build. This will take several minutes. +You can run `:sdks:python:container:py:docker` to build an image +for different Python version. +See [the supported Python version list](https://github.com/apache/beam/tree/master/sdks/python/container). +For example, `py310` is for Python 3.10. + +```shell +./gradlew :sdks:python:container:py310:docker +``` + +If the build is successful, you can see the built image locally. + +```shell +docker images +``` + +Expected output: + +```text +REPOSITORY TAG IMAGE ID CREATED SIZE +apache/beam_python3.10_sdk 2.60.0 33db45f57f25 About a minute ago 2.79GB +``` + +> [!NOTE] +> If you run the build in your local environment and Gradle task +`:sdks:python:setupVirtualenv` fails by an incompatible python version, please +> try with `-PpythonVersion` with the Python version installed in your local +> environment (e.g. `-PpythonVersion=3.10`) + +### Push to Repository + +You may push the custom image to a image repository. The image can be used +for [Dataflow custom container](https://cloud.google.com/dataflow/docs/guides/run-custom-container#usage). + +#### Google Cloud Artifact Registry + +You can push the image to Artifact Registry. No additional authentication is +necessary if you use Google Compute Engine. + +```shell +docker tag apache/beam_python3.10_sdk:2.60.0 us-central1-docker.pkg.dev///beam_python3.10_sdk:2.60.0-custom +docker push us-central1-docker.pkg.dev///beam_python3.10_sdk:2.60.0-custom +``` + +If you push an image in an environment other than a VM in Google Cloud, you +should configure [docker authentication with +`gcloud`](https://cloud.google.com/artifact-registry/docs/docker/authentication#gcloud-helper) +before `docker push`. + +#### Docker Hub + +You can push your Docker hub repository +after [docker login](https://docs.docker.com/reference/cli/docker/login/). + +```shell +docker tag apache/beam_python3.10_sdk:2.60.0 /beam_python3.10_sdk:2.60.0-custom +docker push /beam_python3.10_sdk:2.60.0-custom +``` + diff --git a/website/www/site/layouts/partials/section-menu/en/sdks.html b/website/www/site/layouts/partials/section-menu/en/sdks.html index ea48eb6f40d9..243bbd92a465 100644 --- a/website/www/site/layouts/partials/section-menu/en/sdks.html +++ b/website/www/site/layouts/partials/section-menu/en/sdks.html @@ -44,6 +44,7 @@
  • Managing pipeline dependencies
  • Python multi-language pipelines quickstart
  • Python Unrecoverable Errors
  • +
  • Python SDK image build
  • From 785ec0705cea56c15ba4422905120d21183daba2 Mon Sep 17 00:00:00 2001 From: scwhittle Date: Tue, 12 Nov 2024 11:36:51 +0100 Subject: [PATCH 154/181] Support poisioning instruction ids to prevent the FnApi data stream from blocking on failed instructions (#32857) --- .../fn/data/BeamFnDataGrpcMultiplexer.java | 129 ++++++++++++++---- .../data/BeamFnDataGrpcMultiplexerTest.java | 4 +- .../harness/control/ProcessBundleHandler.java | 89 ++++++------ .../fn/harness/data/BeamFnDataClient.java | 11 +- .../fn/harness/data/BeamFnDataGrpcClient.java | 8 ++ .../PTransformRunnerFactoryTestContext.java | 5 + .../control/ProcessBundleHandlerTest.java | 1 + .../data/BeamFnDataGrpcClientTest.java | 90 ++++++++++++ 8 files changed, 262 insertions(+), 75 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java index e946022c4e36..aa0dea80b0a1 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexer.java @@ -17,11 +17,13 @@ */ package org.apache.beam.sdk.fn.data; +import java.time.Duration; import java.util.HashSet; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.function.Consumer; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.pipeline.v1.Endpoints; @@ -30,6 +32,8 @@ import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; @@ -49,13 +53,20 @@ */ public class BeamFnDataGrpcMultiplexer implements AutoCloseable { private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataGrpcMultiplexer.class); + private static final Duration POISONED_INSTRUCTION_ID_CACHE_TIMEOUT = Duration.ofMinutes(20); private final Endpoints.@Nullable ApiServiceDescriptor apiServiceDescriptor; private final StreamObserver inboundObserver; private final StreamObserver outboundObserver; - private final ConcurrentMap< + private final ConcurrentHashMap< /*instructionId=*/ String, CompletableFuture>> receivers; - private final ConcurrentMap erroredInstructionIds; + private final Cache poisonedInstructionIds; + + private static class PoisonedException extends RuntimeException { + public PoisonedException() { + super("Instruction poisoned"); + } + }; public BeamFnDataGrpcMultiplexer( Endpoints.@Nullable ApiServiceDescriptor apiServiceDescriptor, @@ -64,7 +75,8 @@ public BeamFnDataGrpcMultiplexer( baseOutboundObserverFactory) { this.apiServiceDescriptor = apiServiceDescriptor; this.receivers = new ConcurrentHashMap<>(); - this.erroredInstructionIds = new ConcurrentHashMap<>(); + this.poisonedInstructionIds = + CacheBuilder.newBuilder().expireAfterWrite(POISONED_INSTRUCTION_ID_CACHE_TIMEOUT).build(); this.inboundObserver = new InboundObserver(); this.outboundObserver = outboundObserverFactory.outboundObserverFor(baseOutboundObserverFactory, inboundObserver); @@ -87,11 +99,6 @@ public StreamObserver getOutboundObserver() { return outboundObserver; } - private CompletableFuture> receiverFuture( - String instructionId) { - return receivers.computeIfAbsent(instructionId, (unused) -> new CompletableFuture<>()); - } - /** * Registers a consumer for the specified instruction id. * @@ -99,17 +106,63 @@ private CompletableFuture> receiverF * instruction ids ensuring that the receiver will only see {@link BeamFnApi.Elements} with a * single instruction id. * - *

    The caller must {@link #unregisterConsumer unregister the consumer} when they no longer wish - * to receive messages. + *

    The caller must either {@link #unregisterConsumer unregister the consumer} when all messages + * have been processed or {@link #poisonInstructionId(String) poison the instruction} if messages + * for the instruction should be dropped. */ public void registerConsumer( String instructionId, CloseableFnDataReceiver receiver) { - receiverFuture(instructionId).complete(receiver); + receivers.compute( + instructionId, + (unused, existing) -> { + if (existing != null) { + if (!existing.complete(receiver)) { + throw new IllegalArgumentException("Instruction id was registered twice"); + } + return existing; + } + if (poisonedInstructionIds.getIfPresent(instructionId) != null) { + throw new IllegalArgumentException("Instruction id was poisoned"); + } + return CompletableFuture.completedFuture(receiver); + }); } - /** Unregisters a consumer. */ + /** Unregisters a previously registered consumer. */ public void unregisterConsumer(String instructionId) { - receivers.remove(instructionId); + @Nullable + CompletableFuture> receiverFuture = + receivers.remove(instructionId); + if (receiverFuture != null && !receiverFuture.isDone()) { + // The future must have been inserted by the inbound observer since registerConsumer completes + // the future. + throw new IllegalArgumentException("Unregistering consumer which was not registered."); + } + } + + /** + * Poisons an instruction id. + * + *

    Any records for the instruction on the inbound observer will be dropped for the next {@link + * #POISONED_INSTRUCTION_ID_CACHE_TIMEOUT}. + */ + public void poisonInstructionId(String instructionId) { + poisonedInstructionIds.put(instructionId, Boolean.TRUE); + @Nullable + CompletableFuture> receiverFuture = + receivers.remove(instructionId); + if (receiverFuture != null) { + // Completing exceptionally has no effect if the future was already notified. In that case + // whatever registered the receiver needs to handle cancelling it. + receiverFuture.completeExceptionally(new PoisonedException()); + if (!receiverFuture.isCompletedExceptionally()) { + try { + receiverFuture.get().close(); + } catch (Exception e) { + LOG.warn("Unexpected error closing existing observer"); + } + } + } } @VisibleForTesting @@ -210,27 +263,42 @@ public void onNext(BeamFnApi.Elements value) { } private void forwardToConsumerForInstructionId(String instructionId, BeamFnApi.Elements value) { - if (erroredInstructionIds.containsKey(instructionId)) { - LOG.debug("Ignoring inbound data for failed instruction {}", instructionId); - return; - } - CompletableFuture> consumerFuture = - receiverFuture(instructionId); - if (!consumerFuture.isDone()) { - LOG.debug( - "Received data for instruction {} without consumer ready. " - + "Waiting for consumer to be registered.", - instructionId); - } CloseableFnDataReceiver consumer; try { - consumer = consumerFuture.get(); - + CompletableFuture> consumerFuture = + receivers.computeIfAbsent( + instructionId, + (unused) -> { + if (poisonedInstructionIds.getIfPresent(instructionId) != null) { + throw new PoisonedException(); + } + LOG.debug( + "Received data for instruction {} without consumer ready. " + + "Waiting for consumer to be registered.", + instructionId); + return new CompletableFuture<>(); + }); + // The consumer may not be registered until the bundle processor is fully constructed so we + // conservatively set + // a high timeout. Poisoning will prevent this for occurring for consumers that will not be + // registered. + consumer = consumerFuture.get(3, TimeUnit.HOURS); /* * TODO: On failure we should fail any bundles that were impacted eagerly * instead of relying on the Runner harness to do all the failure handling. */ - } catch (ExecutionException | InterruptedException e) { + } catch (TimeoutException e) { + LOG.error( + "Timed out waiting to observe consumer data stream for instruction {}", + instructionId, + e); + outboundObserver.onError(e); + return; + } catch (ExecutionException | InterruptedException | PoisonedException e) { + if (e instanceof PoisonedException || e.getCause() instanceof PoisonedException) { + LOG.debug("Received data for poisoned instruction {}. Dropping input.", instructionId); + return; + } LOG.error( "Client interrupted during handling of data for instruction {}", instructionId, e); outboundObserver.onError(e); @@ -240,10 +308,11 @@ private void forwardToConsumerForInstructionId(String instructionId, BeamFnApi.E outboundObserver.onError(e); return; } + try { consumer.accept(value); } catch (Exception e) { - erroredInstructionIds.put(instructionId, true); + poisonInstructionId(instructionId); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java index 3a7a0d5a8935..37580824b558 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataGrpcMultiplexerTest.java @@ -280,6 +280,7 @@ public void testFailedProcessingCausesAdditionalInboundDataToBeIgnored() throws DESCRIPTOR, OutboundObserverFactory.clientDirect(), inboundObserver -> TestStreams.withOnNext(outboundValues::add).build()); + final AtomicBoolean closed = new AtomicBoolean(); multiplexer.registerConsumer( DATA_INSTRUCTION_ID, new CloseableFnDataReceiver() { @@ -290,7 +291,7 @@ public void flush() throws Exception { @Override public void close() throws Exception { - fail("Unexpected call"); + closed.set(true); } @Override @@ -320,6 +321,7 @@ public void accept(BeamFnApi.Elements input) throws Exception { dataInboundValues, Matchers.contains( BeamFnApi.Elements.newBuilder().addData(data.setTransformId("A").build()).build())); + assertTrue(closed.get()); } @Test diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index c91d5ba71b89..0d517503b12d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -64,7 +64,6 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse; -import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor; import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.model.pipeline.v1.RunnerApi.Coder; @@ -93,6 +92,7 @@ import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.TextFormat; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheLoader; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; @@ -108,20 +108,19 @@ import org.slf4j.LoggerFactory; /** - * Processes {@link BeamFnApi.ProcessBundleRequest}s and {@link - * BeamFnApi.ProcessBundleSplitRequest}s. + * Processes {@link ProcessBundleRequest}s and {@link BeamFnApi.ProcessBundleSplitRequest}s. * *

    {@link BeamFnApi.ProcessBundleSplitRequest}s use a {@link BundleProcessorCache cache} to * find/create a {@link BundleProcessor}. The creation of a {@link BundleProcessor} uses the - * associated {@link BeamFnApi.ProcessBundleDescriptor} definition; creating runners for each {@link + * associated {@link ProcessBundleDescriptor} definition; creating runners for each {@link * RunnerApi.FunctionSpec}; wiring them together based upon the {@code input} and {@code output} map * definitions. The {@link BundleProcessor} executes the DAG based graph by starting all runners in * reverse topological order, and finishing all runners in forward topological order. * *

    {@link BeamFnApi.ProcessBundleSplitRequest}s finds an {@code active} {@link BundleProcessor} - * associated with a currently processing {@link BeamFnApi.ProcessBundleRequest} and uses it to - * perform a split request. See breaking the - * fusion barrier for further details. + * associated with a currently processing {@link ProcessBundleRequest} and uses it to perform a + * split request. See breaking the fusion + * barrier for further details. */ @SuppressWarnings({ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) @@ -153,7 +152,7 @@ public class ProcessBundleHandler { } private final PipelineOptions options; - private final Function fnApiRegistry; + private final Function fnApiRegistry; private final BeamFnDataClient beamFnDataClient; private final BeamFnStateGrpcClientCache beamFnStateGrpcClientCache; private final FinalizeBundleHandler finalizeBundleHandler; @@ -170,7 +169,7 @@ public class ProcessBundleHandler { public ProcessBundleHandler( PipelineOptions options, Set runnerCapabilities, - Function fnApiRegistry, + Function fnApiRegistry, BeamFnDataClient beamFnDataClient, BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, FinalizeBundleHandler finalizeBundleHandler, @@ -197,7 +196,7 @@ public ProcessBundleHandler( ProcessBundleHandler( PipelineOptions options, Set runnerCapabilities, - Function fnApiRegistry, + Function fnApiRegistry, BeamFnDataClient beamFnDataClient, BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, FinalizeBundleHandler finalizeBundleHandler, @@ -216,7 +215,7 @@ public ProcessBundleHandler( this.runnerCapabilities = runnerCapabilities; this.runnerAcceptsShortIds = runnerCapabilities.contains( - BeamUrns.getUrn(RunnerApi.StandardRunnerProtocols.Enum.MONITORING_INFO_SHORT_IDS)); + BeamUrns.getUrn(StandardRunnerProtocols.Enum.MONITORING_INFO_SHORT_IDS)); this.executionStateSampler = executionStateSampler; this.urnToPTransformRunnerFactoryMap = urnToPTransformRunnerFactoryMap; this.defaultPTransformRunnerFactory = @@ -232,7 +231,7 @@ private void createRunnerAndConsumersForPTransformRecursively( String pTransformId, PTransform pTransform, Supplier processBundleInstructionId, - Supplier> cacheTokens, + Supplier> cacheTokens, Supplier> bundleCache, ProcessBundleDescriptor processBundleDescriptor, SetMultimap pCollectionIdsToConsumingPTransforms, @@ -242,7 +241,7 @@ private void createRunnerAndConsumersForPTransformRecursively( PTransformFunctionRegistry finishFunctionRegistry, Consumer addResetFunction, Consumer addTearDownFunction, - BiConsumer> addDataEndpoint, + BiConsumer> addDataEndpoint, Consumer> addTimerEndpoint, Consumer addBundleProgressReporter, BundleSplitListener splitListener, @@ -499,28 +498,29 @@ public BundleFinalizer getBundleFinalizer() { * Processes a bundle, running the start(), process(), and finish() functions. This function is * required to be reentrant. */ - public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request) + public BeamFnApi.InstructionResponse.Builder processBundle(InstructionRequest request) throws Exception { - BeamFnApi.ProcessBundleResponse.Builder response = BeamFnApi.ProcessBundleResponse.newBuilder(); - - BundleProcessor bundleProcessor = - bundleProcessorCache.get( - request, - () -> { - try { - return createBundleProcessor( - request.getProcessBundle().getProcessBundleDescriptorId(), - request.getProcessBundle()); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); + @Nullable BundleProcessor bundleProcessor = null; try { + bundleProcessor = + Preconditions.checkNotNull( + bundleProcessorCache.get( + request, + () -> { + try { + return createBundleProcessor( + request.getProcessBundle().getProcessBundleDescriptorId(), + request.getProcessBundle()); + } catch (IOException e) { + throw new RuntimeException(e); + } + })); + PTransformFunctionRegistry startFunctionRegistry = bundleProcessor.getStartFunctionRegistry(); PTransformFunctionRegistry finishFunctionRegistry = bundleProcessor.getFinishFunctionRegistry(); ExecutionStateTracker stateTracker = bundleProcessor.getStateTracker(); - + ProcessBundleResponse.Builder response = ProcessBundleResponse.newBuilder(); try (HandleStateCallsForBundle beamFnStateClient = bundleProcessor.getBeamFnStateClient()) { stateTracker.start(request.getInstructionId()); try { @@ -596,12 +596,17 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction request.getProcessBundle().getProcessBundleDescriptorId(), bundleProcessor); return BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response); } catch (Exception e) { - // Make sure we clean up from the active set of bundle processors. LOG.debug( - "Discard bundleProcessor for {} after exception: {}", + "Error processing bundle {} with bundleProcessor for {} after exception: {}", + request.getInstructionId(), request.getProcessBundle().getProcessBundleDescriptorId(), e.getMessage()); - bundleProcessorCache.discard(bundleProcessor); + if (bundleProcessor != null) { + // Make sure we clean up from the active set of bundle processors. + bundleProcessorCache.discard(bundleProcessor); + } + // Ensure that if more data arrives for the instruction it is discarded. + beamFnDataClient.poisonInstructionId(request.getInstructionId()); throw e; } } @@ -643,7 +648,7 @@ private void embedOutboundElementsIfApplicable( } } - public BeamFnApi.InstructionResponse.Builder progress(BeamFnApi.InstructionRequest request) + public BeamFnApi.InstructionResponse.Builder progress(InstructionRequest request) throws Exception { BundleProcessor bundleProcessor = bundleProcessorCache.find(request.getProcessBundleProgress().getInstructionId()); @@ -727,7 +732,7 @@ private Map finalMonitoringData(BundleProcessor bundleProces } /** Splits an active bundle. */ - public BeamFnApi.InstructionResponse.Builder trySplit(BeamFnApi.InstructionRequest request) { + public BeamFnApi.InstructionResponse.Builder trySplit(InstructionRequest request) { BundleProcessor bundleProcessor = bundleProcessorCache.find(request.getProcessBundleSplit().getInstructionId()); BeamFnApi.ProcessBundleSplitResponse.Builder response = @@ -772,8 +777,8 @@ public void discard() { } private BundleProcessor createBundleProcessor( - String bundleId, BeamFnApi.ProcessBundleRequest processBundleRequest) throws IOException { - BeamFnApi.ProcessBundleDescriptor bundleDescriptor = fnApiRegistry.apply(bundleId); + String bundleId, ProcessBundleRequest processBundleRequest) throws IOException { + ProcessBundleDescriptor bundleDescriptor = fnApiRegistry.apply(bundleId); SetMultimap pCollectionIdsToConsumingPTransforms = HashMultimap.create(); BundleProgressReporter.InMemory bundleProgressReporterAndRegistrar = @@ -799,8 +804,7 @@ private BundleProcessor createBundleProcessor( List tearDownFunctions = new ArrayList<>(); // Build a multimap of PCollection ids to PTransform ids which consume said PCollections - for (Map.Entry entry : - bundleDescriptor.getTransformsMap().entrySet()) { + for (Map.Entry entry : bundleDescriptor.getTransformsMap().entrySet()) { for (String pCollectionId : entry.getValue().getInputsMap().values()) { pCollectionIdsToConsumingPTransforms.put(pCollectionId, entry.getKey()); } @@ -848,8 +852,7 @@ public void afterBundleCommit(Instant callbackExpiry, Callback callback) { runnerCapabilities); // Create a BeamFnStateClient - for (Map.Entry entry : - bundleDescriptor.getTransformsMap().entrySet()) { + for (Map.Entry entry : bundleDescriptor.getTransformsMap().entrySet()) { // Skip anything which isn't a root. // Also force data output transforms to be unconditionally instantiated (see BEAM-10450). @@ -1090,7 +1093,7 @@ public static BundleProcessor create( abstract HandleStateCallsForBundle getBeamFnStateClient(); - abstract List getInboundEndpointApiServiceDescriptors(); + abstract List getInboundEndpointApiServiceDescriptors(); abstract List> getInboundDataEndpoints(); @@ -1117,7 +1120,7 @@ synchronized List getCacheTokens() { synchronized Cache getBundleCache() { if (this.bundleCache == null) { this.bundleCache = - new Caches.ClearableCache<>( + new ClearableCache<>( Caches.subCache(getProcessWideCache(), "Bundle", this.instructionId)); } return this.bundleCache; @@ -1264,7 +1267,7 @@ public void close() throws Exception { } @Override - public CompletableFuture handle(BeamFnApi.StateRequest.Builder requestBuilder) { + public CompletableFuture handle(StateRequest.Builder requestBuilder) { throw new IllegalStateException( String.format( "State API calls are unsupported because the " diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java index 75f3a24301c9..94d59d0fcb62 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataClient.java @@ -55,10 +55,19 @@ void registerReceiver( * successfully. * *

    It is expected that if a bundle fails during processing then the failure will become visible - * to the {@link BeamFnDataClient} during a future {@link FnDataReceiver#accept} invocation. + * to the {@link BeamFnDataClient} during a future {@link FnDataReceiver#accept} invocation or via + * a call to {@link #poisonInstructionId}. */ void unregisterReceiver(String instructionId, List apiServiceDescriptors); + /** + * Poisons the instruction id, indicating that future data arriving for it should be discarded. + * Unregisters the receiver if was registered. + * + * @param instructionId + */ + void poisonInstructionId(String instructionId); + /** * Creates a {@link BeamFnDataOutboundAggregator} for buffering and sending outbound data and * timers over the data plane. It is important that {@link diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java index 981b115c58e7..cd1ac26e364d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java @@ -82,6 +82,14 @@ public void unregisterReceiver( } } + @Override + public void poisonInstructionId(String instructionId) { + LOG.debug("Poisoning instruction {}", instructionId); + for (BeamFnDataGrpcMultiplexer client : multiplexerCache.values()) { + client.poisonInstructionId(instructionId); + } + } + @Override public BeamFnDataOutboundAggregator createOutboundAggregator( ApiServiceDescriptor apiServiceDescriptor, diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java index 9328dc86c009..acfd3bb70202 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PTransformRunnerFactoryTestContext.java @@ -92,6 +92,11 @@ public BeamFnDataOutboundAggregator createOutboundAggregator( boolean collectElementsIfNoFlushes) { throw new UnsupportedOperationException("Unexpected call during test."); } + + @Override + public void poisonInstructionId(String instructionId) { + throw new UnsupportedOperationException("Unexpected call during test."); + } }) .beamFnStateClient( new BeamFnStateClient() { diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index 2d1e323707f7..95b404aa6203 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -1516,6 +1516,7 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { // Ensure that we unregister during successful processing verify(beamFnDataClient).registerReceiver(eq("instructionId"), any(), any()); + verify(beamFnDataClient).poisonInstructionId(eq("instructionId")); verifyNoMoreInteractions(beamFnDataClient); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java index 3489fe766891..514cf61ded40 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java @@ -23,14 +23,17 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.model.fnexecution.v1.BeamFnApi; @@ -281,6 +284,93 @@ public StreamObserver data( } } + @Test + public void testForInboundConsumerThatIsPoisoned() throws Exception { + CountDownLatch waitForClientToConnect = new CountDownLatch(1); + CountDownLatch receivedAElement = new CountDownLatch(1); + Collection> inboundValuesA = new ConcurrentLinkedQueue<>(); + Collection inboundServerValues = new ConcurrentLinkedQueue<>(); + AtomicReference> outboundServerObserver = + new AtomicReference<>(); + CallStreamObserver inboundServerObserver = + TestStreams.withOnNext(inboundServerValues::add).build(); + + Endpoints.ApiServiceDescriptor apiServiceDescriptor = + Endpoints.ApiServiceDescriptor.newBuilder() + .setUrl(this.getClass().getName() + "-" + UUID.randomUUID()) + .build(); + Server server = + InProcessServerBuilder.forName(apiServiceDescriptor.getUrl()) + .addService( + new BeamFnDataGrpc.BeamFnDataImplBase() { + @Override + public StreamObserver data( + StreamObserver outboundObserver) { + outboundServerObserver.set(outboundObserver); + waitForClientToConnect.countDown(); + return inboundServerObserver; + } + }) + .build(); + server.start(); + + try { + ManagedChannel channel = + InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build(); + + BeamFnDataGrpcClient clientFactory = + new BeamFnDataGrpcClient( + PipelineOptionsFactory.create(), + (Endpoints.ApiServiceDescriptor descriptor) -> channel, + OutboundObserverFactory.trivial()); + + BeamFnDataInboundObserver observerA = + BeamFnDataInboundObserver.forConsumers( + Arrays.asList( + DataEndpoint.create( + TRANSFORM_ID_A, + CODER, + (WindowedValue elem) -> { + receivedAElement.countDown(); + inboundValuesA.add(elem); + })), + Collections.emptyList()); + CompletableFuture future = + CompletableFuture.runAsync( + () -> { + try { + observerA.awaitCompletion(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + clientFactory.registerReceiver( + INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), observerA); + + waitForClientToConnect.await(); + outboundServerObserver.get().onNext(ELEMENTS_B_1); + clientFactory.poisonInstructionId(INSTRUCTION_ID_B); + + outboundServerObserver.get().onNext(ELEMENTS_B_1); + outboundServerObserver.get().onNext(ELEMENTS_A_1); + assertTrue(receivedAElement.await(5, TimeUnit.SECONDS)); + + clientFactory.poisonInstructionId(INSTRUCTION_ID_A); + try { + future.get(); + fail(); // We expect the awaitCompletion to fail due to closing. + } catch (Exception ignored) { + } + + outboundServerObserver.get().onNext(ELEMENTS_A_2); + + assertThat(inboundValuesA, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"))); + } finally { + server.shutdownNow(); + } + } + @Test public void testForOutboundConsumer() throws Exception { CountDownLatch waitForInboundServerValuesCompletion = new CountDownLatch(2); From 682eaeff69d944dc1ed399db9e4ceeacdc72e710 Mon Sep 17 00:00:00 2001 From: Steven van Rossum Date: Tue, 12 Nov 2024 13:55:03 +0100 Subject: [PATCH 155/181] [KafkaIO] Fix potential data race in ReadFromKafkaDoFn.AverageRecordSize (#33073) * Add comments clarifying offets and record size calculation --- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 44 ++++++++++++++++--- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 7c2064883488..add76c9682a0 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -27,6 +27,8 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.TimeUnit; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors; import org.apache.beam.sdk.io.kafka.KafkaIOUtils.MovingAvg; @@ -338,13 +340,18 @@ public WatermarkEstimator newWatermarkEstimator( public double getSize( @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange offsetRange) throws Exception { + // If present, estimates the record size to offset gap ratio. Compacted topics may hold less + // records than the estimated offset range due to record deletion within a partition. final LoadingCache avgRecordSize = Preconditions.checkStateNotNull(this.avgRecordSize); - double numRecords = + // The tracker estimates the offset range by subtracting the last claimed position from the + // currently observed end offset for the partition belonging to this split. + double estimatedOffsetRange = restrictionTracker(kafkaSourceDescriptor, offsetRange).getProgress().getWorkRemaining(); // Before processing elements, we don't have a good estimated size of records and offset gap. + // Return the estimated offset range without scaling by a size to gap ratio. if (!avgRecordSize.asMap().containsKey(kafkaSourceDescriptor.getTopicPartition())) { - return numRecords; + return estimatedOffsetRange; } if (offsetEstimatorCache != null) { for (Map.Entry tp : @@ -353,7 +360,12 @@ public double getSize( } } - return avgRecordSize.get(kafkaSourceDescriptor.getTopicPartition()).getTotalSize(numRecords); + // When processing elements, a moving average estimates the size of records and offset gap. + // Return the estimated offset range scaled by the estimated size to gap ratio. + return estimatedOffsetRange + * avgRecordSize + .get(kafkaSourceDescriptor.getTopicPartition()) + .estimateRecordByteSizeToOffsetCountRatio(); } @NewTracker @@ -665,8 +677,15 @@ private Map overrideBootstrapServersConfig( return config; } + // TODO: Collapse the two moving average trackers into a single accumulator using a single Guava + // AtomicDouble. Note that this requires that a single thread will call update and that while get + // may be called by multiple threads the method must only load the accumulator itself. + @ThreadSafe private static class AverageRecordSize { + @GuardedBy("this") private MovingAvg avgRecordSize; + + @GuardedBy("this") private MovingAvg avgRecordGap; public AverageRecordSize() { @@ -674,13 +693,26 @@ public AverageRecordSize() { this.avgRecordGap = new MovingAvg(); } - public void update(int recordSize, long gap) { + public synchronized void update(int recordSize, long gap) { avgRecordSize.update(recordSize); avgRecordGap.update(gap); } - public double getTotalSize(double numRecords) { - return avgRecordSize.get() * numRecords / (1 + avgRecordGap.get()); + public double estimateRecordByteSizeToOffsetCountRatio() { + double avgRecordSize; + double avgRecordGap; + + synchronized (this) { + avgRecordSize = this.avgRecordSize.get(); + avgRecordGap = this.avgRecordGap.get(); + } + + // The offset increases between records in a batch fetched from a compacted topic may be + // greater than 1. Compacted topics only store records with the greatest offset per key per + // partition, the records in between are deleted and will not be observed by a consumer. + // The observed gap between offsets is used to estimate the number of records that are likely + // to be observed for the provided number of records. + return avgRecordSize / (1 + avgRecordGap); } } From c4143315571e6481a0b3e976d6feeecffec904d4 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 12 Nov 2024 10:10:33 -0500 Subject: [PATCH 156/181] Revert "Distroless python sdk (#32960)" This reverts commit 81f35ab62298a2ec9fadeded82461b363b6401db. --- sdks/python/container/Dockerfile | 26 +---------- sdks/python/container/common.gradle | 9 +--- sdks/python/test-suites/dataflow/build.gradle | 6 --- .../python/test-suites/dataflow/common.gradle | 45 ------------------- sdks/python/test-suites/gradle.properties | 3 -- 5 files changed, 2 insertions(+), 87 deletions(-) diff --git a/sdks/python/container/Dockerfile b/sdks/python/container/Dockerfile index f3d22a4b5bc6..7bea6229668f 100644 --- a/sdks/python/container/Dockerfile +++ b/sdks/python/container/Dockerfile @@ -103,33 +103,9 @@ RUN if [ "$pull_licenses" = "true" ] ; then \ python /tmp/license_scripts/pull_licenses_py.py ; \ fi -FROM beam as base +FROM beam ARG pull_licenses COPY --from=third_party_licenses /opt/apache/beam/third_party_licenses /opt/apache/beam/third_party_licenses RUN if [ "$pull_licenses" != "true" ] ; then \ rm -rf /opt/apache/beam/third_party_licenses ; \ fi - -ARG TARGETARCH -FROM gcr.io/distroless/python3-debian12:latest-${TARGETARCH} as distroless -ARG py_version - -# Contains header files needed by the Python interpreter. -COPY --from=base /usr/local/include /usr/local/include - -# Contains the Python interpreter executables. -COPY --from=base /usr/local/bin /usr/local/bin - -# Contains the Python library dependencies. -COPY --from=base /usr/local/lib /usr/local/lib - -# Python standard library modules. -COPY --from=base /usr/lib/python${py_version} /usr/lib/python${py_version} - -# Contains the boot entrypoint and related files such as licenses. -COPY --from=base /opt /opt - -ENV PATH "$PATH:/usr/local/bin" - -# Despite the ENTRYPOINT set above, need to reset since deriving the layer derives from a different image. -ENTRYPOINT ["/opt/apache/beam/boot"] diff --git a/sdks/python/container/common.gradle b/sdks/python/container/common.gradle index 885662362894..0175778a6301 100644 --- a/sdks/python/container/common.gradle +++ b/sdks/python/container/common.gradle @@ -71,16 +71,10 @@ def copyLauncherDependencies = tasks.register("copyLauncherDependencies", Copy) } def pushContainers = project.rootProject.hasProperty(["isRelease"]) || project.rootProject.hasProperty("push-containers") -def baseBuildTarget = 'base' -def buildTarget = project.findProperty('container-build-target') ?: 'base' -var imageName = project.docker_image_default_repo_prefix + "python${project.ext.pythonVersion}_sdk" -if (buildTarget != baseBuildTarget) { - imageName += "_${buildTarget}" -} docker { name containerImageName( - name: imageName, + name: project.docker_image_default_repo_prefix + "python${project.ext.pythonVersion}_sdk", root: project.rootProject.hasProperty(["docker-repository-root"]) ? project.rootProject["docker-repository-root"] : project.docker_image_default_repo_root, @@ -96,7 +90,6 @@ docker { platform(*project.containerPlatforms()) load project.useBuildx() && !pushContainers push pushContainers - target buildTarget } dockerPrepare.dependsOn copyLauncherDependencies diff --git a/sdks/python/test-suites/dataflow/build.gradle b/sdks/python/test-suites/dataflow/build.gradle index 4500b395b0a6..04a79683fd36 100644 --- a/sdks/python/test-suites/dataflow/build.gradle +++ b/sdks/python/test-suites/dataflow/build.gradle @@ -60,12 +60,6 @@ task validatesContainerTests { } } -task validatesDistrolessContainerTests { - getVersionsAsList('distroless_python_versions').each { - dependsOn.add(":sdks:python:test-suites:dataflow:py${getVersionSuffix(it)}:validatesDistrolessContainer") - } -} - task examplesPostCommit { getVersionsAsList('dataflow_examples_postcommit_py_versions').each { dependsOn.add(":sdks:python:test-suites:dataflow:py${getVersionSuffix(it)}:examples") diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index cd0db4a62f77..71d44652bc7e 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -380,51 +380,6 @@ task validatesContainer() { } } -/** - * Validates the distroless (https://github.com/GoogleContainerTools/distroless) variant of the Python SDK container - * image (sdks/python/container/Dockerfile). - * To test a single version of Python: - * ./gradlew :sdks:python:test-suites:dataflow:py311:validatesDistrolessContainer - * See https://cwiki.apache.org/confluence/display/BEAM/Python+Tips#PythonTips-VirtualEnvironmentSetup - * for more information on setting up different Python versions. - */ -task validatesDistrolessContainer() { - def pyversion = "${project.ext.pythonVersion.replace('.', '')}" - def buildTarget = 'distroless' - def repository = "us.gcr.io/apache-beam-testing/${System.getenv('USER')}" - def tag = java.time.Instant.now().getEpochSecond() - def imageURL = "${repository}/beam_python${project.ext.pythonVersion}_sdk_${buildTarget}:${tag}" - project.rootProject.ext['docker-repository-root'] = repository - project.rootProject.ext['container-build-target'] = buildTarget - project.rootProject.ext['docker-tag'] = tag - if (project.rootProject.hasProperty('dry-run')) { - println "Running in dry run mode: imageURL: ${imageURL}, pyversion: ${pyversion}, buildTarget: ${buildTarget}, repository: ${repository}, tag: ${tag}, envdir: ${envdir}" - return - } - dependsOn 'initializeForDataflowJob' - dependsOn ":sdks:python:container:py${pyversion}:docker" - dependsOn ":sdks:python:container:py${pyversion}:dockerPush" - def testTarget = "apache_beam/examples/wordcount_it_test.py::WordCountIT::test_wordcount_it" - def argMap = [ - "output": "gs://temp-storage-for-end-to-end-tests/py-it-cloud/output", - "project": "apache-beam-testing", - "region": "us-central1", - "runner": "TestDataflowRunner", - "sdk_container_image": "${imageURL}", - "sdk_location": "container", - "staging_location": "gs://temp-storage-for-end-to-end-tests/staging-it", - "temp_location": "gs://temp-storage-for-end-to-end-tests/temp-it", - ] - def cmdArgs = mapToArgString(argMap) - doLast { - exec { - workingDir = "${rootDir}/sdks/python" - executable 'sh' - args '-c', ". ${envdir}/bin/activate && pytest ${testTarget} --test-pipeline-options=\"${cmdArgs}\"" - } - } -} - task validatesContainerARM() { def pyversion = "${project.ext.pythonVersion.replace('.', '')}" dependsOn 'initializeForDataflowJob' diff --git a/sdks/python/test-suites/gradle.properties b/sdks/python/test-suites/gradle.properties index 08266c4b0dd5..d027cd3144d3 100644 --- a/sdks/python/test-suites/gradle.properties +++ b/sdks/python/test-suites/gradle.properties @@ -54,6 +54,3 @@ prism_examples_postcommit_py_versions=3.9,3.12 # cross language postcommit python test suites cross_language_validates_py_versions=3.9,3.12 - -# Python versions to support distroless variants -distroless_python_versions=3.9,3.10,3.11,3.12 From 93a3dea312a38ed37872562fe91b997a3d87990b Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 12 Nov 2024 10:11:33 -0500 Subject: [PATCH 157/181] Trigger post commit python test. --- .github/trigger_files/beam_PostCommit_Python.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 1eb60f6e4959..9e1d1e1b80dd 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 3 + "modification": 4 } From 941e5421fbb3c3b46746b5c1e7ec152f95220e1d Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:00:30 -0500 Subject: [PATCH 158/181] Update Enrichment Handlers to PEP 585 typing (#33087) * Update Enrichment Handlers to PEP 585 typing * BT test * linting --- .../enrichment_handlers/bigquery.py | 30 +++++++++---------- .../enrichment_handlers/bigtable.py | 5 ++-- .../enrichment_handlers/bigtable_it_test.py | 9 ++---- .../feast_feature_store.py | 7 ++--- .../feast_feature_store_it_test.py | 2 +- .../vertex_ai_feature_store.py | 5 ++-- 6 files changed, 25 insertions(+), 33 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index ea98fb6b0bbd..06b40bf38cc1 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -14,11 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from collections.abc import Callable +from collections.abc import Mapping from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Mapping from typing import Optional from typing import Union @@ -30,7 +28,7 @@ from apache_beam.transforms.enrichment import EnrichmentSourceHandler QueryFn = Callable[[beam.Row], str] -ConditionValueFn = Callable[[beam.Row], List[Any]] +ConditionValueFn = Callable[[beam.Row], list[Any]] def _validate_bigquery_metadata( @@ -54,8 +52,8 @@ def _validate_bigquery_metadata( "`condition_value_fn`") -class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, List[Row]], - Union[Row, List[Row]]]): +class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]], + Union[Row, list[Row]]]): """Enrichment handler for Google Cloud BigQuery. Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` @@ -83,8 +81,8 @@ def __init__( *, table_name: str = "", row_restriction_template: str = "", - fields: Optional[List[str]] = None, - column_names: Optional[List[str]] = None, + fields: Optional[list[str]] = None, + column_names: Optional[list[str]] = None, condition_value_fn: Optional[ConditionValueFn] = None, query_fn: Optional[QueryFn] = None, min_batch_size: int = 1, @@ -107,10 +105,10 @@ def __init__( row_restriction_template (str): A template string for the `WHERE` clause in the BigQuery query with placeholders (`{}`) to dynamically filter rows based on input data. - fields: (Optional[List[str]]) List of field names present in the input + fields: (Optional[list[str]]) List of field names present in the input `beam.Row`. These are used to construct the WHERE clause (if `condition_value_fn` is not provided). - column_names: (Optional[List[str]]) Names of columns to select from the + column_names: (Optional[list[str]]) Names of columns to select from the BigQuery table. If not provided, all columns (`*`) are selected. condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function that takes a `beam.Row` and returns a list of value to populate in the @@ -179,11 +177,11 @@ def create_row_key(self, row: beam.Row): return (tuple(row_dict[field] for field in self.fields)) raise ValueError("Either fields or condition_value_fn must be specified") - def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): - if isinstance(request, List): + def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): + if isinstance(request, list): values = [] responses = [] - requests_map: Dict[Any, Any] = {} + requests_map: dict[Any, Any] = {} batch_size = len(request) raw_query = self.query_template if batch_size > 1: @@ -230,8 +228,8 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): def __exit__(self, exc_type, exc_val, exc_tb): self.client.close() - def get_cache_key(self, request: Union[beam.Row, List[beam.Row]]): - if isinstance(request, List): + def get_cache_key(self, request: Union[beam.Row, list[beam.Row]]): + if isinstance(request, list): cache_keys = [] for req in request: req_dict = req._asdict() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py index ddb62c2f60d5..c251ab05ecab 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py @@ -15,9 +15,8 @@ # limitations under the License. # import logging +from collections.abc import Callable from typing import Any -from typing import Callable -from typing import Dict from typing import Optional from google.api_core.exceptions import NotFound @@ -115,7 +114,7 @@ def __call__(self, request: beam.Row, *args, **kwargs): Args: request: the input `beam.Row` to enrich. """ - response_dict: Dict[str, Any] = {} + response_dict: dict[str, Any] = {} row_key_str: str = "" try: if self._row_key_fn: diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py index 79d73178e94e..6bf57cefacbe 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py @@ -18,10 +18,7 @@ import datetime import logging import unittest -from typing import Dict -from typing import List from typing import NamedTuple -from typing import Tuple from unittest.mock import MagicMock import pytest @@ -57,8 +54,8 @@ class ValidateResponse(beam.DoFn): def __init__( self, n_fields: int, - fields: List[str], - enriched_fields: Dict[str, List[str]], + fields: list[str], + enriched_fields: dict[str, list[str]], include_timestamp: bool = False, ): self.n_fields = n_fields @@ -88,7 +85,7 @@ def process(self, element: beam.Row, *args, **kwargs): "Response from bigtable should contain a %s column_family with " "%s columns." % (column_family, columns)) if (self._include_timestamp and - not isinstance(element_dict[column_family][key][0], Tuple)): # type: ignore[arg-type] + not isinstance(element_dict[column_family][key][0], tuple)): raise BeamAssertException( "Response from bigtable should contain timestamp associated with " "its value.") diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py index dc2a71786f65..f8e8b4db1d7f 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store.py @@ -16,11 +16,10 @@ # import logging import tempfile +from collections.abc import Callable +from collections.abc import Mapping from pathlib import Path from typing import Any -from typing import Callable -from typing import List -from typing import Mapping from typing import Optional import apache_beam as beam @@ -95,7 +94,7 @@ class FeastFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row, def __init__( self, feature_store_yaml_path: str, - feature_names: Optional[List[str]] = None, + feature_names: Optional[list[str]] = None, feature_service_name: Optional[str] = "", full_feature_names: Optional[bool] = False, entity_id: str = "", diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py index 89cb39c2c19c..9c4dab3d68b8 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/feast_feature_store_it_test.py @@ -22,8 +22,8 @@ """ import unittest +from collections.abc import Mapping from typing import Any -from typing import Mapping import pytest diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py index 753b04e1793d..b6de3aa1c826 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py @@ -15,7 +15,6 @@ # limitations under the License. # import logging -from typing import List import proto from google.api_core.exceptions import NotFound @@ -209,7 +208,7 @@ def __init__( api_endpoint: str, feature_store_id: str, entity_type_id: str, - feature_ids: List[str], + feature_ids: list[str], row_key: str, *, exception_level: ExceptionLevel = ExceptionLevel.WARN, @@ -224,7 +223,7 @@ def __init__( Vertex AI Feature Store (Legacy). feature_store_id (str): The id of the Vertex AI Feature Store (Legacy). entity_type_id (str): The entity type of the feature store. - feature_ids (List[str]): A list of feature-ids to fetch + feature_ids (list[str]): A list of feature-ids to fetch from the Feature Store. row_key (str): The row key field name containing the entity id for the feature values. From 628348b8bafb3856b26aed6d2bd20b97f938aad0 Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:50:47 -0500 Subject: [PATCH 159/181] Managed BigQueryIO (#31486) * managed bigqueryio * spotless * move managed dependency to test only * cleanup after merging snake_case PR * choose write method based on boundedness and pipeline options * rename bigquery write config class * spotless * change read output tag to 'output' * spotless * revert logic that depends on DataflowServiceOptions. switching BQ methods can instead be done in Dataflow service side * spotless * fix typo * separate BQ write config to a new class * fix doc * resolve after syncing to HEAD * spotless * fork on batch/streaming * cleanup * spotless * move forking logic to BQ schematransform side * add file loads translation and tests; add test checks that the correct transform is chosen * set top-level wrapper to be the underlying managed BQ transform urn; change tests to verify underlying transform name * move unit tests to respectvie schematransform test classes * expose to Python SDK as well --- .../beam_PostCommit_Java_DataflowV2.json | 3 +- ...am_PostCommit_Python_Xlang_Gcp_Direct.json | 2 +- .../pipeline/v1/external_transforms.proto | 4 + .../io/google-cloud-platform/build.gradle | 1 + .../expansion-service/build.gradle | 3 + ...oadsWriteSchemaTransformConfiguration.java | 72 ----- ...FileLoadsWriteSchemaTransformProvider.java | 256 ----------------- ...ueryDirectReadSchemaTransformProvider.java | 33 ++- ...QueryFileLoadsSchemaTransformProvider.java | 137 +++++++++ .../BigQuerySchemaTransformTranslation.java | 81 ++++++ ...torageWriteApiSchemaTransformProvider.java | 226 +-------------- .../providers/BigQueryWriteConfiguration.java | 218 ++++++++++++++ .../BigQueryWriteSchemaTransformProvider.java | 87 ++++++ ...LoadsWriteSchemaTransformProviderTest.java | 265 ------------------ ...yFileLoadsSchemaTransformProviderTest.java | 146 ++++++++++ .../bigquery/providers/BigQueryManagedIT.java | 153 ++++++++++ ...igQuerySchemaTransformTranslationTest.java | 205 ++++++++++++++ ...geWriteApiSchemaTransformProviderTest.java | 83 +++--- .../org/apache/beam/sdk/managed/Managed.java | 11 +- .../ManagedSchemaTransformProvider.java | 37 ++- .../managed/ManagedTransformConstants.java | 18 ++ .../ManagedSchemaTransformProviderTest.java | 3 +- sdks/python/apache_beam/transforms/managed.py | 8 +- 23 files changed, 1187 insertions(+), 865 deletions(-) delete mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformConfiguration.java delete mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProvider.java create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslation.java create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteSchemaTransformProvider.java delete mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProviderTest.java create mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java create mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java create mode 100644 sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslationTest.java diff --git a/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json b/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json index a03c067d2c4e..1efc8e9e4405 100644 --- a/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json +++ b/.github/trigger_files/beam_PostCommit_Java_DataflowV2.json @@ -1,3 +1,4 @@ { - "comment": "Modify this file in a trivial way to cause this test suite to run" + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json index b26833333238..e3d6056a5de9 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 2 + "modification": 1 } diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto index b03350966d6c..f102e82bafa6 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto @@ -70,6 +70,10 @@ message ManagedTransforms { "beam:schematransform:org.apache.beam:kafka_read:v1"]; KAFKA_WRITE = 3 [(org.apache.beam.model.pipeline.v1.beam_urn) = "beam:schematransform:org.apache.beam:kafka_write:v1"]; + BIGQUERY_READ = 4 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:bigquery_storage_read:v1"]; + BIGQUERY_WRITE = 5 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:bigquery_write:v1"]; } } diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle index 3e322d976c1a..2acce3e94cc2 100644 --- a/sdks/java/io/google-cloud-platform/build.gradle +++ b/sdks/java/io/google-cloud-platform/build.gradle @@ -159,6 +159,7 @@ dependencies { testImplementation project(path: ":sdks:java:extensions:google-cloud-platform-core", configuration: "testRuntimeMigration") testImplementation project(path: ":sdks:java:extensions:protobuf", configuration: "testRuntimeMigration") testImplementation project(path: ":runners:direct-java", configuration: "shadow") + testImplementation project(":sdks:java:managed") testImplementation project(path: ":sdks:java:io:common") testImplementation project(path: ":sdks:java:testing:test-utils") testImplementation library.java.commons_math3 diff --git a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle index 1288d91964e1..f6c6f07d0cdf 100644 --- a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle +++ b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle @@ -36,6 +36,9 @@ dependencies { permitUnusedDeclared project(":sdks:java:io:google-cloud-platform") // BEAM-11761 implementation project(":sdks:java:extensions:schemaio-expansion-service") permitUnusedDeclared project(":sdks:java:extensions:schemaio-expansion-service") // BEAM-11761 + implementation project(":sdks:java:managed") + permitUnusedDeclared project(":sdks:java:managed") // BEAM-11761 + runtimeOnly library.java.slf4j_jdk14 } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformConfiguration.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformConfiguration.java deleted file mode 100644 index f634b5ec6f60..000000000000 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformConfiguration.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.bigquery; - -import com.google.auto.value.AutoValue; -import org.apache.beam.sdk.schemas.AutoValueSchema; -import org.apache.beam.sdk.schemas.annotations.DefaultSchema; - -/** - * Configuration for writing to BigQuery. - * - *

    This class is meant to be used with {@link BigQueryFileLoadsWriteSchemaTransformProvider}. - * - *

    Internal only: This class is actively being worked on, and it will likely change. We - * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam - * repository. - */ -@DefaultSchema(AutoValueSchema.class) -@AutoValue -public abstract class BigQueryFileLoadsWriteSchemaTransformConfiguration { - - /** Instantiates a {@link BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder}. */ - public static Builder builder() { - return new AutoValue_BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder(); - } - - /** - * Writes to the given table specification. See {@link BigQueryIO.Write#to(String)}} for the - * expected format. - */ - public abstract String getTableSpec(); - - /** Specifies whether the table should be created if it does not exist. */ - public abstract String getCreateDisposition(); - - /** Specifies what to do with existing data in the table, in case the table already exists. */ - public abstract String getWriteDisposition(); - - @AutoValue.Builder - public abstract static class Builder { - - /** - * Writes to the given table specification. See {@link BigQueryIO.Write#to(String)}} for the - * expected format. - */ - public abstract Builder setTableSpec(String value); - - /** Specifies whether the table should be created if it does not exist. */ - public abstract Builder setCreateDisposition(String value); - - /** Specifies what to do with existing data in the table, in case the table already exists. */ - public abstract Builder setWriteDisposition(String value); - - /** Builds the {@link BigQueryFileLoadsWriteSchemaTransformConfiguration} configuration. */ - public abstract BigQueryFileLoadsWriteSchemaTransformConfiguration build(); - } -} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProvider.java deleted file mode 100644 index 3212e2a30348..000000000000 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProvider.java +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.bigquery; - -import com.google.api.services.bigquery.model.Table; -import com.google.api.services.bigquery.model.TableReference; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import com.google.auto.service.AutoService; -import java.io.IOException; -import java.util.Collections; -import java.util.List; -import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.io.InvalidConfigurationException; -import org.apache.beam.sdk.schemas.transforms.SchemaTransform; -import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; -import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionRowTuple; -import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; - -/** - * An implementation of {@link TypedSchemaTransformProvider} for BigQuery write jobs configured - * using {@link BigQueryFileLoadsWriteSchemaTransformConfiguration}. - * - *

    Internal only: This class is actively being worked on, and it will likely change. We - * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam - * repository. - */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -@Internal -@AutoService(SchemaTransformProvider.class) -public class BigQueryFileLoadsWriteSchemaTransformProvider - extends TypedSchemaTransformProvider { - - private static final String IDENTIFIER = - "beam:schematransform:org.apache.beam:bigquery_fileloads_write:v1"; - static final String INPUT_TAG = "INPUT"; - - /** Returns the expected class of the configuration. */ - @Override - protected Class configurationClass() { - return BigQueryFileLoadsWriteSchemaTransformConfiguration.class; - } - - /** Returns the expected {@link SchemaTransform} of the configuration. */ - @Override - protected SchemaTransform from(BigQueryFileLoadsWriteSchemaTransformConfiguration configuration) { - return new BigQueryWriteSchemaTransform(configuration); - } - - /** Implementation of the {@link TypedSchemaTransformProvider} identifier method. */ - @Override - public String identifier() { - return IDENTIFIER; - } - - /** - * Implementation of the {@link TypedSchemaTransformProvider} inputCollectionNames method. Since a - * single is expected, this returns a list with a single name. - */ - @Override - public List inputCollectionNames() { - return Collections.singletonList(INPUT_TAG); - } - - /** - * Implementation of the {@link TypedSchemaTransformProvider} outputCollectionNames method. Since - * no output is expected, this returns an empty list. - */ - @Override - public List outputCollectionNames() { - return Collections.emptyList(); - } - - /** - * A {@link SchemaTransform} that performs {@link BigQueryIO.Write}s based on a {@link - * BigQueryFileLoadsWriteSchemaTransformConfiguration}. - */ - protected static class BigQueryWriteSchemaTransform extends SchemaTransform { - /** An instance of {@link BigQueryServices} used for testing. */ - private BigQueryServices testBigQueryServices = null; - - private final BigQueryFileLoadsWriteSchemaTransformConfiguration configuration; - - BigQueryWriteSchemaTransform(BigQueryFileLoadsWriteSchemaTransformConfiguration configuration) { - this.configuration = configuration; - } - - @Override - public void validate(PipelineOptions options) { - if (!configuration.getCreateDisposition().equals(CreateDisposition.CREATE_NEVER.name())) { - return; - } - - BigQueryOptions bigQueryOptions = options.as(BigQueryOptions.class); - - BigQueryServices bigQueryServices = new BigQueryServicesImpl(); - if (testBigQueryServices != null) { - bigQueryServices = testBigQueryServices; - } - - DatasetService datasetService = bigQueryServices.getDatasetService(bigQueryOptions); - TableReference tableReference = BigQueryUtils.toTableReference(configuration.getTableSpec()); - - try { - Table table = datasetService.getTable(tableReference); - if (table == null) { - throw new NullPointerException(); - } - - if (table.getSchema() == null) { - throw new InvalidConfigurationException( - String.format("could not fetch schema for table: %s", configuration.getTableSpec())); - } - - } catch (NullPointerException | InterruptedException | IOException ex) { - throw new InvalidConfigurationException( - String.format( - "could not fetch table %s, error: %s", - configuration.getTableSpec(), ex.getMessage())); - } - } - - @Override - public PCollectionRowTuple expand(PCollectionRowTuple input) { - validate(input); - PCollection rowPCollection = input.get(INPUT_TAG); - Schema schema = rowPCollection.getSchema(); - BigQueryIO.Write write = toWrite(schema); - if (testBigQueryServices != null) { - write = write.withTestServices(testBigQueryServices); - } - - PCollection tableRowPCollection = - rowPCollection.apply( - MapElements.into(TypeDescriptor.of(TableRow.class)).via(BigQueryUtils::toTableRow)); - tableRowPCollection.apply(write); - return PCollectionRowTuple.empty(input.getPipeline()); - } - - /** Instantiates a {@link BigQueryIO.Write} from a {@link Schema}. */ - BigQueryIO.Write toWrite(Schema schema) { - TableSchema tableSchema = BigQueryUtils.toTableSchema(schema); - CreateDisposition createDisposition = - CreateDisposition.valueOf(configuration.getCreateDisposition()); - WriteDisposition writeDisposition = - WriteDisposition.valueOf(configuration.getWriteDisposition()); - - return BigQueryIO.writeTableRows() - .to(configuration.getTableSpec()) - .withCreateDisposition(createDisposition) - .withWriteDisposition(writeDisposition) - .withSchema(tableSchema); - } - - /** Setter for testing using {@link BigQueryServices}. */ - @VisibleForTesting - void setTestBigQueryServices(BigQueryServices testBigQueryServices) { - this.testBigQueryServices = testBigQueryServices; - } - - /** Validate a {@link PCollectionRowTuple} input. */ - void validate(PCollectionRowTuple input) { - if (!input.has(INPUT_TAG)) { - throw new IllegalArgumentException( - String.format( - "%s %s is missing expected tag: %s", - getClass().getSimpleName(), input.getClass().getSimpleName(), INPUT_TAG)); - } - - PCollection rowInput = input.get(INPUT_TAG); - Schema sourceSchema = rowInput.getSchema(); - - if (sourceSchema == null) { - throw new IllegalArgumentException( - String.format("%s is null for input of tag: %s", Schema.class, INPUT_TAG)); - } - - if (!configuration.getCreateDisposition().equals(CreateDisposition.CREATE_NEVER.name())) { - return; - } - - BigQueryOptions bigQueryOptions = input.getPipeline().getOptions().as(BigQueryOptions.class); - - BigQueryServices bigQueryServices = new BigQueryServicesImpl(); - if (testBigQueryServices != null) { - bigQueryServices = testBigQueryServices; - } - - DatasetService datasetService = bigQueryServices.getDatasetService(bigQueryOptions); - TableReference tableReference = BigQueryUtils.toTableReference(configuration.getTableSpec()); - - try { - Table table = datasetService.getTable(tableReference); - if (table == null) { - throw new NullPointerException(); - } - - TableSchema tableSchema = table.getSchema(); - if (tableSchema == null) { - throw new NullPointerException(); - } - - Schema destinationSchema = BigQueryUtils.fromTableSchema(tableSchema); - if (destinationSchema == null) { - throw new NullPointerException(); - } - - validateMatching(sourceSchema, destinationSchema); - - } catch (NullPointerException | InterruptedException | IOException e) { - throw new InvalidConfigurationException( - String.format( - "could not validate input for create disposition: %s and table: %s, error: %s", - configuration.getCreateDisposition(), - configuration.getTableSpec(), - e.getMessage())); - } - } - - void validateMatching(Schema sourceSchema, Schema destinationSchema) { - if (!sourceSchema.equals(destinationSchema)) { - throw new IllegalArgumentException( - String.format( - "source and destination schema mismatch for table: %s", - configuration.getTableSpec())); - } - } - } -} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryDirectReadSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryDirectReadSchemaTransformProvider.java index 8b8e8179ce7d..15b1b01d7f6c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryDirectReadSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryDirectReadSchemaTransformProvider.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.gcp.bigquery.providers; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; @@ -26,6 +27,7 @@ import java.util.Collections; import java.util.List; import javax.annotation.Nullable; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead; @@ -33,7 +35,9 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider.BigQueryDirectReadSchemaTransformConfiguration; import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; @@ -62,7 +66,7 @@ public class BigQueryDirectReadSchemaTransformProvider extends TypedSchemaTransformProvider { - private static final String OUTPUT_TAG = "OUTPUT_ROWS"; + public static final String OUTPUT_TAG = "output"; @Override protected Class configurationClass() { @@ -76,7 +80,7 @@ protected SchemaTransform from(BigQueryDirectReadSchemaTransformConfiguration co @Override public String identifier() { - return "beam:schematransform:org.apache.beam:bigquery_storage_read:v1"; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_READ); } @Override @@ -139,6 +143,10 @@ public static Builder builder() { @Nullable public abstract List getSelectedFields(); + @SchemaFieldDescription("Use this Cloud KMS key to encrypt your data") + @Nullable + public abstract String getKmsKey(); + @Nullable /** Builder for the {@link BigQueryDirectReadSchemaTransformConfiguration}. */ @AutoValue.Builder @@ -151,6 +159,8 @@ public abstract static class Builder { public abstract Builder setSelectedFields(List selectedFields); + public abstract Builder setKmsKey(String kmsKey); + /** Builds a {@link BigQueryDirectReadSchemaTransformConfiguration} instance. */ public abstract BigQueryDirectReadSchemaTransformConfiguration build(); } @@ -161,7 +171,7 @@ public abstract static class Builder { * BigQueryDirectReadSchemaTransformConfiguration} and instantiated by {@link * BigQueryDirectReadSchemaTransformProvider}. */ - protected static class BigQueryDirectReadSchemaTransform extends SchemaTransform { + public static class BigQueryDirectReadSchemaTransform extends SchemaTransform { private BigQueryServices testBigQueryServices = null; private final BigQueryDirectReadSchemaTransformConfiguration configuration; @@ -172,6 +182,20 @@ protected static class BigQueryDirectReadSchemaTransform extends SchemaTransform this.configuration = configuration; } + public Row getConfigurationRow() { + try { + // To stay consistent with our SchemaTransform configuration naming conventions, + // we sort lexicographically + return SchemaRegistry.createDefault() + .getToRowFunction(BigQueryDirectReadSchemaTransformConfiguration.class) + .apply(configuration) + .sorted() + .toSnakeCase(); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e); + } + } + @VisibleForTesting public void setBigQueryServices(BigQueryServices testBigQueryServices) { this.testBigQueryServices = testBigQueryServices; @@ -211,6 +235,9 @@ BigQueryIO.TypedRead createDirectReadTransform() { } else { read = read.fromQuery(configuration.getQuery()); } + if (!Strings.isNullOrEmpty(configuration.getKmsKey())) { + read = read.withKmsKey(configuration.getKmsKey()); + } if (this.testBigQueryServices != null) { read = read.withTestServices(testBigQueryServices); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java new file mode 100644 index 000000000000..092cf42a29a4 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import com.google.auto.service.AutoService; +import java.util.Collections; +import java.util.List; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; + +/** + * An implementation of {@link TypedSchemaTransformProvider} for BigQuery write jobs configured + * using {@link org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteConfiguration}. + * + *

    Internal only: This class is actively being worked on, and it will likely change. We + * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam + * repository. + */ +@SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/20497) +}) +@Internal +@AutoService(SchemaTransformProvider.class) +public class BigQueryFileLoadsSchemaTransformProvider + extends TypedSchemaTransformProvider { + + static final String INPUT_TAG = "input"; + + @Override + protected SchemaTransform from(BigQueryWriteConfiguration configuration) { + return new BigQueryFileLoadsSchemaTransform(configuration); + } + + @Override + public String identifier() { + return "beam:schematransform:org.apache.beam:bigquery_fileloads:v1"; + } + + @Override + public List inputCollectionNames() { + return Collections.singletonList(INPUT_TAG); + } + + @Override + public List outputCollectionNames() { + return Collections.emptyList(); + } + + public static class BigQueryFileLoadsSchemaTransform extends SchemaTransform { + /** An instance of {@link BigQueryServices} used for testing. */ + private BigQueryServices testBigQueryServices = null; + + private final BigQueryWriteConfiguration configuration; + + BigQueryFileLoadsSchemaTransform(BigQueryWriteConfiguration configuration) { + configuration.validate(); + this.configuration = configuration; + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + PCollection rowPCollection = input.getSinglePCollection(); + BigQueryIO.Write write = toWrite(input.getPipeline().getOptions()); + rowPCollection.apply(write); + + return PCollectionRowTuple.empty(input.getPipeline()); + } + + BigQueryIO.Write toWrite(PipelineOptions options) { + BigQueryIO.Write write = + BigQueryIO.write() + .to(configuration.getTable()) + .withMethod(BigQueryIO.Write.Method.FILE_LOADS) + .withFormatFunction(BigQueryUtils.toTableRow()) + // TODO(https://github.com/apache/beam/issues/33074) BatchLoad's + // createTempFilePrefixView() doesn't pick up the pipeline option + .withCustomGcsTempLocation( + ValueProvider.StaticValueProvider.of(options.getTempLocation())) + .withWriteDisposition(WriteDisposition.WRITE_APPEND) + .useBeamSchema(); + + if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) { + CreateDisposition createDisposition = + CreateDisposition.valueOf(configuration.getCreateDisposition().toUpperCase()); + write = write.withCreateDisposition(createDisposition); + } + if (!Strings.isNullOrEmpty(configuration.getWriteDisposition())) { + WriteDisposition writeDisposition = + WriteDisposition.valueOf(configuration.getWriteDisposition().toUpperCase()); + write = write.withWriteDisposition(writeDisposition); + } + if (!Strings.isNullOrEmpty(configuration.getKmsKey())) { + write = write.withKmsKey(configuration.getKmsKey()); + } + if (testBigQueryServices != null) { + write = write.withTestServices(testBigQueryServices); + } + + return write; + } + + /** Setter for testing using {@link BigQueryServices}. */ + @VisibleForTesting + void setTestBigQueryServices(BigQueryServices testBigQueryServices) { + this.testBigQueryServices = testBigQueryServices; + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslation.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslation.java new file mode 100644 index 000000000000..555df0d0a2b8 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslation.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider.BigQueryDirectReadSchemaTransform; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteSchemaTransformProvider.BigQueryWriteSchemaTransform; + +import com.google.auto.service.AutoService; +import java.util.Map; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; + +public class BigQuerySchemaTransformTranslation { + public static class BigQueryStorageReadSchemaTransformTranslator + extends SchemaTransformTranslation.SchemaTransformPayloadTranslator< + BigQueryDirectReadSchemaTransform> { + @Override + public SchemaTransformProvider provider() { + return new BigQueryDirectReadSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(BigQueryDirectReadSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + public static class BigQueryWriteSchemaTransformTranslator + extends SchemaTransformTranslation.SchemaTransformPayloadTranslator< + BigQueryWriteSchemaTransform> { + @Override + public SchemaTransformProvider provider() { + return new BigQueryWriteSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(BigQueryWriteSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class ReadWriteRegistrar implements TransformPayloadTranslatorRegistrar { + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map< + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap + ., PTransformTranslation.TransformPayloadTranslator>builder() + .put( + BigQueryDirectReadSchemaTransform.class, + new BigQueryStorageReadSchemaTransformTranslator()) + .put(BigQueryWriteSchemaTransform.class, new BigQueryWriteSchemaTransformTranslator()) + .build(); + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java index c1c06fc592f4..c45433aaf0e7 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java @@ -17,20 +17,16 @@ */ package org.apache.beam.sdk.io.gcp.bigquery.providers; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteConfiguration.DYNAMIC_DESTINATIONS; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import com.google.api.services.bigquery.model.TableConstraints; import com.google.api.services.bigquery.model.TableSchema; import com.google.auto.service.AutoService; -import com.google.auto.value.AutoValue; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Optional; -import javax.annotation.Nullable; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.Method; @@ -42,15 +38,11 @@ import org.apache.beam.sdk.io.gcp.bigquery.RowMutationInformation; import org.apache.beam.sdk.io.gcp.bigquery.TableDestination; import org.apache.beam.sdk.io.gcp.bigquery.WriteResult; -import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.annotations.DefaultSchema; -import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; @@ -65,12 +57,11 @@ import org.apache.beam.sdk.values.ValueInSingleWindow; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; /** * An implementation of {@link TypedSchemaTransformProvider} for BigQuery Storage Write API jobs - * configured via {@link BigQueryStorageWriteApiSchemaTransformConfiguration}. + * configured via {@link BigQueryWriteConfiguration}. * *

    Internal only: This class is actively being worked on, and it will likely change. We * provide no backwards compatibility guarantees, and it should not be implemented outside the Beam @@ -81,7 +72,7 @@ }) @AutoService(SchemaTransformProvider.class) public class BigQueryStorageWriteApiSchemaTransformProvider - extends TypedSchemaTransformProvider { + extends TypedSchemaTransformProvider { private static final Integer DEFAULT_TRIGGER_FREQUENCY_SECS = 5; private static final Duration DEFAULT_TRIGGERING_FREQUENCY = Duration.standardSeconds(DEFAULT_TRIGGER_FREQUENCY_SECS); @@ -89,7 +80,6 @@ public class BigQueryStorageWriteApiSchemaTransformProvider private static final String FAILED_ROWS_TAG = "FailedRows"; private static final String FAILED_ROWS_WITH_ERRORS_TAG = "FailedRowsWithErrors"; // magic string that tells us to write to dynamic destinations - protected static final String DYNAMIC_DESTINATIONS = "DYNAMIC_DESTINATIONS"; protected static final String ROW_PROPERTY_MUTATION_INFO = "row_mutation_info"; protected static final String ROW_PROPERTY_MUTATION_TYPE = "mutation_type"; protected static final String ROW_PROPERTY_MUTATION_SQN = "change_sequence_number"; @@ -100,14 +90,13 @@ public class BigQueryStorageWriteApiSchemaTransformProvider .build(); @Override - protected SchemaTransform from( - BigQueryStorageWriteApiSchemaTransformConfiguration configuration) { + protected SchemaTransform from(BigQueryWriteConfiguration configuration) { return new BigQueryStorageWriteApiSchemaTransform(configuration); } @Override public String identifier() { - return String.format("beam:schematransform:org.apache.beam:bigquery_storage_write:v2"); + return "beam:schematransform:org.apache.beam:bigquery_storage_write:v2"; } @Override @@ -130,201 +119,17 @@ public List outputCollectionNames() { return Arrays.asList(FAILED_ROWS_TAG, FAILED_ROWS_WITH_ERRORS_TAG, "errors"); } - /** Configuration for writing to BigQuery with Storage Write API. */ - @DefaultSchema(AutoValueSchema.class) - @AutoValue - public abstract static class BigQueryStorageWriteApiSchemaTransformConfiguration { - - static final Map CREATE_DISPOSITIONS = - ImmutableMap.builder() - .put(CreateDisposition.CREATE_IF_NEEDED.name(), CreateDisposition.CREATE_IF_NEEDED) - .put(CreateDisposition.CREATE_NEVER.name(), CreateDisposition.CREATE_NEVER) - .build(); - - static final Map WRITE_DISPOSITIONS = - ImmutableMap.builder() - .put(WriteDisposition.WRITE_TRUNCATE.name(), WriteDisposition.WRITE_TRUNCATE) - .put(WriteDisposition.WRITE_EMPTY.name(), WriteDisposition.WRITE_EMPTY) - .put(WriteDisposition.WRITE_APPEND.name(), WriteDisposition.WRITE_APPEND) - .build(); - - @AutoValue - public abstract static class ErrorHandling { - @SchemaFieldDescription("The name of the output PCollection containing failed writes.") - public abstract String getOutput(); - - public static Builder builder() { - return new AutoValue_BigQueryStorageWriteApiSchemaTransformProvider_BigQueryStorageWriteApiSchemaTransformConfiguration_ErrorHandling - .Builder(); - } - - @AutoValue.Builder - public abstract static class Builder { - public abstract Builder setOutput(String output); - - public abstract ErrorHandling build(); - } - } - - public void validate() { - String invalidConfigMessage = "Invalid BigQuery Storage Write configuration: "; - - // validate output table spec - checkArgument( - !Strings.isNullOrEmpty(this.getTable()), - invalidConfigMessage + "Table spec for a BigQuery Write must be specified."); - - // if we have an input table spec, validate it - if (!this.getTable().equals(DYNAMIC_DESTINATIONS)) { - checkNotNull(BigQueryHelpers.parseTableSpec(this.getTable())); - } - - // validate create and write dispositions - if (!Strings.isNullOrEmpty(this.getCreateDisposition())) { - checkNotNull( - CREATE_DISPOSITIONS.get(this.getCreateDisposition().toUpperCase()), - invalidConfigMessage - + "Invalid create disposition (%s) was specified. Available dispositions are: %s", - this.getCreateDisposition(), - CREATE_DISPOSITIONS.keySet()); - } - if (!Strings.isNullOrEmpty(this.getWriteDisposition())) { - checkNotNull( - WRITE_DISPOSITIONS.get(this.getWriteDisposition().toUpperCase()), - invalidConfigMessage - + "Invalid write disposition (%s) was specified. Available dispositions are: %s", - this.getWriteDisposition(), - WRITE_DISPOSITIONS.keySet()); - } - - if (this.getErrorHandling() != null) { - checkArgument( - !Strings.isNullOrEmpty(this.getErrorHandling().getOutput()), - invalidConfigMessage + "Output must not be empty if error handling specified."); - } - - if (this.getAutoSharding() != null - && this.getAutoSharding() - && this.getNumStreams() != null) { - checkArgument( - this.getNumStreams() == 0, - invalidConfigMessage - + "Cannot set a fixed number of streams when auto-sharding is enabled. Please pick only one of the two options."); - } - } - - /** - * Instantiates a {@link BigQueryStorageWriteApiSchemaTransformConfiguration.Builder} instance. - */ - public static Builder builder() { - return new AutoValue_BigQueryStorageWriteApiSchemaTransformProvider_BigQueryStorageWriteApiSchemaTransformConfiguration - .Builder(); - } - - @SchemaFieldDescription( - "The bigquery table to write to. Format: [${PROJECT}:]${DATASET}.${TABLE}") - public abstract String getTable(); - - @SchemaFieldDescription( - "Optional field that specifies whether the job is allowed to create new tables. " - + "The following values are supported: CREATE_IF_NEEDED (the job may create the table), CREATE_NEVER (" - + "the job must fail if the table does not exist already).") - @Nullable - public abstract String getCreateDisposition(); - - @SchemaFieldDescription( - "Specifies the action that occurs if the destination table already exists. " - + "The following values are supported: " - + "WRITE_TRUNCATE (overwrites the table data), " - + "WRITE_APPEND (append the data to the table), " - + "WRITE_EMPTY (job must fail if the table is not empty).") - @Nullable - public abstract String getWriteDisposition(); - - @SchemaFieldDescription( - "Determines how often to 'commit' progress into BigQuery. Default is every 5 seconds.") - @Nullable - public abstract Long getTriggeringFrequencySeconds(); - - @SchemaFieldDescription( - "This option enables lower latency for insertions to BigQuery but may ocassionally " - + "duplicate data elements.") - @Nullable - public abstract Boolean getUseAtLeastOnceSemantics(); - - @SchemaFieldDescription( - "This option enables using a dynamically determined number of Storage Write API streams to write to " - + "BigQuery. Only applicable to unbounded data.") - @Nullable - public abstract Boolean getAutoSharding(); - - @SchemaFieldDescription( - "Specifies the number of write streams that the Storage API sink will use. " - + "This parameter is only applicable when writing unbounded data.") - @Nullable - public abstract Integer getNumStreams(); - - @SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") - @Nullable - public abstract ErrorHandling getErrorHandling(); - - @SchemaFieldDescription( - "This option enables the use of BigQuery CDC functionality. The expected PCollection" - + " should contain Beam Rows with a schema wrapping the record to be inserted and" - + " adding the CDC info similar to: {row_mutation_info: {mutation_type:\"...\", " - + "change_sequence_number:\"...\"}, record: {...}}") - @Nullable - public abstract Boolean getUseCdcWrites(); - - @SchemaFieldDescription( - "If CREATE_IF_NEEDED disposition is set, BigQuery table(s) will be created with this" - + " columns as primary key. Required when CDC writes are enabled with CREATE_IF_NEEDED.") - @Nullable - public abstract List getPrimaryKey(); - - /** Builder for {@link BigQueryStorageWriteApiSchemaTransformConfiguration}. */ - @AutoValue.Builder - public abstract static class Builder { - - public abstract Builder setTable(String table); - - public abstract Builder setCreateDisposition(String createDisposition); - - public abstract Builder setWriteDisposition(String writeDisposition); - - public abstract Builder setTriggeringFrequencySeconds(Long seconds); - - public abstract Builder setUseAtLeastOnceSemantics(Boolean use); - - public abstract Builder setAutoSharding(Boolean autoSharding); - - public abstract Builder setNumStreams(Integer numStreams); - - public abstract Builder setErrorHandling(ErrorHandling errorHandling); - - public abstract Builder setUseCdcWrites(Boolean cdcWrites); - - public abstract Builder setPrimaryKey(List pkColumns); - - /** Builds a {@link BigQueryStorageWriteApiSchemaTransformConfiguration} instance. */ - public abstract BigQueryStorageWriteApiSchemaTransformProvider - .BigQueryStorageWriteApiSchemaTransformConfiguration - build(); - } - } - /** * A {@link SchemaTransform} for BigQuery Storage Write API, configured with {@link - * BigQueryStorageWriteApiSchemaTransformConfiguration} and instantiated by {@link + * BigQueryWriteConfiguration} and instantiated by {@link * BigQueryStorageWriteApiSchemaTransformProvider}. */ - protected static class BigQueryStorageWriteApiSchemaTransform extends SchemaTransform { + public static class BigQueryStorageWriteApiSchemaTransform extends SchemaTransform { private BigQueryServices testBigQueryServices = null; - private final BigQueryStorageWriteApiSchemaTransformConfiguration configuration; + private final BigQueryWriteConfiguration configuration; - BigQueryStorageWriteApiSchemaTransform( - BigQueryStorageWriteApiSchemaTransformConfiguration configuration) { + BigQueryStorageWriteApiSchemaTransform(BigQueryWriteConfiguration configuration) { configuration.validate(); this.configuration = configuration; } @@ -420,8 +225,7 @@ public TableConstraints getTableConstraints(String destination) { @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { // Check that the input exists - checkArgument(input.has(INPUT_ROWS_TAG), "Missing expected input tag: %s", INPUT_ROWS_TAG); - PCollection inputRows = input.get(INPUT_ROWS_TAG); + PCollection inputRows = input.getSinglePCollection(); BigQueryIO.Write write = createStorageWriteApiTransform(inputRows.getSchema()); @@ -540,18 +344,18 @@ BigQueryIO.Write createStorageWriteApiTransform(Schema schema) { if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) { CreateDisposition createDisposition = - BigQueryStorageWriteApiSchemaTransformConfiguration.CREATE_DISPOSITIONS.get( - configuration.getCreateDisposition().toUpperCase()); + CreateDisposition.valueOf(configuration.getCreateDisposition().toUpperCase()); write = write.withCreateDisposition(createDisposition); } if (!Strings.isNullOrEmpty(configuration.getWriteDisposition())) { WriteDisposition writeDisposition = - BigQueryStorageWriteApiSchemaTransformConfiguration.WRITE_DISPOSITIONS.get( - configuration.getWriteDisposition().toUpperCase()); + WriteDisposition.valueOf(configuration.getWriteDisposition().toUpperCase()); write = write.withWriteDisposition(writeDisposition); } - + if (!Strings.isNullOrEmpty(configuration.getKmsKey())) { + write = write.withKmsKey(configuration.getKmsKey()); + } if (this.testBigQueryServices != null) { write = write.withTestServices(testBigQueryServices); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java new file mode 100644 index 000000000000..4296da7e0cd5 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; + +import com.google.auto.value.AutoValue; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import javax.annotation.Nullable; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; +import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; + +/** + * Configuration for writing to BigQuery with SchemaTransforms. Used by {@link + * BigQueryStorageWriteApiSchemaTransformProvider} and {@link + * BigQueryFileLoadsSchemaTransformProvider}. + */ +@DefaultSchema(AutoValueSchema.class) +@AutoValue +public abstract class BigQueryWriteConfiguration { + protected static final String DYNAMIC_DESTINATIONS = "DYNAMIC_DESTINATIONS"; + + @AutoValue + public abstract static class ErrorHandling { + @SchemaFieldDescription("The name of the output PCollection containing failed writes.") + public abstract String getOutput(); + + public static Builder builder() { + return new AutoValue_BigQueryWriteConfiguration_ErrorHandling.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setOutput(String output); + + public abstract ErrorHandling build(); + } + } + + public void validate() { + String invalidConfigMessage = "Invalid BigQuery Storage Write configuration: "; + + // validate output table spec + checkArgument( + !Strings.isNullOrEmpty(this.getTable()), + invalidConfigMessage + "Table spec for a BigQuery Write must be specified."); + + // if we have an input table spec, validate it + if (!this.getTable().equals(DYNAMIC_DESTINATIONS)) { + checkNotNull(BigQueryHelpers.parseTableSpec(this.getTable())); + } + + // validate create and write dispositions + String createDisposition = getCreateDisposition(); + if (createDisposition != null && !createDisposition.isEmpty()) { + List createDispositions = + Arrays.stream(BigQueryIO.Write.CreateDisposition.values()) + .map(c -> c.name()) + .collect(Collectors.toList()); + Preconditions.checkArgument( + createDispositions.contains(createDisposition.toUpperCase()), + "Invalid create disposition (%s) was specified. Available dispositions are: %s", + createDisposition, + createDispositions); + } + String writeDisposition = getWriteDisposition(); + if (writeDisposition != null && !writeDisposition.isEmpty()) { + List writeDispostions = + Arrays.stream(BigQueryIO.Write.WriteDisposition.values()) + .map(w -> w.name()) + .collect(Collectors.toList()); + Preconditions.checkArgument( + writeDispostions.contains(writeDisposition.toUpperCase()), + "Invalid write disposition (%s) was specified. Available dispositions are: %s", + writeDisposition, + writeDispostions); + } + + ErrorHandling errorHandling = getErrorHandling(); + if (errorHandling != null) { + checkArgument( + !Strings.isNullOrEmpty(errorHandling.getOutput()), + invalidConfigMessage + "Output must not be empty if error handling specified."); + } + + Boolean autoSharding = getAutoSharding(); + Integer numStreams = getNumStreams(); + if (autoSharding != null && autoSharding && numStreams != null) { + checkArgument( + numStreams == 0, + invalidConfigMessage + + "Cannot set a fixed number of streams when auto-sharding is enabled. Please pick only one of the two options."); + } + } + + /** Instantiates a {@link BigQueryWriteConfiguration.Builder} instance. */ + public static Builder builder() { + return new AutoValue_BigQueryWriteConfiguration.Builder(); + } + + @SchemaFieldDescription( + "The bigquery table to write to. Format: [${PROJECT}:]${DATASET}.${TABLE}") + public abstract String getTable(); + + @SchemaFieldDescription( + "Optional field that specifies whether the job is allowed to create new tables. " + + "The following values are supported: CREATE_IF_NEEDED (the job may create the table), CREATE_NEVER (" + + "the job must fail if the table does not exist already).") + @Nullable + public abstract String getCreateDisposition(); + + @SchemaFieldDescription( + "Specifies the action that occurs if the destination table already exists. " + + "The following values are supported: " + + "WRITE_TRUNCATE (overwrites the table data), " + + "WRITE_APPEND (append the data to the table), " + + "WRITE_EMPTY (job must fail if the table is not empty).") + @Nullable + public abstract String getWriteDisposition(); + + @SchemaFieldDescription( + "Determines how often to 'commit' progress into BigQuery. Default is every 5 seconds.") + @Nullable + public abstract Long getTriggeringFrequencySeconds(); + + @SchemaFieldDescription( + "This option enables lower latency for insertions to BigQuery but may ocassionally " + + "duplicate data elements.") + @Nullable + public abstract Boolean getUseAtLeastOnceSemantics(); + + @SchemaFieldDescription( + "This option enables using a dynamically determined number of Storage Write API streams to write to " + + "BigQuery. Only applicable to unbounded data.") + @Nullable + public abstract Boolean getAutoSharding(); + + @SchemaFieldDescription( + "Specifies the number of write streams that the Storage API sink will use. " + + "This parameter is only applicable when writing unbounded data.") + @Nullable + public abstract Integer getNumStreams(); + + @SchemaFieldDescription("Use this Cloud KMS key to encrypt your data") + @Nullable + public abstract String getKmsKey(); + + @SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") + @Nullable + public abstract ErrorHandling getErrorHandling(); + + @SchemaFieldDescription( + "This option enables the use of BigQuery CDC functionality. The expected PCollection" + + " should contain Beam Rows with a schema wrapping the record to be inserted and" + + " adding the CDC info similar to: {row_mutation_info: {mutation_type:\"...\", " + + "change_sequence_number:\"...\"}, record: {...}}") + @Nullable + public abstract Boolean getUseCdcWrites(); + + @SchemaFieldDescription( + "If CREATE_IF_NEEDED disposition is set, BigQuery table(s) will be created with this" + + " columns as primary key. Required when CDC writes are enabled with CREATE_IF_NEEDED.") + @Nullable + public abstract List getPrimaryKey(); + + /** Builder for {@link BigQueryWriteConfiguration}. */ + @AutoValue.Builder + public abstract static class Builder { + + public abstract Builder setTable(String table); + + public abstract Builder setCreateDisposition(String createDisposition); + + public abstract Builder setWriteDisposition(String writeDisposition); + + public abstract Builder setTriggeringFrequencySeconds(Long seconds); + + public abstract Builder setUseAtLeastOnceSemantics(Boolean use); + + public abstract Builder setAutoSharding(Boolean autoSharding); + + public abstract Builder setNumStreams(Integer numStreams); + + public abstract Builder setKmsKey(String kmsKey); + + public abstract Builder setErrorHandling(ErrorHandling errorHandling); + + public abstract Builder setUseCdcWrites(Boolean cdcWrites); + + public abstract Builder setPrimaryKey(List pkColumns); + + /** Builds a {@link BigQueryWriteConfiguration} instance. */ + public abstract BigQueryWriteConfiguration build(); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteSchemaTransformProvider.java new file mode 100644 index 000000000000..abab169d6932 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteSchemaTransformProvider.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; + +import com.google.auto.service.AutoService; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; + +/** + * A BigQuery Write SchemaTransformProvider that routes to either {@link + * BigQueryFileLoadsSchemaTransformProvider} or {@link + * BigQueryStorageWriteApiSchemaTransformProvider}. + * + *

    Internal only. Used by the Managed Transform layer. + */ +@Internal +@AutoService(SchemaTransformProvider.class) +public class BigQueryWriteSchemaTransformProvider + extends TypedSchemaTransformProvider { + @Override + public String identifier() { + return getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_WRITE); + } + + @Override + protected SchemaTransform from(BigQueryWriteConfiguration configuration) { + return new BigQueryWriteSchemaTransform(configuration); + } + + public static class BigQueryWriteSchemaTransform extends SchemaTransform { + private final BigQueryWriteConfiguration configuration; + + BigQueryWriteSchemaTransform(BigQueryWriteConfiguration configuration) { + configuration.validate(); + this.configuration = configuration; + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + if (input.getSinglePCollection().isBounded().equals(PCollection.IsBounded.BOUNDED)) { + return input.apply(new BigQueryFileLoadsSchemaTransformProvider().from(configuration)); + } else { // UNBOUNDED + return input.apply( + new BigQueryStorageWriteApiSchemaTransformProvider().from(configuration)); + } + } + + public Row getConfigurationRow() { + try { + // To stay consistent with our SchemaTransform configuration naming conventions, + // we sort lexicographically + return SchemaRegistry.createDefault() + .getToRowFunction(BigQueryWriteConfiguration.class) + .apply(configuration) + .sorted() + .toSnakeCase(); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProviderTest.java deleted file mode 100644 index dd8bb9fc8664..000000000000 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryFileLoadsWriteSchemaTransformProviderTest.java +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.gcp.bigquery; - -import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryFileLoadsWriteSchemaTransformProvider.INPUT_TAG; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; - -import com.google.api.services.bigquery.model.TableReference; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import java.io.IOException; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryFileLoadsWriteSchemaTransformProvider.BigQueryWriteSchemaTransform; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; -import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; -import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; -import org.apache.beam.sdk.io.gcp.testing.FakeJobService; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.Schema.Field; -import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.io.InvalidConfigurationException; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.transforms.display.DisplayData.Identifier; -import org.apache.beam.sdk.transforms.display.DisplayData.Item; -import org.apache.beam.sdk.values.PCollectionRowTuple; -import org.apache.beam.sdk.values.Row; -import org.apache.commons.lang3.tuple.Pair; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Test for {@link BigQueryFileLoadsWriteSchemaTransformProvider}. */ -@RunWith(JUnit4.class) -public class BigQueryFileLoadsWriteSchemaTransformProviderTest { - - private static final String PROJECT = "fakeproject"; - private static final String DATASET = "fakedataset"; - private static final String TABLE_ID = "faketable"; - - private static final TableReference TABLE_REFERENCE = - new TableReference().setProjectId(PROJECT).setDatasetId(DATASET).setTableId(TABLE_ID); - - private static final Schema SCHEMA = - Schema.of(Field.of("name", FieldType.STRING), Field.of("number", FieldType.INT64)); - - private static final TableSchema TABLE_SCHEMA = BigQueryUtils.toTableSchema(SCHEMA); - - private static final List ROWS = - Arrays.asList( - Row.withSchema(SCHEMA).withFieldValue("name", "a").withFieldValue("number", 1L).build(), - Row.withSchema(SCHEMA).withFieldValue("name", "b").withFieldValue("number", 2L).build(), - Row.withSchema(SCHEMA).withFieldValue("name", "c").withFieldValue("number", 3L).build()); - - private static final BigQueryOptions OPTIONS = - TestPipeline.testingPipelineOptions().as(BigQueryOptions.class); - private final FakeDatasetService fakeDatasetService = new FakeDatasetService(); - private final FakeJobService fakeJobService = new FakeJobService(); - private final TemporaryFolder temporaryFolder = new TemporaryFolder(); - private final FakeBigQueryServices fakeBigQueryServices = - new FakeBigQueryServices() - .withJobService(fakeJobService) - .withDatasetService(fakeDatasetService); - - @Before - public void setUp() throws IOException, InterruptedException { - FakeDatasetService.setUp(); - fakeDatasetService.createDataset(PROJECT, DATASET, "", "", null); - temporaryFolder.create(); - OPTIONS.setProject(PROJECT); - OPTIONS.setTempLocation(temporaryFolder.getRoot().getAbsolutePath()); - } - - @After - public void tearDown() { - temporaryFolder.delete(); - } - - @Rule public transient TestPipeline p = TestPipeline.fromOptions(OPTIONS); - - @Test - public void testLoad() throws IOException, InterruptedException { - BigQueryFileLoadsWriteSchemaTransformProvider provider = - new BigQueryFileLoadsWriteSchemaTransformProvider(); - BigQueryFileLoadsWriteSchemaTransformConfiguration configuration = - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) - .setWriteDisposition(WriteDisposition.WRITE_TRUNCATE.name()) - .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) - .build(); - BigQueryWriteSchemaTransform schemaTransform = - (BigQueryWriteSchemaTransform) provider.from(configuration); - schemaTransform.setTestBigQueryServices(fakeBigQueryServices); - String tag = provider.inputCollectionNames().get(0); - PCollectionRowTuple input = - PCollectionRowTuple.of(tag, p.apply(Create.of(ROWS).withRowSchema(SCHEMA))); - input.apply(schemaTransform); - - p.run(); - - assertNotNull(fakeDatasetService.getTable(TABLE_REFERENCE)); - assertEquals(ROWS.size(), fakeDatasetService.getAllRows(PROJECT, DATASET, TABLE_ID).size()); - } - - @Test - public void testValidatePipelineOptions() { - List< - Pair< - BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder, - Class>> - cases = - Arrays.asList( - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec("project.doesnot.exist") - .setCreateDisposition(CreateDisposition.CREATE_NEVER.name()) - .setWriteDisposition(WriteDisposition.WRITE_APPEND.name()), - InvalidConfigurationException.class), - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(String.format("%s.%s.%s", PROJECT, DATASET, "doesnotexist")) - .setCreateDisposition(CreateDisposition.CREATE_NEVER.name()) - .setWriteDisposition(WriteDisposition.WRITE_EMPTY.name()), - InvalidConfigurationException.class), - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec("project.doesnot.exist") - .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) - .setWriteDisposition(WriteDisposition.WRITE_APPEND.name()), - null)); - for (Pair< - BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder, Class> - caze : cases) { - BigQueryWriteSchemaTransform transform = transformFrom(caze.getLeft().build()); - if (caze.getRight() != null) { - assertThrows(caze.getRight(), () -> transform.validate(p.getOptions())); - } else { - transform.validate(p.getOptions()); - } - } - } - - @Test - public void testToWrite() { - List< - Pair< - BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder, - BigQueryIO.Write>> - cases = - Arrays.asList( - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) - .setCreateDisposition(CreateDisposition.CREATE_NEVER.name()) - .setWriteDisposition(WriteDisposition.WRITE_EMPTY.name()), - BigQueryIO.writeTableRows() - .to(TABLE_REFERENCE) - .withCreateDisposition(CreateDisposition.CREATE_NEVER) - .withWriteDisposition(WriteDisposition.WRITE_EMPTY) - .withSchema(TABLE_SCHEMA)), - Pair.of( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) - .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) - .setWriteDisposition(WriteDisposition.WRITE_TRUNCATE.name()), - BigQueryIO.writeTableRows() - .to(TABLE_REFERENCE) - .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED) - .withWriteDisposition(WriteDisposition.WRITE_TRUNCATE) - .withSchema(TABLE_SCHEMA))); - for (Pair< - BigQueryFileLoadsWriteSchemaTransformConfiguration.Builder, BigQueryIO.Write> - caze : cases) { - BigQueryWriteSchemaTransform transform = transformFrom(caze.getLeft().build()); - Map gotDisplayData = DisplayData.from(transform.toWrite(SCHEMA)).asMap(); - Map wantDisplayData = DisplayData.from(caze.getRight()).asMap(); - Set keys = new HashSet<>(); - keys.addAll(gotDisplayData.keySet()); - keys.addAll(wantDisplayData.keySet()); - for (Identifier key : keys) { - Item got = null; - Item want = null; - if (gotDisplayData.containsKey(key)) { - got = gotDisplayData.get(key); - } - if (wantDisplayData.containsKey(key)) { - want = wantDisplayData.get(key); - } - assertEquals(want, got); - } - } - } - - @Test - public void validatePCollectionRowTupleInput() { - PCollectionRowTuple empty = PCollectionRowTuple.empty(p); - PCollectionRowTuple valid = - PCollectionRowTuple.of( - INPUT_TAG, p.apply("CreateRowsWithValidSchema", Create.of(ROWS)).setRowSchema(SCHEMA)); - - PCollectionRowTuple invalid = - PCollectionRowTuple.of( - INPUT_TAG, - p.apply( - "CreateRowsWithInvalidSchema", - Create.of( - Row.nullRow( - Schema.builder().addNullableField("name", FieldType.STRING).build())))); - - BigQueryWriteSchemaTransform transform = - transformFrom( - BigQueryFileLoadsWriteSchemaTransformConfiguration.builder() - .setTableSpec(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) - .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) - .setWriteDisposition(WriteDisposition.WRITE_APPEND.name()) - .build()); - - assertThrows(IllegalArgumentException.class, () -> transform.validate(empty)); - - assertThrows(IllegalStateException.class, () -> transform.validate(invalid)); - - transform.validate(valid); - - p.run(); - } - - private BigQueryWriteSchemaTransform transformFrom( - BigQueryFileLoadsWriteSchemaTransformConfiguration configuration) { - BigQueryFileLoadsWriteSchemaTransformProvider provider = - new BigQueryFileLoadsWriteSchemaTransformProvider(); - BigQueryWriteSchemaTransform transform = - (BigQueryWriteSchemaTransform) provider.from(configuration); - - transform.setTestBigQueryServices(fakeBigQueryServices); - - return transform; - } -} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java new file mode 100644 index 000000000000..897d95da3b13 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import com.google.api.services.bigquery.model.TableReference; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryOptions; +import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryFileLoadsSchemaTransformProvider.BigQueryFileLoadsSchemaTransform; +import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; +import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; +import org.apache.beam.sdk.io.gcp.testing.FakeJobService; +import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test for {@link BigQueryFileLoadsSchemaTransformProvider}. */ +@RunWith(JUnit4.class) +public class BigQueryFileLoadsSchemaTransformProviderTest { + + private static final String PROJECT = "fakeproject"; + private static final String DATASET = "fakedataset"; + private static final String TABLE_ID = "faketable"; + + private static final TableReference TABLE_REFERENCE = + new TableReference().setProjectId(PROJECT).setDatasetId(DATASET).setTableId(TABLE_ID); + + private static final Schema SCHEMA = + Schema.of(Field.of("name", FieldType.STRING), Field.of("number", FieldType.INT64)); + + private static final List ROWS = + Arrays.asList( + Row.withSchema(SCHEMA).withFieldValue("name", "a").withFieldValue("number", 1L).build(), + Row.withSchema(SCHEMA).withFieldValue("name", "b").withFieldValue("number", 2L).build(), + Row.withSchema(SCHEMA).withFieldValue("name", "c").withFieldValue("number", 3L).build()); + + private static final BigQueryOptions OPTIONS = + TestPipeline.testingPipelineOptions().as(BigQueryOptions.class); + private final FakeDatasetService fakeDatasetService = new FakeDatasetService(); + private final FakeJobService fakeJobService = new FakeJobService(); + private final TemporaryFolder temporaryFolder = new TemporaryFolder(); + private final FakeBigQueryServices fakeBigQueryServices = + new FakeBigQueryServices() + .withJobService(fakeJobService) + .withDatasetService(fakeDatasetService); + + @Before + public void setUp() throws IOException, InterruptedException { + FakeDatasetService.setUp(); + fakeDatasetService.createDataset(PROJECT, DATASET, "", "", null); + temporaryFolder.create(); + OPTIONS.setProject(PROJECT); + OPTIONS.setTempLocation(temporaryFolder.getRoot().getAbsolutePath()); + } + + @After + public void tearDown() { + temporaryFolder.delete(); + } + + @Rule public transient TestPipeline p = TestPipeline.fromOptions(OPTIONS); + + @Test + public void testLoad() throws IOException, InterruptedException { + BigQueryFileLoadsSchemaTransformProvider provider = + new BigQueryFileLoadsSchemaTransformProvider(); + BigQueryWriteConfiguration configuration = + BigQueryWriteConfiguration.builder() + .setTable(BigQueryHelpers.toTableSpec(TABLE_REFERENCE)) + .setWriteDisposition(WriteDisposition.WRITE_TRUNCATE.name()) + .setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED.name()) + .build(); + BigQueryFileLoadsSchemaTransform schemaTransform = + (BigQueryFileLoadsSchemaTransform) provider.from(configuration); + schemaTransform.setTestBigQueryServices(fakeBigQueryServices); + String tag = provider.inputCollectionNames().get(0); + PCollectionRowTuple input = + PCollectionRowTuple.of(tag, p.apply(Create.of(ROWS).withRowSchema(SCHEMA))); + input.apply(schemaTransform); + + p.run(); + + assertNotNull(fakeDatasetService.getTable(TABLE_REFERENCE)); + assertEquals(ROWS.size(), fakeDatasetService.getAllRows(PROJECT, DATASET, TABLE_ID).size()); + } + + @Test + public void testManagedChoosesFileLoadsForBoundedWrites() { + PCollection batchInput = p.apply(Create.of(ROWS)).setRowSchema(SCHEMA); + batchInput.apply( + Managed.write(Managed.BIGQUERY) + .withConfig(ImmutableMap.of("table", "project.dataset.table"))); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> + tr.getUniqueName() + .contains(BigQueryFileLoadsSchemaTransform.class.getSimpleName())) + .collect(Collectors.toList()); + assertThat(writeTransformProto.size(), greaterThan(0)); + p.enableAbandonedNodeEnforcement(false); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java new file mode 100644 index 000000000000..63727107a651 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.LongStream; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.gcp.testing.BigqueryClient; +import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PeriodicImpulse; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** This class tests the execution of {@link Managed} BigQueryIO. */ +@RunWith(JUnit4.class) +public class BigQueryManagedIT { + @Rule public TestName testName = new TestName(); + @Rule public transient TestPipeline writePipeline = TestPipeline.create(); + @Rule public transient TestPipeline readPipeline = TestPipeline.create(); + + private static final Schema SCHEMA = + Schema.of( + Schema.Field.of("str", Schema.FieldType.STRING), + Schema.Field.of("number", Schema.FieldType.INT64)); + + private static final List ROWS = + LongStream.range(0, 20) + .mapToObj( + i -> + Row.withSchema(SCHEMA) + .withFieldValue("str", Long.toString(i)) + .withFieldValue("number", i) + .build()) + .collect(Collectors.toList()); + + private static final BigqueryClient BQ_CLIENT = new BigqueryClient("BigQueryManagedIT"); + + private static final String PROJECT = + TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + private static final String BIG_QUERY_DATASET_ID = "bigquery_managed_" + System.nanoTime(); + + @BeforeClass + public static void setUpTestEnvironment() throws IOException, InterruptedException { + // Create one BQ dataset for all test cases. + BQ_CLIENT.createNewDataset(PROJECT, BIG_QUERY_DATASET_ID, null); + } + + @AfterClass + public static void cleanup() { + BQ_CLIENT.deleteDataset(PROJECT, BIG_QUERY_DATASET_ID); + } + + @Test + public void testBatchFileLoadsWriteRead() { + String table = + String.format("%s:%s.%s", PROJECT, BIG_QUERY_DATASET_ID, testName.getMethodName()); + Map config = ImmutableMap.of("table", table); + + // file loads requires a GCS temp location + String tempLocation = writePipeline.getOptions().as(TestPipelineOptions.class).getTempRoot(); + writePipeline.getOptions().setTempLocation(tempLocation); + + // batch write + PCollectionRowTuple.of("input", getInput(writePipeline, false)) + .apply(Managed.write(Managed.BIGQUERY).withConfig(config)); + writePipeline.run().waitUntilFinish(); + + // read and validate + PCollection outputRows = + readPipeline + .apply(Managed.read(Managed.BIGQUERY).withConfig(config)) + .getSinglePCollection(); + PAssert.that(outputRows).containsInAnyOrder(ROWS); + readPipeline.run().waitUntilFinish(); + } + + @Test + public void testStreamingStorageWriteRead() { + String table = + String.format("%s:%s.%s", PROJECT, BIG_QUERY_DATASET_ID, testName.getMethodName()); + Map config = ImmutableMap.of("table", table); + + // streaming write + PCollectionRowTuple.of("input", getInput(writePipeline, true)) + .apply(Managed.write(Managed.BIGQUERY).withConfig(config)); + writePipeline.run().waitUntilFinish(); + + // read and validate + PCollection outputRows = + readPipeline + .apply(Managed.read(Managed.BIGQUERY).withConfig(config)) + .getSinglePCollection(); + PAssert.that(outputRows).containsInAnyOrder(ROWS); + readPipeline.run().waitUntilFinish(); + } + + public PCollection getInput(Pipeline p, boolean isStreaming) { + if (isStreaming) { + return p.apply( + PeriodicImpulse.create() + .startAt(new Instant(0)) + .stopAt(new Instant(19)) + .withInterval(Duration.millis(1))) + .apply( + MapElements.into(TypeDescriptors.rows()) + .via( + i -> + Row.withSchema(SCHEMA) + .withFieldValue("str", Long.toString(i.getMillis())) + .withFieldValue("number", i.getMillis()) + .build())) + .setRowSchema(SCHEMA); + } + return p.apply(Create.of(ROWS)).setRowSchema(SCHEMA); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslationTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslationTest.java new file mode 100644 index 000000000000..822c607aa3c9 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQuerySchemaTransformTranslationTest.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryDirectReadSchemaTransformProvider.BigQueryDirectReadSchemaTransform; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQuerySchemaTransformTranslation.BigQueryStorageReadSchemaTransformTranslator; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQuerySchemaTransformTranslation.BigQueryWriteSchemaTransformTranslator; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteSchemaTransformProvider.BigQueryWriteSchemaTransform; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.construction.BeamUrns; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class BigQuerySchemaTransformTranslationTest { + static final BigQueryWriteSchemaTransformProvider WRITE_PROVIDER = + new BigQueryWriteSchemaTransformProvider(); + static final BigQueryDirectReadSchemaTransformProvider READ_PROVIDER = + new BigQueryDirectReadSchemaTransformProvider(); + static final Row WRITE_CONFIG_ROW = + Row.withSchema(WRITE_PROVIDER.configurationSchema()) + .withFieldValue("table", "project:dataset.table") + .withFieldValue("create_disposition", "create_never") + .withFieldValue("write_disposition", "write_append") + .withFieldValue("triggering_frequency_seconds", 5L) + .withFieldValue("use_at_least_once_semantics", false) + .withFieldValue("auto_sharding", false) + .withFieldValue("num_streams", 5) + .withFieldValue("error_handling", null) + .build(); + static final Row READ_CONFIG_ROW = + Row.withSchema(READ_PROVIDER.configurationSchema()) + .withFieldValue("query", null) + .withFieldValue("table_spec", "apache-beam-testing.samples.weather_stations") + .withFieldValue("row_restriction", "col < 5") + .withFieldValue("selected_fields", Arrays.asList("col1", "col2", "col3")) + .build(); + + @Test + public void testRecreateWriteTransformFromRow() { + BigQueryWriteSchemaTransform writeTransform = + (BigQueryWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG_ROW); + + BigQueryWriteSchemaTransformTranslator translator = + new BigQueryWriteSchemaTransformTranslator(); + Row translatedRow = translator.toConfigRow(writeTransform); + + BigQueryWriteSchemaTransform writeTransformFromRow = + translator.fromConfigRow(translatedRow, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG_ROW, writeTransformFromRow.getConfigurationRow()); + } + + @Test + public void testWriteTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + Schema inputSchema = Schema.builder().addByteArrayField("b").build(); + PCollection input = + p.apply( + Create.of( + Collections.singletonList( + Row.withSchema(inputSchema).addValue(new byte[] {1, 2, 3}).build()))) + .setRowSchema(inputSchema); + + BigQueryWriteSchemaTransform writeTransform = + (BigQueryWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG_ROW); + PCollectionRowTuple.of("input", input).apply(writeTransform); + + // Then translate the pipeline to a proto and extract KafkaWriteSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(WRITE_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, writeTransformProto.size()); + RunnerApi.FunctionSpec spec = writeTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(WRITE_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + + assertEquals(WRITE_CONFIG_ROW, rowFromSpec); + + // Use the information in the proto to recreate the KafkaWriteSchemaTransform + BigQueryWriteSchemaTransformTranslator translator = + new BigQueryWriteSchemaTransformTranslator(); + BigQueryWriteSchemaTransform writeTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG_ROW, writeTransformFromSpec.getConfigurationRow()); + } + + @Test + public void testReCreateReadTransformFromRow() { + BigQueryDirectReadSchemaTransform readTransform = + (BigQueryDirectReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG_ROW); + + BigQueryStorageReadSchemaTransformTranslator translator = + new BigQueryStorageReadSchemaTransformTranslator(); + Row row = translator.toConfigRow(readTransform); + + BigQueryDirectReadSchemaTransform readTransformFromRow = + translator.fromConfigRow(row, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG_ROW, readTransformFromRow.getConfigurationRow()); + } + + @Test + public void testReadTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + + BigQueryDirectReadSchemaTransform readTransform = + (BigQueryDirectReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG_ROW); + + PCollectionRowTuple.empty(p).apply(readTransform); + + // Then translate the pipeline to a proto and extract KafkaReadSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List readTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(READ_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, readTransformProto.size()); + RunnerApi.FunctionSpec spec = readTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(READ_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + assertEquals(READ_CONFIG_ROW, rowFromSpec); + + // Use the information in the proto to recreate the KafkaReadSchemaTransform + BigQueryStorageReadSchemaTransformTranslator translator = + new BigQueryStorageReadSchemaTransformTranslator(); + BigQueryDirectReadSchemaTransform readTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG_ROW, readTransformFromSpec.getConfigurationRow()); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java index 87ba2961461a..7b59552bbbe4 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.gcp.bigquery.providers; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; @@ -32,13 +34,14 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; +import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransform; -import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration; import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; import org.apache.beam.sdk.io.gcp.testing.FakeJobService; +import org.apache.beam.sdk.managed.Managed; import org.apache.beam.sdk.metrics.MetricNameFilter; import org.apache.beam.sdk.metrics.MetricQueryResults; import org.apache.beam.sdk.metrics.MetricResult; @@ -50,13 +53,16 @@ import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.util.construction.PipelineTranslation; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -108,15 +114,14 @@ public void setUp() throws Exception { @Test public void testInvalidConfig() { - List invalidConfigs = + List invalidConfigs = Arrays.asList( - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() - .setTable("not_a_valid_table_spec"), - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + BigQueryWriteConfiguration.builder().setTable("not_a_valid_table_spec"), + BigQueryWriteConfiguration.builder() .setTable("project:dataset.table") .setCreateDisposition("INVALID_DISPOSITION")); - for (BigQueryStorageWriteApiSchemaTransformConfiguration.Builder config : invalidConfigs) { + for (BigQueryWriteConfiguration.Builder config : invalidConfigs) { assertThrows( Exception.class, () -> { @@ -125,13 +130,11 @@ public void testInvalidConfig() { } } - public PCollectionRowTuple runWithConfig( - BigQueryStorageWriteApiSchemaTransformConfiguration config) { + public PCollectionRowTuple runWithConfig(BigQueryWriteConfiguration config) { return runWithConfig(config, ROWS); } - public PCollectionRowTuple runWithConfig( - BigQueryStorageWriteApiSchemaTransformConfiguration config, List inputRows) { + public PCollectionRowTuple runWithConfig(BigQueryWriteConfiguration config, List inputRows) { BigQueryStorageWriteApiSchemaTransformProvider provider = new BigQueryStorageWriteApiSchemaTransformProvider(); @@ -176,8 +179,8 @@ public boolean rowEquals(Row expectedRow, TableRow actualRow) { @Test public void testSimpleWrite() throws Exception { String tableSpec = "project:dataset.simple_write"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder().setTable(tableSpec).build(); runWithConfig(config, ROWS); p.run().waitUntilFinish(); @@ -189,9 +192,9 @@ public void testSimpleWrite() throws Exception { @Test public void testWriteToDynamicDestinations() throws Exception { - String dynamic = BigQueryStorageWriteApiSchemaTransformProvider.DYNAMIC_DESTINATIONS; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(dynamic).build(); + String dynamic = BigQueryWriteConfiguration.DYNAMIC_DESTINATIONS; + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder().setTable(dynamic).build(); String baseTableSpec = "project:dataset.dynamic_write_"; @@ -273,8 +276,8 @@ public void testCDCWrites() throws Exception { String tableSpec = "project:dataset.cdc_write"; List primaryKeyColumns = ImmutableList.of("name"); - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() .setUseAtLeastOnceSemantics(true) .setTable(tableSpec) .setUseCdcWrites(true) @@ -304,9 +307,9 @@ public void testCDCWrites() throws Exception { @Test public void testCDCWriteToDynamicDestinations() throws Exception { List primaryKeyColumns = ImmutableList.of("name"); - String dynamic = BigQueryStorageWriteApiSchemaTransformProvider.DYNAMIC_DESTINATIONS; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + String dynamic = BigQueryWriteConfiguration.DYNAMIC_DESTINATIONS; + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() .setUseAtLeastOnceSemantics(true) .setTable(dynamic) .setUseCdcWrites(true) @@ -338,8 +341,8 @@ public void testCDCWriteToDynamicDestinations() throws Exception { @Test public void testInputElementCount() throws Exception { String tableSpec = "project:dataset.input_count"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder().setTable(tableSpec).build(); runWithConfig(config); PipelineResult result = p.run(); @@ -368,13 +371,11 @@ public void testInputElementCount() throws Exception { @Test public void testFailedRows() throws Exception { String tableSpec = "project:dataset.write_with_fail"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() .setTable(tableSpec) .setErrorHandling( - BigQueryStorageWriteApiSchemaTransformConfiguration.ErrorHandling.builder() - .setOutput("FailedRows") - .build()) + BigQueryWriteConfiguration.ErrorHandling.builder().setOutput("FailedRows").build()) .build(); String failValue = "fail_me"; @@ -420,13 +421,11 @@ public void testFailedRows() throws Exception { @Test public void testErrorCount() throws Exception { String tableSpec = "project:dataset.error_count"; - BigQueryStorageWriteApiSchemaTransformConfiguration config = - BigQueryStorageWriteApiSchemaTransformConfiguration.builder() + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() .setTable(tableSpec) .setErrorHandling( - BigQueryStorageWriteApiSchemaTransformConfiguration.ErrorHandling.builder() - .setOutput("FailedRows") - .build()) + BigQueryWriteConfiguration.ErrorHandling.builder().setOutput("FailedRows").build()) .build(); Function shouldFailRow = @@ -456,4 +455,24 @@ public void testErrorCount() throws Exception { assertEquals(expectedCount, count.getAttempted()); } } + + @Test + public void testManagedChoosesStorageApiForUnboundedWrites() { + PCollection batchInput = + p.apply(TestStream.create(SCHEMA).addElements(ROWS.get(0)).advanceWatermarkToInfinity()); + batchInput.apply( + Managed.write(Managed.BIGQUERY) + .withConfig(ImmutableMap.of("table", "project.dataset.table"))); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> + tr.getUniqueName() + .contains(BigQueryStorageWriteApiSchemaTransform.class.getSimpleName())) + .collect(Collectors.toList()); + assertThat(writeTransformProto.size(), greaterThan(0)); + p.enableAbandonedNodeEnforcement(false); + } } diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java index 8477726686ee..8e7e0862eff4 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java @@ -86,17 +86,20 @@ public class Managed { // TODO: Dynamically generate a list of supported transforms public static final String ICEBERG = "iceberg"; public static final String KAFKA = "kafka"; + public static final String BIGQUERY = "bigquery"; // Supported SchemaTransforms public static final Map READ_TRANSFORMS = ImmutableMap.builder() .put(ICEBERG, getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_READ)) .put(KAFKA, getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ)) + .put(BIGQUERY, getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_READ)) .build(); public static final Map WRITE_TRANSFORMS = ImmutableMap.builder() .put(ICEBERG, getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_WRITE)) .put(KAFKA, getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_WRITE)) + .put(BIGQUERY, getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_WRITE)) .build(); /** @@ -104,7 +107,9 @@ public class Managed { * supported managed sources are: * *

      - *
    • {@link Managed#ICEBERG} : Read from Apache Iceberg + *
    • {@link Managed#ICEBERG} : Read from Apache Iceberg tables + *
    • {@link Managed#KAFKA} : Read from Apache Kafka topics + *
    • {@link Managed#BIGQUERY} : Read from GCP BigQuery tables *
    */ public static ManagedTransform read(String source) { @@ -124,7 +129,9 @@ public static ManagedTransform read(String source) { * managed sinks are: * *
      - *
    • {@link Managed#ICEBERG} : Write to Apache Iceberg + *
    • {@link Managed#ICEBERG} : Write to Apache Iceberg tables + *
    • {@link Managed#KAFKA} : Write to Apache Kafka topics + *
    • {@link Managed#BIGQUERY} : Write to GCP BigQuery tables *
    */ public static ManagedTransform write(String sink) { diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java index 6f97983d3260..b705306b9478 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProvider.java @@ -117,7 +117,7 @@ protected void validate() { "Please specify a config or a config URL, but not both."); } - public @Nullable String resolveUnderlyingConfig() { + private Map resolveUnderlyingConfig() { String yamlTransformConfig = getConfig(); // If YAML string is empty, then attempt to read from YAML file if (Strings.isNullOrEmpty(yamlTransformConfig)) { @@ -131,7 +131,8 @@ protected void validate() { throw new RuntimeException(e); } } - return yamlTransformConfig; + + return YamlUtils.yamlStringToMap(yamlTransformConfig); } } @@ -152,34 +153,34 @@ protected SchemaTransform from(ManagedConfig managedConfig) { static class ManagedSchemaTransform extends SchemaTransform { private final ManagedConfig managedConfig; - private final Row underlyingTransformConfig; + private final Row underlyingRowConfig; private final SchemaTransformProvider underlyingTransformProvider; ManagedSchemaTransform( ManagedConfig managedConfig, SchemaTransformProvider underlyingTransformProvider) { // parse config before expansion to check if it matches underlying transform's config schema Schema transformConfigSchema = underlyingTransformProvider.configurationSchema(); - Row underlyingTransformConfig; + Row underlyingRowConfig; try { - underlyingTransformConfig = getRowConfig(managedConfig, transformConfigSchema); + underlyingRowConfig = getRowConfig(managedConfig, transformConfigSchema); } catch (Exception e) { throw new IllegalArgumentException( "Encountered an error when retrieving a Row configuration", e); } - this.managedConfig = managedConfig; - this.underlyingTransformConfig = underlyingTransformConfig; + this.underlyingRowConfig = underlyingRowConfig; this.underlyingTransformProvider = underlyingTransformProvider; + this.managedConfig = managedConfig; } @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { LOG.debug( - "Building transform \"{}\" with Row configuration: {}", + "Building transform \"{}\" with configuration: {}", underlyingTransformProvider.identifier(), - underlyingTransformConfig); + underlyingRowConfig); - return input.apply(underlyingTransformProvider.from(underlyingTransformConfig)); + return input.apply(underlyingTransformProvider.from(underlyingRowConfig)); } public ManagedConfig getManagedConfig() { @@ -201,16 +202,14 @@ Row getConfigurationRow() { } } + // May return an empty row (perhaps the underlying transform doesn't have any required + // parameters) @VisibleForTesting static Row getRowConfig(ManagedConfig config, Schema transformSchema) { - // May return an empty row (perhaps the underlying transform doesn't have any required - // parameters) - String yamlConfig = config.resolveUnderlyingConfig(); - Map configMap = YamlUtils.yamlStringToMap(yamlConfig); - - // The config Row object will be used to build the underlying SchemaTransform. - // If a mapping for the SchemaTransform exists, we use it to update parameter names and align - // with the underlying config schema + Map configMap = config.resolveUnderlyingConfig(); + // Build a config Row that will be used to build the underlying SchemaTransform. + // If a mapping for the SchemaTransform exists, we use it to update parameter names to align + // with the underlying SchemaTransform config schema Map mapping = MAPPINGS.get(config.getTransformIdentifier()); if (mapping != null && configMap != null) { Map remappedConfig = new HashMap<>(); @@ -227,7 +226,7 @@ static Row getRowConfig(ManagedConfig config, Schema transformSchema) { return YamlUtils.toBeamRow(configMap, transformSchema, false); } - // We load providers seperately, after construction, to prevent the + // We load providers separately, after construction, to prevent the // 'ManagedSchemaTransformProvider' from being initialized in a recursive loop // when being loaded using 'AutoValue'. synchronized Map getAllProviders() { diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java index 4cf752747be5..30476a30d373 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java @@ -50,9 +50,27 @@ public class ManagedTransformConstants { private static final Map KAFKA_WRITE_MAPPINGS = ImmutableMap.builder().put("data_format", "format").build(); + private static final Map BIGQUERY_READ_MAPPINGS = + ImmutableMap.builder() + .put("table", "table_spec") + .put("fields", "selected_fields") + .build(); + + private static final Map BIGQUERY_WRITE_MAPPINGS = + ImmutableMap.builder() + .put("at_least_once", "use_at_least_once_semantics") + .put("triggering_frequency", "triggering_frequency_seconds") + .build(); + public static final Map> MAPPINGS = ImmutableMap.>builder() .put(getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ), KAFKA_READ_MAPPINGS) .put(getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_WRITE), KAFKA_WRITE_MAPPINGS) + .put( + getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_READ), + BIGQUERY_READ_MAPPINGS) + .put( + getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_WRITE), + BIGQUERY_WRITE_MAPPINGS) .build(); } diff --git a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProviderTest.java b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProviderTest.java index e9edf8751e34..a287ec6260ce 100644 --- a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProviderTest.java +++ b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformProviderTest.java @@ -88,8 +88,7 @@ public void testGetConfigRowFromYamlFile() throws URISyntaxException { .withFieldValue("extra_integer", 123) .build(); Row configRow = - ManagedSchemaTransformProvider.getRowConfig( - config, new TestSchemaTransformProvider().configurationSchema()); + ManagedSchemaTransformProvider.getRowConfig(config, TestSchemaTransformProvider.SCHEMA); assertEquals(expectedRow, configRow); } diff --git a/sdks/python/apache_beam/transforms/managed.py b/sdks/python/apache_beam/transforms/managed.py index 22ee15b1de1c..cbcb6de56ed7 100644 --- a/sdks/python/apache_beam/transforms/managed.py +++ b/sdks/python/apache_beam/transforms/managed.py @@ -77,12 +77,16 @@ ICEBERG = "iceberg" KAFKA = "kafka" +BIGQUERY = "bigquery" _MANAGED_IDENTIFIER = "beam:transform:managed:v1" _EXPANSION_SERVICE_JAR_TARGETS = { "sdks:java:io:expansion-service:shadowJar": [KAFKA, ICEBERG], + "sdks:java:io:google-cloud-platform:expansion-service:shadowJar": [ + BIGQUERY + ] } -__all__ = ["ICEBERG", "KAFKA", "Read", "Write"] +__all__ = ["ICEBERG", "KAFKA", "BIGQUERY", "Read", "Write"] class Read(PTransform): @@ -90,6 +94,7 @@ class Read(PTransform): _READ_TRANSFORMS = { ICEBERG: ManagedTransforms.Urns.ICEBERG_READ.urn, KAFKA: ManagedTransforms.Urns.KAFKA_READ.urn, + BIGQUERY: ManagedTransforms.Urns.BIGQUERY_READ.urn } def __init__( @@ -130,6 +135,7 @@ class Write(PTransform): _WRITE_TRANSFORMS = { ICEBERG: ManagedTransforms.Urns.ICEBERG_WRITE.urn, KAFKA: ManagedTransforms.Urns.KAFKA_WRITE.urn, + BIGQUERY: ManagedTransforms.Urns.BIGQUERY_WRITE.urn } def __init__( From 26049437ebbe9cd1ef2574e2a704cc3403fe871a Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Tue, 12 Nov 2024 14:28:38 -0500 Subject: [PATCH 160/181] Switch to use ConcurrentMap for StringSetData (#33057) * Switch to use ConcurrentMap for StringSetData * address comments --- .../runners/core/metrics/StringSetData.java | 23 +++++----- .../core/metrics/StringSetCellTest.java | 44 +++++++++++++++++++ 2 files changed, 57 insertions(+), 10 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java index 4fc5d3beca31..5f9bb6392ec2 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/StringSetData.java @@ -20,8 +20,8 @@ import com.google.auto.value.AutoValue; import java.io.Serializable; import java.util.Arrays; -import java.util.HashSet; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import org.apache.beam.sdk.metrics.StringSetResult; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; @@ -54,13 +54,13 @@ public static StringSetData create(Set set) { if (set.isEmpty()) { return empty(); } - HashSet combined = new HashSet<>(); + Set combined = ConcurrentHashMap.newKeySet(); long stringSize = addUntilCapacity(combined, 0L, set); return new AutoValue_StringSetData(combined, stringSize); } /** Returns a {@link StringSetData} which is made from the given set in place. */ - private static StringSetData createInPlace(HashSet set, long stringSize) { + private static StringSetData createInPlace(Set set, long stringSize) { return new AutoValue_StringSetData(set, stringSize); } @@ -76,11 +76,12 @@ public static StringSetData empty() { *

    >Should only be used by {@link StringSetCell#add}. */ public StringSetData addAll(String... strings) { - HashSet combined; - if (this.stringSet() instanceof HashSet) { - combined = (HashSet) this.stringSet(); + Set combined; + if (this.stringSet() instanceof ConcurrentHashMap.KeySetView) { + combined = this.stringSet(); } else { - combined = new HashSet<>(this.stringSet()); + combined = ConcurrentHashMap.newKeySet(); + combined.addAll(this.stringSet()); } long stringSize = addUntilCapacity(combined, this.stringSize(), Arrays.asList(strings)); return StringSetData.createInPlace(combined, stringSize); @@ -95,7 +96,8 @@ public StringSetData combine(StringSetData other) { } else if (other.stringSet().isEmpty()) { return this; } else { - HashSet combined = new HashSet<>(this.stringSet()); + Set combined = ConcurrentHashMap.newKeySet(); + combined.addAll(this.stringSet()); long stringSize = addUntilCapacity(combined, this.stringSize(), other.stringSet()); return StringSetData.createInPlace(combined, stringSize); } @@ -105,7 +107,8 @@ public StringSetData combine(StringSetData other) { * Combines this {@link StringSetData} with others, all original StringSetData are left intact. */ public StringSetData combine(Iterable others) { - HashSet combined = new HashSet<>(this.stringSet()); + Set combined = ConcurrentHashMap.newKeySet(); + combined.addAll(this.stringSet()); long stringSize = this.stringSize(); for (StringSetData other : others) { stringSize = addUntilCapacity(combined, stringSize, other.stringSet()); @@ -120,7 +123,7 @@ public StringSetResult extractResult() { /** Add strings into set until reach capacity. Return the all string size of added set. */ private static long addUntilCapacity( - HashSet combined, long currentSize, Iterable others) { + Set combined, long currentSize, Iterable others) { if (currentSize > STRING_SET_SIZE_LIMIT) { // already at capacity return currentSize; diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java index f78ed01603fb..9497bbe43d0e 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/StringSetCellTest.java @@ -20,7 +20,13 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.junit.Assert; @@ -94,4 +100,42 @@ public void testReset() { assertThat(stringSetCell.getCumulative(), equalTo(StringSetData.empty())); assertThat(stringSetCell.getDirty(), equalTo(new DirtyState())); } + + @Test(timeout = 5000) + public void testStringSetCellConcurrentAddRetrieval() throws InterruptedException { + StringSetCell cell = new StringSetCell(MetricName.named("namespace", "name")); + AtomicBoolean finished = new AtomicBoolean(false); + Thread increment = + new Thread( + () -> { + for (long i = 0; !finished.get(); ++i) { + cell.add(String.valueOf(i)); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + break; + } + } + }); + increment.start(); + Instant start = Instant.now(); + try { + while (true) { + Set s = cell.getCumulative().stringSet(); + List snapshot = new ArrayList<>(s); + if (Instant.now().isAfter(start.plusSeconds(3)) && snapshot.size() > 0) { + finished.compareAndSet(false, true); + break; + } + } + } finally { + increment.interrupt(); + increment.join(); + } + + Set s = cell.getCumulative().stringSet(); + for (long i = 0; i < s.size(); ++i) { + assertTrue(s.contains(String.valueOf(i))); + } + } } From c03a5e09445d7d4279bb6450b0749362ca233d93 Mon Sep 17 00:00:00 2001 From: johnjcasey <95318300+johnjcasey@users.noreply.github.com> Date: Tue, 12 Nov 2024 15:27:25 -0500 Subject: [PATCH 161/181] Change dead partition detection to only look at the current topic (#33089) * Change dead partition detection to only look at the current topic, instead of looking at all topics * spotless * fix test, simplify existence check --- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 21 +++++-------------- .../sdk/io/kafka/ReadFromKafkaDoFnTest.java | 11 ++++++++++ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index add76c9682a0..4d7aa6b32aef 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -21,11 +21,9 @@ import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.TimeUnit; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; @@ -463,7 +461,8 @@ public ProcessContinuation processElement( // and move to process the next element. if (rawRecords.isEmpty()) { if (!topicPartitionExists( - kafkaSourceDescriptor.getTopicPartition(), consumer.listTopics())) { + kafkaSourceDescriptor.getTopicPartition(), + consumer.partitionsFor(kafkaSourceDescriptor.getTopic()))) { return ProcessContinuation.stop(); } if (timestampPolicy != null) { @@ -557,20 +556,10 @@ public ProcessContinuation processElement( } private boolean topicPartitionExists( - TopicPartition topicPartition, Map> topicListMap) { + TopicPartition topicPartition, List partitionInfos) { // Check if the current TopicPartition still exists. - Set existingTopicPartitions = new HashSet<>(); - for (List topicPartitionList : topicListMap.values()) { - topicPartitionList.forEach( - partitionInfo -> { - existingTopicPartitions.add( - new TopicPartition(partitionInfo.topic(), partitionInfo.partition())); - }); - } - if (!existingTopicPartitions.contains(topicPartition)) { - return false; - } - return true; + return partitionInfos.stream() + .anyMatch(partitionInfo -> partitionInfo.partition() == (topicPartition.partition())); } // see https://github.com/apache/beam/issues/25962 diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java index 52c141685760..cbff0f896619 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java @@ -252,6 +252,17 @@ public synchronized Map> listTopics() { topicPartition.topic(), topicPartition.partition(), null, null, null))); } + @Override + public synchronized List partitionsFor(String partition) { + if (this.isRemoved) { + return ImmutableList.of(); + } else { + return ImmutableList.of( + new PartitionInfo( + topicPartition.topic(), topicPartition.partition(), null, null, null)); + } + } + @Override public synchronized void assign(Collection partitions) { assertTrue(Iterables.getOnlyElement(partitions).equals(this.topicPartition)); From 9394f8561d29509a6e67fcd66254197bd0d46b2e Mon Sep 17 00:00:00 2001 From: Chris Ashcraft Date: Tue, 12 Nov 2024 18:58:38 -0600 Subject: [PATCH 162/181] [JdbcIO] Adding disableAutoCommit flag (#32988) * adding disableAutoCommit flag to ReadFn --------- Co-authored-by: Chris Ashcraft --- CHANGES.md | 1 + .../org/apache/beam/sdk/io/jdbc/JdbcIO.java | 86 +++++++++++++++++-- .../jdbc/JdbcReadSchemaTransformProvider.java | 9 ++ .../sdk/io/jdbc/JdbcSchemaIOProvider.java | 11 +++ sdks/python/apache_beam/io/jdbc.py | 5 ++ 5 files changed, 106 insertions(+), 6 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 261fafc024f3..c5731bcff313 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -94,6 +94,7 @@ * Fixed X (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * (Java) Fixed tearDown not invoked when DoFn throws on Portable Runners ([#18592](https://github.com/apache/beam/issues/18592), [#31381](https://github.com/apache/beam/issues/31381)). * (Java) Fixed protobuf error with MapState.remove() in Dataflow Streaming Java Legacy Runner without Streaming Engine ([#32892](https://github.com/apache/beam/issues/32892)). +* Adding flag to support conditionally disabling auto-commit in JdbcIO ReadFn ([#31111](https://github.com/apache/beam/issues/31111)) ## Security Fixes * Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 2f164fa3bb78..946c07f55763 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -333,6 +333,7 @@ public static Read read() { return new AutoValue_JdbcIO_Read.Builder() .setFetchSize(DEFAULT_FETCH_SIZE) .setOutputParallelization(true) + .setDisableAutoCommit(DEFAULT_DISABLE_AUTO_COMMIT) .build(); } @@ -341,6 +342,7 @@ public static ReadRows readRows() { return new AutoValue_JdbcIO_ReadRows.Builder() .setFetchSize(DEFAULT_FETCH_SIZE) .setOutputParallelization(true) + .setDisableAutoCommit(DEFAULT_DISABLE_AUTO_COMMIT) .setStatementPreparator(ignored -> {}) .build(); } @@ -356,6 +358,7 @@ public static ReadAll readAll() { return new AutoValue_JdbcIO_ReadAll.Builder() .setFetchSize(DEFAULT_FETCH_SIZE) .setOutputParallelization(true) + .setDisableAutoCommit(DEFAULT_DISABLE_AUTO_COMMIT) .build(); } @@ -372,6 +375,7 @@ public static ReadWithPartitions read .setPartitionColumnType(partitioningColumnType) .setNumPartitions(DEFAULT_NUM_PARTITIONS) .setFetchSize(DEFAULT_FETCH_SIZE) + .setDisableAutoCommit(DEFAULT_DISABLE_AUTO_COMMIT) .setUseBeamSchema(false) .build(); } @@ -389,6 +393,7 @@ public static ReadWithPartitions read .setPartitionsHelper(partitionsHelper) .setNumPartitions(DEFAULT_NUM_PARTITIONS) .setFetchSize(DEFAULT_FETCH_SIZE) + .setDisableAutoCommit(DEFAULT_DISABLE_AUTO_COMMIT) .setUseBeamSchema(false) .build(); } @@ -400,6 +405,7 @@ public static ReadWithPartitions readWithPartitions() { private static final long DEFAULT_BATCH_SIZE = 1000L; private static final long DEFAULT_MAX_BATCH_BUFFERING_DURATION = 200L; private static final int DEFAULT_FETCH_SIZE = 50_000; + private static final boolean DEFAULT_DISABLE_AUTO_COMMIT = true; // Default values used from fluent backoff. private static final Duration DEFAULT_INITIAL_BACKOFF = Duration.standardSeconds(1); private static final Duration DEFAULT_MAX_CUMULATIVE_BACKOFF = Duration.standardDays(1000); @@ -733,6 +739,9 @@ public abstract static class ReadRows extends PTransform expand(PBegin input) { ValueProvider query = checkStateNotNull(getQuery(), "withQuery() is required"); @@ -816,6 +836,7 @@ public PCollection expand(PBegin input) { .withCoder(RowCoder.of(schema)) .withRowMapper(SchemaUtil.BeamRowMapper.of(schema)) .withFetchSize(getFetchSize()) + .withDisableAutoCommit(getDisableAutoCommit()) .withOutputParallelization(getOutputParallelization()) .withStatementPreparator(checkStateNotNull(getStatementPreparator()))); rows.setRowSchema(schema); @@ -872,6 +893,9 @@ public abstract static class Read extends PTransform> @Pure abstract boolean getOutputParallelization(); + @Pure + abstract boolean getDisableAutoCommit(); + @Pure abstract Builder toBuilder(); @@ -892,6 +916,8 @@ abstract Builder setDataSourceProviderFn( abstract Builder setOutputParallelization(boolean outputParallelization); + abstract Builder setDisableAutoCommit(boolean disableAutoCommit); + abstract Read build(); } @@ -958,6 +984,15 @@ public Read withOutputParallelization(boolean outputParallelization) { return toBuilder().setOutputParallelization(outputParallelization).build(); } + /** + * Whether to disable auto commit on read. Defaults to true if not provided. The need for this + * config varies depending on the database platform. Informix requires this to be set to false + * while Postgres requires this to be set to true. + */ + public Read withDisableAutoCommit(boolean disableAutoCommit) { + return toBuilder().setDisableAutoCommit(disableAutoCommit).build(); + } + @Override public PCollection expand(PBegin input) { ValueProvider query = checkArgumentNotNull(getQuery(), "withQuery() is required"); @@ -974,6 +1009,7 @@ public PCollection expand(PBegin input) { .withRowMapper(rowMapper) .withFetchSize(getFetchSize()) .withOutputParallelization(getOutputParallelization()) + .withDisableAutoCommit(getDisableAutoCommit()) .withParameterSetter( (element, preparedStatement) -> { if (getStatementPreparator() != null) { @@ -1029,6 +1065,8 @@ public abstract static class ReadAll abstract boolean getOutputParallelization(); + abstract boolean getDisableAutoCommit(); + abstract Builder toBuilder(); @AutoValue.Builder @@ -1049,6 +1087,8 @@ abstract Builder setParameterSetter( abstract Builder setOutputParallelization(boolean outputParallelization); + abstract Builder setDisableAutoCommit(boolean disableAutoCommit); + abstract ReadAll build(); } @@ -1127,6 +1167,15 @@ public ReadAll withOutputParallelization(boolean outputPara return toBuilder().setOutputParallelization(outputParallelization).build(); } + /** + * Whether to disable auto commit on read. Defaults to true if not provided. The need for this + * config varies depending on the database platform. Informix requires this to be set to false + * while Postgres requires this to be set to true. + */ + public ReadAll withDisableAutoCommit(boolean disableAutoCommit) { + return toBuilder().setDisableAutoCommit(disableAutoCommit).build(); + } + private @Nullable Coder inferCoder( CoderRegistry registry, SchemaRegistry schemaRegistry) { if (getCoder() != null) { @@ -1173,7 +1222,8 @@ public PCollection expand(PCollection input) { checkStateNotNull(getQuery()), checkStateNotNull(getParameterSetter()), checkStateNotNull(getRowMapper()), - getFetchSize()))) + getFetchSize(), + getDisableAutoCommit()))) .setCoder(coder); if (getOutputParallelization()) { @@ -1254,6 +1304,9 @@ public abstract static class ReadWithPartitions @Pure abstract @Nullable JdbcReadWithPartitionsHelper getPartitionsHelper(); + @Pure + abstract boolean getDisableAutoCommit(); + @Pure abstract Builder toBuilder(); @@ -1287,6 +1340,8 @@ abstract Builder setPartitionColumnType( abstract Builder setPartitionsHelper( JdbcReadWithPartitionsHelper partitionsHelper); + abstract Builder setDisableAutoCommit(boolean disableAutoCommit); + abstract ReadWithPartitions build(); } @@ -1337,6 +1392,16 @@ public ReadWithPartitions withFetchSize(int fetchSize) { return toBuilder().setFetchSize(fetchSize).build(); } + /** + * Whether to disable auto commit on read. Defaults to true if not provided. The need for this + * config varies depending on the database platform. Informix requires this to be set to false + * while Postgres requires this to be set to true. + */ + public ReadWithPartitions withDisableAutoCommit( + boolean disableAutoCommit) { + return toBuilder().setDisableAutoCommit(disableAutoCommit).build(); + } + /** Data output type is {@link Row}, and schema is auto-inferred from the database. */ public ReadWithPartitions withRowOutput() { return toBuilder().setUseBeamSchema(true).build(); @@ -1419,7 +1484,8 @@ && getLowerBound() instanceof Comparable) { .withQuery(query) .withDataSourceProviderFn(dataSourceProviderFn) .withRowMapper(checkStateNotNull(partitionsHelper)) - .withFetchSize(getFetchSize())) + .withFetchSize(getFetchSize()) + .withDisableAutoCommit(getDisableAutoCommit())) .apply( MapElements.via( new SimpleFunction< @@ -1487,7 +1553,8 @@ public KV> apply( .withRowMapper(rowMapper) .withFetchSize(getFetchSize()) .withParameterSetter(checkStateNotNull(partitionsHelper)) - .withOutputParallelization(false); + .withOutputParallelization(false) + .withDisableAutoCommit(getDisableAutoCommit()); if (getUseBeamSchema()) { checkStateNotNull(schema); @@ -1537,6 +1604,7 @@ private static class ReadFn extends DoFn parameterSetter; private final RowMapper rowMapper; private final int fetchSize; + private final boolean disableAutoCommit; private @Nullable DataSource dataSource; private @Nullable Connection connection; @@ -1546,12 +1614,14 @@ private ReadFn( ValueProvider query, PreparedStatementSetter parameterSetter, RowMapper rowMapper, - int fetchSize) { + int fetchSize, + boolean disableAutoCommit) { this.dataSourceProviderFn = dataSourceProviderFn; this.query = query; this.parameterSetter = parameterSetter; this.rowMapper = rowMapper; this.fetchSize = fetchSize; + this.disableAutoCommit = disableAutoCommit; } @Setup @@ -1577,8 +1647,12 @@ public void processElement(ProcessContext context) throws Exception { Connection connection = getConnection(); // PostgreSQL requires autocommit to be disabled to enable cursor streaming // see https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor - LOG.info("Autocommit has been disabled"); - connection.setAutoCommit(false); + // This option is configurable as Informix will error + // if calling setAutoCommit on a non-logged database + if (disableAutoCommit) { + LOG.info("Autocommit has been disabled"); + connection.setAutoCommit(false); + } try (PreparedStatement statement = connection.prepareStatement( query.get(), ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)) { diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java index 0139207235a0..435bfc138b5b 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java @@ -117,6 +117,10 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { if (outputParallelization != null) { readRows = readRows.withOutputParallelization(outputParallelization); } + Boolean disableAutoCommit = config.getDisableAutoCommit(); + if (disableAutoCommit != null) { + readRows = readRows.withDisableAutoCommit(disableAutoCommit); + } return PCollectionRowTuple.of("output", input.getPipeline().apply(readRows)); } } @@ -174,6 +178,9 @@ public abstract static class JdbcReadSchemaTransformConfiguration implements Ser @Nullable public abstract Boolean getOutputParallelization(); + @Nullable + public abstract Boolean getDisableAutoCommit(); + @Nullable public abstract String getDriverJars(); @@ -238,6 +245,8 @@ public abstract static class Builder { public abstract Builder setOutputParallelization(Boolean value); + public abstract Builder setDisableAutoCommit(Boolean value); + public abstract Builder setDriverJars(String value); public abstract JdbcReadSchemaTransformConfiguration build(); diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java index 4b5dc0d7e24a..30012465eb9e 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java @@ -65,6 +65,7 @@ public Schema configurationSchema() { .addNullableField("readQuery", FieldType.STRING) .addNullableField("writeStatement", FieldType.STRING) .addNullableField("fetchSize", FieldType.INT16) + .addNullableField("disableAutoCommit", FieldType.BOOLEAN) .addNullableField("outputParallelization", FieldType.BOOLEAN) .addNullableField("autosharding", FieldType.BOOLEAN) // Partitioning support. If you specify a partition column we will use that instead of @@ -140,6 +141,11 @@ public PCollection expand(PBegin input) { readRows = readRows.withFetchSize(fetchSize); } + @Nullable Boolean disableAutoCommit = config.getBoolean("disableAutoCommit"); + if (disableAutoCommit != null) { + readRows = readRows.withDisableAutoCommit(disableAutoCommit); + } + return input.apply(readRows); } else { @@ -163,6 +169,11 @@ public PCollection expand(PBegin input) { readRows = readRows.withOutputParallelization(outputParallelization); } + @Nullable Boolean disableAutoCommit = config.getBoolean("disableAutoCommit"); + if (disableAutoCommit != null) { + readRows = readRows.withDisableAutoCommit(disableAutoCommit); + } + return input.apply(readRows); } } diff --git a/sdks/python/apache_beam/io/jdbc.py b/sdks/python/apache_beam/io/jdbc.py index 3fef1f5fee35..d4ece0c7bc29 100644 --- a/sdks/python/apache_beam/io/jdbc.py +++ b/sdks/python/apache_beam/io/jdbc.py @@ -125,6 +125,7 @@ def default_io_expansion_service(classpath=None): ('read_query', typing.Optional[str]), ('write_statement', typing.Optional[str]), ('fetch_size', typing.Optional[np.int16]), + ('disable_autocommit', typing.Optional[bool]), ('output_parallelization', typing.Optional[bool]), ('autosharding', typing.Optional[bool]), ('partition_column', typing.Optional[str]), @@ -236,6 +237,7 @@ def __init__( write_statement=statement, read_query=None, fetch_size=None, + disable_autocommit=None, output_parallelization=None, autosharding=autosharding, max_connections=max_connections, @@ -286,6 +288,7 @@ def __init__( username, password, query=None, + disable_autocommit=None, output_parallelization=None, fetch_size=None, partition_column=None, @@ -305,6 +308,7 @@ def __init__( :param username: database username :param password: database password :param query: sql query to be executed + :param disable_autocommit: disable autocommit on read :param output_parallelization: is output parallelization on :param fetch_size: how many rows to fetch :param partition_column: enable partitioned reads by splitting on this @@ -350,6 +354,7 @@ def __init__( write_statement=None, read_query=query, fetch_size=fetch_size, + disable_autocommit=disable_autocommit, output_parallelization=output_parallelization, autosharding=None, max_connections=max_connections, From a6daf6dc89d0e3ac7ea55defecbeeb477b39d72c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 09:42:58 -0500 Subject: [PATCH 163/181] Bump cloud.google.com/go/bigquery from 1.63.1 to 1.64.0 in /sdks (#32993) Bumps [cloud.google.com/go/bigquery](https://github.com/googleapis/google-cloud-go) from 1.63.1 to 1.64.0. - [Release notes](https://github.com/googleapis/google-cloud-go/releases) - [Changelog](https://github.com/googleapis/google-cloud-go/blob/main/CHANGES.md) - [Commits](https://github.com/googleapis/google-cloud-go/compare/bigquery/v1.63.1...spanner/v1.64.0) --- updated-dependencies: - dependency-name: cloud.google.com/go/bigquery dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- sdks/go.mod | 2 +- sdks/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdks/go.mod b/sdks/go.mod index ff711cbe91b0..35882c132bb1 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -23,7 +23,7 @@ module github.com/apache/beam/sdks/v2 go 1.21.0 require ( - cloud.google.com/go/bigquery v1.63.1 + cloud.google.com/go/bigquery v1.64.0 cloud.google.com/go/bigtable v1.33.0 cloud.google.com/go/datastore v1.20.0 cloud.google.com/go/profiler v0.4.1 diff --git a/sdks/go.sum b/sdks/go.sum index c24cb10126c8..b18c381fb072 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -133,8 +133,8 @@ cloud.google.com/go/bigquery v1.47.0/go.mod h1:sA9XOgy0A8vQK9+MWhEQTY6Tix87M/Zur cloud.google.com/go/bigquery v1.48.0/go.mod h1:QAwSz+ipNgfL5jxiaK7weyOhzdoAy1zFm0Nf1fysJac= cloud.google.com/go/bigquery v1.49.0/go.mod h1:Sv8hMmTFFYBlt/ftw2uN6dFdQPzBlREY9yBh7Oy7/4Q= cloud.google.com/go/bigquery v1.50.0/go.mod h1:YrleYEh2pSEbgTBZYMJ5SuSr0ML3ypjRB1zgf7pvQLU= -cloud.google.com/go/bigquery v1.63.1 h1:/6syiWrSpardKNxdvldS5CUTRJX1iIkSPXCjLjiGL+g= -cloud.google.com/go/bigquery v1.63.1/go.mod h1:ufaITfroCk17WTqBhMpi8CRjsfHjMX07pDrQaRKKX2o= +cloud.google.com/go/bigquery v1.64.0 h1:vSSZisNyhr2ioJE1OuYBQrnrpB7pIhRQm4jfjc7E/js= +cloud.google.com/go/bigquery v1.64.0/go.mod h1:gy8Ooz6HF7QmA+TRtX8tZmXBKH5mCFBwUApGAb3zI7Y= cloud.google.com/go/bigtable v1.33.0 h1:2BDaWLRAwXO14DJL/u8crbV2oUbMZkIa2eGq8Yao1bk= cloud.google.com/go/bigtable v1.33.0/go.mod h1:HtpnH4g25VT1pejHRtInlFPnN5sjTxbQlsYBjh9t5l0= cloud.google.com/go/billing v1.4.0/go.mod h1:g9IdKBEFlItS8bTtlrZdVLWSSdSyFUZKXNS02zKMOZY= From 87e251a95ab9922857c904cfce8c6f79ab041e7b Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 13 Nov 2024 10:16:49 -0500 Subject: [PATCH 164/181] Add Kafka tag to review bot (#32975) --- .github/REVIEWERS.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/REVIEWERS.yml b/.github/REVIEWERS.yml index 38adde6a7820..dba969180c45 100644 --- a/.github/REVIEWERS.yml +++ b/.github/REVIEWERS.yml @@ -61,6 +61,12 @@ labels: reviewers: - svetakvsundhar exclusionList: [] + - name: kafka + reviewers: + - johnjcasey + - fozzie15 + - Dippatel98 + - sjvanrossum - name: Build reviewers: - damccorm From f25ac698341867c704a6223c7c6187155e952af9 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 13 Nov 2024 10:17:11 -0500 Subject: [PATCH 165/181] Lineage support for JdbcIO (#33062) * Lineage support for JdbcIO * Report table the pipeline read from and write to * add logs and documentation --- sdks/java/io/jdbc/build.gradle | 2 + .../org/apache/beam/sdk/io/jdbc/JdbcIO.java | 46 ++- .../org/apache/beam/sdk/io/jdbc/JdbcUtil.java | 261 ++++++++++++++++++ .../apache/beam/sdk/io/jdbc/JdbcIOTest.java | 18 +- .../apache/beam/sdk/io/jdbc/JdbcUtilTest.java | 101 +++++++ 5 files changed, 420 insertions(+), 8 deletions(-) diff --git a/sdks/java/io/jdbc/build.gradle b/sdks/java/io/jdbc/build.gradle index 2015bf173978..8c5fa685fdad 100644 --- a/sdks/java/io/jdbc/build.gradle +++ b/sdks/java/io/jdbc/build.gradle @@ -48,6 +48,8 @@ dependencies { testImplementation library.java.testcontainers_mysql testImplementation library.java.testcontainers_postgresql testImplementation 'mysql:mysql-connector-java:8.0.22' + // TODO(https://github.com/apache/beam/issues/31678) HikariCP 5.x requires Java11+ + testImplementation 'com.zaxxer:HikariCP:4.0.3' testRuntimeOnly library.java.slf4j_jdk14 testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 946c07f55763..ab2e3e07e817 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -39,6 +39,7 @@ import java.util.Collection; import java.util.HashSet; import java.util.List; +import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; @@ -55,6 +56,7 @@ import org.apache.beam.sdk.io.jdbc.JdbcUtil.PartitioningFn; import org.apache.beam.sdk.io.jdbc.SchemaUtil.FieldWithIndex; import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; @@ -93,6 +95,7 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.sdk.values.TypeDescriptors.TypeVariableExtractor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.apache.commons.dbcp2.BasicDataSource; import org.apache.commons.dbcp2.DataSourceConnectionFactory; import org.apache.commons.dbcp2.PoolableConnectionFactory; @@ -1608,6 +1611,7 @@ private static class ReadFn extends DoFn dataSourceProviderFn, @@ -1630,10 +1634,26 @@ public void setup() throws Exception { } private Connection getConnection() throws SQLException { - if (this.connection == null) { - this.connection = checkStateNotNull(this.dataSource).getConnection(); + Connection connection = this.connection; + if (connection == null) { + DataSource validSource = checkStateNotNull(this.dataSource); + connection = checkStateNotNull(validSource).getConnection(); + this.connection = connection; + + // report Lineage if not haven't done so + String table = JdbcUtil.extractTableFromReadQuery(query.get()); + if (!table.equals(reportedLineage)) { + JdbcUtil.FQNComponents fqn = JdbcUtil.FQNComponents.of(validSource); + if (fqn == null) { + fqn = JdbcUtil.FQNComponents.of(connection); + } + if (fqn != null) { + fqn.reportLineage(Lineage.getSources(), table); + reportedLineage = table; + } + } } - return this.connection; + return connection; } @ProcessElement @@ -2645,6 +2665,7 @@ abstract Builder setMaxBatchBufferingDuration( private @Nullable DataSource dataSource; private @Nullable Connection connection; private @Nullable PreparedStatement preparedStatement; + private @Nullable String reportedLineage; private static @Nullable FluentBackoff retryBackOff; public WriteFn(WriteFnSpec spec) { @@ -2677,11 +2698,28 @@ public void setup() { private Connection getConnection() throws SQLException { Connection connection = this.connection; if (connection == null) { - connection = checkStateNotNull(dataSource).getConnection(); + DataSource validSource = checkStateNotNull(dataSource); + connection = validSource.getConnection(); connection.setAutoCommit(false); preparedStatement = connection.prepareStatement(checkStateNotNull(spec.getStatement()).get()); this.connection = connection; + + // report Lineage if haven't done so + String table = spec.getTable(); + if (Strings.isNullOrEmpty(table) && spec.getStatement() != null) { + table = JdbcUtil.extractTableFromWriteQuery(spec.getStatement().get()); + } + if (!Objects.equals(table, reportedLineage)) { + JdbcUtil.FQNComponents fqn = JdbcUtil.FQNComponents.of(validSource); + if (fqn == null) { + fqn = JdbcUtil.FQNComponents.of(connection); + } + if (fqn != null) { + fqn.reportLineage(Lineage.getSinks(), table); + reportedLineage = table; + } + } } return connection; } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java index b3f46492f745..c0f7d68899b3 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java @@ -19,12 +19,18 @@ import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; +import com.google.auto.value.AutoValue; import java.io.File; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.URI; import java.net.URL; import java.nio.channels.ReadableByteChannel; import java.nio.channels.WritableByteChannel; import java.nio.file.Paths; +import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.Date; import java.sql.JDBCType; import java.sql.PreparedStatement; @@ -33,6 +39,7 @@ import java.sql.Time; import java.sql.Timestamp; import java.util.ArrayList; +import java.util.Arrays; import java.util.Calendar; import java.util.Collection; import java.util.EnumMap; @@ -40,12 +47,17 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Properties; import java.util.TimeZone; import java.util.UUID; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.IntStream; +import javax.sql.DataSource; import org.apache.beam.sdk.io.FileSystems; import org.apache.beam.sdk.io.fs.ResourceId; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.logicaltypes.FixedPrecisionNumeric; import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant; @@ -57,6 +69,8 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Splitter; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; @@ -563,4 +577,251 @@ public KV> mapRow(ResultSet resultSet) throws Excep } } }); + + @AutoValue + abstract static class JdbcUrl { + abstract String getScheme(); + + abstract @Nullable String getHostAndPort(); + + abstract String getDatabase(); + + /** + * Parse Jdbc Url String and return an {@link JdbcUrl} object, or return null for unsupported + * formats. + * + *

    Example of supported format: + * + *

      + *
    • "jdbc:postgresql://localhost:5432/postgres" + *
    • "jdbc:mysql://127.0.0.1:3306/db" + *
    • "jdbc:oracle:thin:HR/hr@localhost:5221:orcl" + *
    • "jdbc:derby:memory:testDB;create=true" + *
    • "jdbc:oracle:thin:@//myhost.example.com:1521/my_service" + *
    • "jdbc:mysql:///cloud_sql" (GCP CloudSQL, supported if Connection name setup via + * HikariDataSource) + *
    + */ + static @Nullable JdbcUrl of(String url) { + if (Strings.isNullOrEmpty(url) || !url.startsWith("jdbc:")) { + return null; + } + String cleanUri = url.substring(5); + + // 1. Resolve the scheme + // handle sub-schemes e.g. oracle:thin (RAC) + int start = cleanUri.indexOf("//"); + if (start != -1) { + List subschemes = Splitter.on(':').splitToList(cleanUri.substring(0, start)); + cleanUri = subschemes.get(0) + ":" + cleanUri.substring(start); + } else { + // not a URI format e.g. oracle:thin (non-RAC); derby in memory + if (cleanUri.startsWith("derby:")) { + String scheme = "derby"; + int endUrl = cleanUri.indexOf(";"); + if (endUrl == -1) { + endUrl = cleanUri.length(); + } + List components = + Splitter.on(':').splitToList(cleanUri.substring("derby:".length(), endUrl)); + if (components.size() < 2) { + return null; + } + return new AutoValue_JdbcUtil_JdbcUrl(scheme, components.get(0), components.get(1)); + } else if (cleanUri.startsWith("oracle:thin:")) { + String scheme = "oracle"; + + int startHost = cleanUri.indexOf("@"); + if (startHost == -1) { + return null; + } + List components = Splitter.on(':').splitToList(cleanUri.substring(startHost + 1)); + if (components.size() < 3) { + return null; + } + return new AutoValue_JdbcUtil_JdbcUrl( + scheme, components.get(0) + ":" + components.get(1), components.get(2)); + } else { + return null; + } + } + + URI uri = URI.create(cleanUri); + String scheme = uri.getScheme(); + + // 2. resolve database + @Nullable String path = uri.getPath(); + if (path != null && path.startsWith("/")) { + path = path.substring(1); + } + if (path == null) { + return null; + } + + // 3. resolve host and port + // treat as self-managed SQL instance + @Nullable String hostAndPort = null; + @Nullable String host = uri.getHost(); + if (host != null) { + int port = uri.getPort(); + hostAndPort = port != -1 ? host + ":" + port : null; + } + return new AutoValue_JdbcUtil_JdbcUrl(scheme, hostAndPort, path); + } + } + + /** Jdbc fully qualified name components. */ + @AutoValue + abstract static class FQNComponents { + abstract String getScheme(); + + abstract Iterable getSegments(); + + void reportLineage(Lineage lineage, @Nullable String table) { + ImmutableList.Builder builder = ImmutableList.builder().addAll(getSegments()); + if (table != null && !table.isEmpty()) { + builder.add(table); + } + lineage.add(getScheme(), builder.build()); + } + + /** Fail-safely extract FQN from supported DataSource. Return null if failed. */ + static @Nullable FQNComponents of(DataSource dataSource) { + // Supported case CloudSql using HikariDataSource + // Had to retrieve properties via Reflection to avoid introduce mandatory Hikari dependencies + String maybeSqlInstance; + String url; + try { + Class hikariClass = Class.forName("com.zaxxer.hikari.HikariDataSource"); + if (!hikariClass.isInstance(dataSource)) { + return null; + } + Method getProperties = hikariClass.getMethod("getDataSourceProperties"); + Properties properties = (Properties) getProperties.invoke(dataSource); + if (properties == null) { + return null; + } + maybeSqlInstance = properties.getProperty("cloudSqlInstance"); + if (maybeSqlInstance == null) { + // not a cloudSqlInstance + return null; + } + Method getUrl = hikariClass.getMethod("getJdbcUrl"); + url = (String) getUrl.invoke(dataSource); + if (url == null) { + return null; + } + } catch (ClassNotFoundException + | InvocationTargetException + | IllegalAccessException + | NoSuchMethodException e) { + return null; + } + + JdbcUrl jdbcUrl = JdbcUrl.of(url); + if (jdbcUrl == null) { + LOG.info("Failed to parse JdbcUrl {}. Lineage will not be reported.", url); + return null; + } + + String scheme = "cloudsql_" + jdbcUrl.getScheme(); + ImmutableList.Builder segments = ImmutableList.builder(); + List sqlInstance = Arrays.asList(maybeSqlInstance.split(":")); + if (sqlInstance.size() > 3) { + // project name contains ":" + segments + .add(String.join(":", sqlInstance.subList(0, sqlInstance.size() - 2))) + .add(sqlInstance.get(sqlInstance.size() - 2)) + .add(sqlInstance.get(sqlInstance.size() - 1)); + } else { + segments.addAll(Arrays.asList(maybeSqlInstance.split(":"))); + } + segments.add(jdbcUrl.getDatabase()); + return new AutoValue_JdbcUtil_FQNComponents(scheme, segments.build()); + } + + /** Fail-safely extract FQN from an active connection. Return null if failed. */ + static @Nullable FQNComponents of(Connection connection) { + try { + DatabaseMetaData metadata = connection.getMetaData(); + if (metadata == null) { + // usually not-null, but can be null when running a mock + return null; + } + String url = metadata.getURL(); + if (url == null) { + // usually not-null, but can be null when running a mock + return null; + } + return of(url); + } catch (Exception e) { + // suppressed + return null; + } + } + + /** + * Fail-safely parse FQN from a Jdbc URL. Return null if failed. + * + *

    e.g. + * + *

    jdbc:postgresql://localhost:5432/postgres -> (postgresql, [localhost:5432, postgres]) + * + *

    jdbc:mysql://127.0.0.1:3306/db -> (mysql, [127.0.0.1:3306, db]) + */ + @VisibleForTesting + static @Nullable FQNComponents of(String url) { + JdbcUrl jdbcUrl = JdbcUrl.of(url); + if (jdbcUrl == null || jdbcUrl.getHostAndPort() == null) { + LOG.info("Failed to parse JdbcUrl {}. Lineage will not be reported.", url); + return null; + } + String hostAndPort = jdbcUrl.getHostAndPort(); + if (hostAndPort == null) { + LOG.info("Failed to parse host/port from JdbcUrl {}. Lineage will not be reported.", url); + return null; + } + + return new AutoValue_JdbcUtil_FQNComponents( + jdbcUrl.getScheme(), ImmutableList.of(hostAndPort, jdbcUrl.getDatabase())); + } + } + + private static final Pattern READ_STATEMENT_PATTERN = + Pattern.compile( + "SELECT\\s+.+?\\s+FROM\\s+\\[?(?[^\\s\\[\\]]+)\\]?", Pattern.CASE_INSENSITIVE); + + private static final Pattern WRITE_STATEMENT_PATTERN = + Pattern.compile( + "INSERT\\s+INTO\\s+\\[?(?[^\\s\\[\\]]+)\\]?", Pattern.CASE_INSENSITIVE); + + /** Extract table name a SELECT statement. Return empty string if fail to extract. */ + static String extractTableFromReadQuery(@Nullable String query) { + if (query == null) { + return ""; + } + Matcher matchRead = READ_STATEMENT_PATTERN.matcher(query); + if (matchRead.find()) { + String matched = matchRead.group("tableName"); + if (matched != null) { + return matched; + } + } + return ""; + } + + /** Extract table name from an INSERT statement. Return empty string if fail to extract. */ + static String extractTableFromWriteQuery(@Nullable String query) { + if (query == null) { + return ""; + } + Matcher matchRead = WRITE_STATEMENT_PATTERN.matcher(query); + if (matchRead.find()) { + String matched = matchRead.group("tableName"); + if (matched != null) { + return matched; + } + } + return ""; + } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index 013fc7996a95..a04f8c4e762f 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -21,6 +21,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertEquals; @@ -71,6 +72,7 @@ import org.apache.beam.sdk.io.jdbc.JdbcIO.DataSourceConfiguration; import org.apache.beam.sdk.io.jdbc.JdbcIO.PoolableDataSourceProvider; import org.apache.beam.sdk.io.jdbc.JdbcUtil.PartitioningFn; +import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.logicaltypes.FixedPrecisionNumeric; @@ -243,7 +245,10 @@ public void testRead() { Iterable expectedValues = TestRow.getExpectedValues(0, EXPECTED_ROW_COUNT); PAssert.that(rows).containsInAnyOrder(expectedValues); - pipeline.run(); + PipelineResult result = pipeline.run(); + assertThat( + Lineage.query(result.metrics(), Lineage.Type.SOURCE), + hasItem(Lineage.getFqName("derby", ImmutableList.of("memory", "testDB", READ_TABLE_NAME)))); } @Test @@ -263,7 +268,10 @@ public void testReadWithSingleStringParameter() { Iterable expectedValues = Collections.singletonList(TestRow.fromSeed(1)); PAssert.that(rows).containsInAnyOrder(expectedValues); - pipeline.run(); + PipelineResult result = pipeline.run(); + assertThat( + Lineage.query(result.metrics(), Lineage.Type.SOURCE), + hasItem(Lineage.getFqName("derby", ImmutableList.of("memory", "testDB", READ_TABLE_NAME)))); } @Test @@ -531,9 +539,11 @@ public void testWrite() throws Exception { ArrayList> data = getDataToWrite(EXPECTED_ROW_COUNT); pipeline.apply(Create.of(data)).apply(getJdbcWrite(tableName)); - pipeline.run(); - + PipelineResult result = pipeline.run(); assertRowCount(DATA_SOURCE, tableName, EXPECTED_ROW_COUNT); + assertThat( + Lineage.query(result.metrics(), Lineage.Type.SINK), + hasItem(Lineage.getFqName("derby", ImmutableList.of("memory", "testDB", tableName)))); } finally { DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName); } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java index 5b2e9f27f0a8..356d6c7f8de7 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java @@ -22,7 +22,10 @@ import static org.hamcrest.number.IsCloseTo.closeTo; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; import java.io.File; import java.io.IOException; import java.net.URL; @@ -34,12 +37,17 @@ import java.sql.SQLException; import java.util.ArrayList; import java.util.List; +import java.util.Map.Entry; import java.util.Random; +import javax.sql.DataSource; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; import org.junit.Rule; import org.junit.Test; @@ -264,4 +272,97 @@ public void testSavesFilesAsExpected() throws IOException { expectedContent2, new String(Files.readAllBytes(Paths.get(urls[1].getFile())), StandardCharsets.UTF_8)); } + + @Test + public void testJdbcUrl() { + ImmutableMap> testCases = + ImmutableMap.>builder() + .put( + "jdbc:postgresql://localhost:5432/postgres", + ImmutableList.of("postgresql", "localhost:5432", "postgres")) + .put( + "jdbc:mysql://127.0.0.1:3306/db", ImmutableList.of("mysql", "127.0.0.1:3306", "db")) + .put( + "jdbc:oracle:thin:HR/hr@localhost:5221:orcl", + ImmutableList.of("oracle", "localhost:5221", "orcl")) + .put( + "jdbc:derby:memory:testDB;create=true", + ImmutableList.of("derby", "memory", "testDB")) + .put( + "jdbc:oracle:thin:@//myhost.example.com:1521/my_service", + ImmutableList.of("oracle", "myhost.example.com:1521", "my_service")) + .put("jdbc:mysql:///cloud_sql", ImmutableList.of("mysql", "", "cloud_sql")) + .put("invalid", ImmutableList.of()) + .build(); + for (Entry> entry : testCases.entrySet()) { + JdbcUtil.JdbcUrl jdbcUrl = JdbcUtil.JdbcUrl.of(entry.getKey()); + + System.out.println(entry.getKey()); + if (entry.getValue().equals(ImmutableList.of())) { + assertNull(jdbcUrl); + } else { + assertEquals(entry.getValue().get(0), jdbcUrl.getScheme()); + assertEquals( + entry.getValue().get(1), + jdbcUrl.getHostAndPort() == null ? "" : jdbcUrl.getHostAndPort()); + assertEquals(entry.getValue().get(2), jdbcUrl.getDatabase()); + } + } + } + + @Test + public void testFqnFromHikariDataSourcePostgreSql() { + HikariConfig config = new HikariConfig(); + config.setJdbcUrl("jdbc:postgresql:///postgres"); + config.setUsername("postgres"); + config.addDataSourceProperty( + "cloudSqlInstance", "example.com:project:some-region:instance-name"); + // instead of `new HikariDataSource(config)`, initialize an empty source to avoid creation + // of actual connection pool + DataSource dataSource = new HikariDataSource(); + config.validate(); + config.copyStateTo((HikariConfig) dataSource); + JdbcUtil.FQNComponents components = JdbcUtil.FQNComponents.of(dataSource); + assertEquals("cloudsql_postgresql", components.getScheme()); + assertEquals( + ImmutableList.of("example.com:project", "some-region", "instance-name", "postgres"), + components.getSegments()); + } + + @Test + public void testFqnFromHikariDataSourceMySql() { + HikariConfig config = new HikariConfig(); + config.setJdbcUrl("jdbc:mysql:///db"); + config.setUsername("root"); + config.addDataSourceProperty("cloudSqlInstance", "some-project:US:instance-name"); + // instead of `new HikariDataSource(config)`, initialize an empty source to avoid creation + // of actual connection pool + DataSource dataSource = new HikariDataSource(); + config.validate(); + config.copyStateTo((HikariConfig) dataSource); + JdbcUtil.FQNComponents components = JdbcUtil.FQNComponents.of(dataSource); + assertEquals("cloudsql_mysql", components.getScheme()); + assertEquals( + ImmutableList.of("some-project", "US", "instance-name", "db"), components.getSegments()); + } + + @Test + public void testExtractTableFromQuery() { + ImmutableList> readCases = + ImmutableList.of( + KV.of("select * from table_1", "table_1"), + KV.of("SELECT a, b FROM [table-2]", "table-2"), + KV.of("drop table not-select", "")); + for (KV testCase : readCases) { + assertEquals(testCase.getValue(), JdbcUtil.extractTableFromReadQuery(testCase.getKey())); + } + ImmutableList> writeCases = + ImmutableList.of( + KV.of("insert into table_1 values ...", "table_1"), + KV.of("INSERT INTO [table-2] values ...", "table-2"), + KV.of("drop table not-select", "")); + for (KV testCase : writeCases) { + assertEquals(testCase.getValue(), JdbcUtil.extractTableFromWriteQuery(testCase.getKey())); + } + } } From 63fc0db0b5c10a1e6170a30166dd271dc93aafd2 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Wed, 13 Nov 2024 10:18:44 -0500 Subject: [PATCH 166/181] Update dataframes to PEP 585 typing --- sdks/python/apache_beam/dataframe/convert.py | 8 ++--- sdks/python/apache_beam/dataframe/doctests.py | 12 +++---- .../apache_beam/dataframe/expressions.py | 33 +++++++++---------- .../apache_beam/dataframe/frame_base.py | 8 ++--- sdks/python/apache_beam/dataframe/frames.py | 7 ++-- .../apache_beam/dataframe/frames_test.py | 3 +- .../dataframe/pandas_top_level_functions.py | 2 +- .../apache_beam/dataframe/partitionings.py | 5 ++- sdks/python/apache_beam/dataframe/schemas.py | 6 ++-- .../apache_beam/dataframe/transforms.py | 23 +++++-------- 10 files changed, 44 insertions(+), 63 deletions(-) diff --git a/sdks/python/apache_beam/dataframe/convert.py b/sdks/python/apache_beam/dataframe/convert.py index e44cc429eac1..c5a0d1025c6d 100644 --- a/sdks/python/apache_beam/dataframe/convert.py +++ b/sdks/python/apache_beam/dataframe/convert.py @@ -17,11 +17,9 @@ import inspect import warnings import weakref +from collections.abc import Iterable from typing import Any -from typing import Dict -from typing import Iterable from typing import Optional -from typing import Tuple from typing import Union import pandas as pd @@ -172,7 +170,7 @@ def to_pcollection( always_return_tuple=False, yield_elements='schemas', include_indexes=False, - pipeline=None) -> Union[pvalue.PCollection, Tuple[pvalue.PCollection, ...]]: + pipeline=None) -> Union[pvalue.PCollection, tuple[pvalue.PCollection, ...]]: """Converts one or more deferred dataframe-like objects back to a PCollection. This method creates and applies the actual Beam operations that compute @@ -252,7 +250,7 @@ def extract_input(placeholder): df for df in dataframes if df._expr._id not in TO_PCOLLECTION_CACHE ] if len(new_dataframes): - new_results: Dict[Any, pvalue.PCollection] = { + new_results: dict[Any, pvalue.PCollection] = { p: extract_input(p) for p in placeholders } | label >> transforms._DataframeExpressionsTransform( diff --git a/sdks/python/apache_beam/dataframe/doctests.py b/sdks/python/apache_beam/dataframe/doctests.py index 33faa6b58599..c57d0b0e699e 100644 --- a/sdks/python/apache_beam/dataframe/doctests.py +++ b/sdks/python/apache_beam/dataframe/doctests.py @@ -45,8 +45,6 @@ import traceback from io import StringIO from typing import Any -from typing import Dict -from typing import List import numpy as np import pandas as pd @@ -146,7 +144,7 @@ class _InMemoryResultRecorder(object): """ # Class-level value to survive pickling. - _ALL_RESULTS = {} # type: Dict[str, List[Any]] + _ALL_RESULTS = {} # type: dict[str, list[Any]] def __init__(self): self._id = id(self) @@ -729,15 +727,15 @@ def wrapper(fn): Args: optionflags (int): Passed through to doctests. - extraglobs (Dict[str,Any]): Passed through to doctests. + extraglobs (dict[str,Any]): Passed through to doctests. use_beam (bool): If true, run a Beam pipeline with partitioned input to verify the examples, else use PartitioningSession to simulate distributed execution. - skip (Dict[str,str]): A set of examples to skip entirely. + skip (dict[str,str]): A set of examples to skip entirely. If a key is '*', an example will be skipped in all test scenarios. - wont_implement_ok (Dict[str,str]): A set of examples that are allowed to + wont_implement_ok (dict[str,str]): A set of examples that are allowed to raise WontImplementError. - not_implemented_ok (Dict[str,str]): A set of examples that are allowed to + not_implemented_ok (dict[str,str]): A set of examples that are allowed to raise NotImplementedError. Returns: diff --git a/sdks/python/apache_beam/dataframe/expressions.py b/sdks/python/apache_beam/dataframe/expressions.py index 91d237c7de96..af04e06bdf6b 100644 --- a/sdks/python/apache_beam/dataframe/expressions.py +++ b/sdks/python/apache_beam/dataframe/expressions.py @@ -17,10 +17,10 @@ import contextlib import random import threading +from collections.abc import Callable +from collections.abc import Iterable from typing import Any -from typing import Callable from typing import Generic -from typing import Iterable from typing import Optional from typing import TypeVar @@ -251,9 +251,9 @@ def preserves_partition_by(self) -> partitionings.Partitioning: class PlaceholderExpression(Expression): """An expression whose value must be explicitly bound in the session.""" def __init__( - self, # type: PlaceholderExpression - proxy, # type: T - reference=None, # type: Any + self, + proxy: T, + reference: Any = None, ): """Initialize a placeholder expression. @@ -282,11 +282,7 @@ def preserves_partition_by(self): class ConstantExpression(Expression): """An expression whose value is known at pipeline construction time.""" - def __init__( - self, # type: ConstantExpression - value, # type: T - proxy=None # type: Optional[T] - ): + def __init__(self, value: T, proxy: Optional[T] = None): """Initialize a constant expression. Args: @@ -319,14 +315,15 @@ def preserves_partition_by(self): class ComputedExpression(Expression): """An expression whose value must be computed at pipeline execution time.""" def __init__( - self, # type: ComputedExpression - name, # type: str - func, # type: Callable[...,T] - args, # type: Iterable[Expression] - proxy=None, # type: Optional[T] - _id=None, # type: Optional[str] - requires_partition_by=partitionings.Index(), # type: partitionings.Partitioning - preserves_partition_by=partitionings.Singleton(), # type: partitionings.Partitioning + self, + name: str, + func: Callable[..., T], + args: Iterable[Expression], + proxy: Optional[T] = None, + _id: Optional[str] = None, + requires_partition_by: partitionings.Partitioning = partitionings.Index(), + preserves_partition_by: partitionings.Partitioning = partitionings. + Singleton(), ): """Initialize a computed expression. diff --git a/sdks/python/apache_beam/dataframe/frame_base.py b/sdks/python/apache_beam/dataframe/frame_base.py index 3b9755232e80..8e206fc5e037 100644 --- a/sdks/python/apache_beam/dataframe/frame_base.py +++ b/sdks/python/apache_beam/dataframe/frame_base.py @@ -17,15 +17,13 @@ import functools import operator import re +from collections.abc import Callable from inspect import cleandoc from inspect import getfullargspec from inspect import isclass from inspect import ismodule from inspect import unwrap from typing import Any -from typing import Callable -from typing import Dict -from typing import List from typing import Optional from typing import Tuple from typing import Union @@ -38,7 +36,7 @@ class DeferredBase(object): - _pandas_type_map: Dict[Union[type, None], type] = {} + _pandas_type_map: dict[Union[type, None], type] = {} def __init__(self, expr): self._expr = expr @@ -229,7 +227,7 @@ def _elementwise_function( def _proxy_function( func: Union[Callable, str], name: Optional[str] = None, - restrictions: Optional[Dict[str, Union[Any, List[Any]]]] = None, + restrictions: Optional[dict[str, Union[Any, list[Any]]]] = None, inplace: bool = False, base: Optional[type] = None, *, diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py index 421430ec972c..ccd01f35f87b 100644 --- a/sdks/python/apache_beam/dataframe/frames.py +++ b/sdks/python/apache_beam/dataframe/frames.py @@ -38,7 +38,6 @@ import math import re import warnings -from typing import List from typing import Optional import numpy as np @@ -2660,7 +2659,7 @@ def get(self, key, default_value=None): @frame_base.populate_defaults(pd.DataFrame) @frame_base.maybe_inplace def set_index(self, keys, **kwargs): - """``keys`` must be a ``str`` or ``List[str]``. Passing an Index or Series + """``keys`` must be a ``str`` or ``list[str]``. Passing an Index or Series is not yet supported (`Issue 20759 `_).""" if isinstance(keys, str): @@ -4574,7 +4573,7 @@ def value_counts(self, **kwargs): tshift = frame_base.wont_implement_method( DataFrameGroupBy, 'tshift', reason="deprecated") -def _maybe_project_func(projection: Optional[List[str]]): +def _maybe_project_func(projection: Optional[list[str]]): """ Returns identity func if projection is empty or None, else returns a function that projects the specified columns. """ if projection: @@ -4967,7 +4966,7 @@ def func(*args): else: raise frame_base.WontImplementError( - "others must be None, DeferredSeries, or List[DeferredSeries] " + "others must be None, DeferredSeries, or list[DeferredSeries] " f"(encountered {type(others)}). Other types are not supported " "because they make this operation sensitive to the order of the " "data.", reason="order-sensitive") diff --git a/sdks/python/apache_beam/dataframe/frames_test.py b/sdks/python/apache_beam/dataframe/frames_test.py index 55d9fc5f4dfb..f99b77e446a8 100644 --- a/sdks/python/apache_beam/dataframe/frames_test.py +++ b/sdks/python/apache_beam/dataframe/frames_test.py @@ -18,7 +18,6 @@ import sys import unittest import warnings -from typing import Dict import numpy as np import pandas as pd @@ -1707,7 +1706,7 @@ def test_pivot_no_index_provided_on_multiindex(self): 'describe')) -def numeric_only_kwargs_for_pandas_2(agg_type: str) -> Dict[str, bool]: +def numeric_only_kwargs_for_pandas_2(agg_type: str) -> dict[str, bool]: """Get proper arguments for numeric_only. Behavior for numeric_only in these methods changed in Pandas 2 to default diff --git a/sdks/python/apache_beam/dataframe/pandas_top_level_functions.py b/sdks/python/apache_beam/dataframe/pandas_top_level_functions.py index ce36dbeb09ad..a8139675ad39 100644 --- a/sdks/python/apache_beam/dataframe/pandas_top_level_functions.py +++ b/sdks/python/apache_beam/dataframe/pandas_top_level_functions.py @@ -18,7 +18,7 @@ """ import re -from typing import Mapping +from collections.abc import Mapping import pandas as pd diff --git a/sdks/python/apache_beam/dataframe/partitionings.py b/sdks/python/apache_beam/dataframe/partitionings.py index 0ff09e111480..1fe760fe8589 100644 --- a/sdks/python/apache_beam/dataframe/partitionings.py +++ b/sdks/python/apache_beam/dataframe/partitionings.py @@ -15,9 +15,8 @@ # limitations under the License. import random +from collections.abc import Iterable from typing import Any -from typing import Iterable -from typing import Tuple from typing import TypeVar import numpy as np @@ -47,7 +46,7 @@ def __le__(self, other): return not self.is_subpartitioning_of(other) def partition_fn(self, df: Frame, - num_partitions: int) -> Iterable[Tuple[Any, Frame]]: + num_partitions: int) -> Iterable[tuple[Any, Frame]]: """A callable that actually performs the partitioning of a Frame df. This will be invoked via a FlatMap in conjunction with a GroupKey to diff --git a/sdks/python/apache_beam/dataframe/schemas.py b/sdks/python/apache_beam/dataframe/schemas.py index e70229f21f77..f849ab11e77c 100644 --- a/sdks/python/apache_beam/dataframe/schemas.py +++ b/sdks/python/apache_beam/dataframe/schemas.py @@ -24,12 +24,10 @@ # pytype: skip-file import warnings +from collections.abc import Sequence from typing import Any -from typing import Dict from typing import NamedTuple from typing import Optional -from typing import Sequence -from typing import Tuple from typing import TypeVar from typing import Union @@ -170,7 +168,7 @@ def element_typehint_from_dataframe_proxy( fields = [(column, dtype_to_fieldtype(dtype)) for (column, dtype) in output_columns] - field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] + field_options: Optional[dict[str, Sequence[tuple[str, Any]]]] if include_indexes: field_options = { index_name: [(INDEX_OPTION_NAME, None)] diff --git a/sdks/python/apache_beam/dataframe/transforms.py b/sdks/python/apache_beam/dataframe/transforms.py index d0b5be4eb2a9..c8ac8174232d 100644 --- a/sdks/python/apache_beam/dataframe/transforms.py +++ b/sdks/python/apache_beam/dataframe/transforms.py @@ -16,12 +16,9 @@ import collections import logging +from collections.abc import Mapping from typing import TYPE_CHECKING from typing import Any -from typing import Dict -from typing import List -from typing import Mapping -from typing import Tuple from typing import TypeVar from typing import Union @@ -108,7 +105,7 @@ def expand(self, input_pcolls): from apache_beam.dataframe import convert # Convert inputs to a flat dict. - input_dict = _flatten(input_pcolls) # type: Dict[Any, PCollection] + input_dict = _flatten(input_pcolls) # type: dict[Any, PCollection] proxies = _flatten(self._proxy) if self._proxy is not None else { tag: None for tag in input_dict @@ -116,7 +113,7 @@ def expand(self, input_pcolls): input_frames = { k: convert.to_dataframe(pc, proxies[k]) for k, pc in input_dict.items() - } # type: Dict[Any, DeferredFrame] # noqa: F821 + } # type: dict[Any, DeferredFrame] # noqa: F821 # Apply the function. frames_input = _substitute(input_pcolls, input_frames) @@ -152,9 +149,9 @@ def expand(self, inputs): def _apply_deferred_ops( self, - inputs, # type: Dict[expressions.Expression, PCollection] - outputs, # type: Dict[Any, expressions.Expression] - ): # -> Dict[Any, PCollection] + inputs: dict[expressions.Expression, PCollection], + outputs: dict[Any, expressions.Expression], + ) -> dict[Any, PCollection]: """Construct a Beam graph that evaluates a set of expressions on a set of input PCollections. @@ -585,11 +582,9 @@ def _concat(parts): def _flatten( - valueish, # type: Union[T, List[T], Tuple[T], Dict[Any, T]] - root=(), # type: Tuple[Any, ...] - ): - # type: (...) -> Mapping[Tuple[Any, ...], T] - + valueish: Union[T, list[T], tuple[T], dict[Any, T]], + root: tuple[Any, ...] = (), +) -> Mapping[tuple[Any, ...], T]: """Given a nested structure of dicts, tuples, and lists, return a flat dictionary where the values are the leafs and the keys are the "paths" to these leaves. From b62e8c42e26992e8ae78b72924f03ba28726bf22 Mon Sep 17 00:00:00 2001 From: Steven van Rossum Date: Wed, 13 Nov 2024 16:27:41 +0100 Subject: [PATCH 167/181] [KafkaIO] Fix per-split metric updates for KafkaUnboundedReader and ReadFromKafkaDoFn (#32921) * Revert "Set backlog in gauge metric (#31137)" * Revert "Add Backlog Metrics to Kafka Splittable DoFn Implementation (#31281)" This reverts commit fd4368f1c4aba18f85e0ef95e61f7c6904a05d19. * Call reportBacklog in nextBatch to report split metrics more often * Report SDF metrics for the active split/partition after processing a record batch * Use KafkaSourceDescriptor as cache key and log entry --- .../sdk/io/kafka/KafkaUnboundedReader.java | 21 +-- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 136 ++++++++++-------- 2 files changed, 79 insertions(+), 78 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java index d86a5d0ce686..209dee14da1e 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java @@ -23,7 +23,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; -import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -227,10 +226,6 @@ public boolean advance() throws IOException { METRIC_NAMESPACE, RAW_SIZE_METRIC_PREFIX + pState.topicPartition.toString()); rawSizes.update(recordSize); - for (Map.Entry backlogSplit : perPartitionBacklogMetrics.entrySet()) { - backlogBytesOfSplit.set(backlogSplit.getValue()); - } - // Pass metrics to container. kafkaResults.updateKafkaMetrics(); return true; @@ -349,7 +344,6 @@ public long getSplitBacklogBytes() { private final Counter bytesReadBySplit; private final Gauge backlogBytesOfSplit; private final Gauge backlogElementsOfSplit; - private HashMap perPartitionBacklogMetrics = new HashMap();; private final Counter checkpointMarkCommitsEnqueued = Metrics.counter(METRIC_NAMESPACE, CHECKPOINT_MARK_COMMITS_ENQUEUED_METRIC); // Checkpoint marks skipped in favor of newer mark (only the latest needs to be committed). @@ -506,10 +500,6 @@ Instant updateAndGetWatermark() { lastWatermark = timestampPolicy.getWatermark(mkTimestampPolicyContext()); return lastWatermark; } - - String name() { - return this.topicPartition.toString(); - } } KafkaUnboundedReader( @@ -554,16 +544,14 @@ String name() { prevWatermark = Optional.of(new Instant(ckptMark.getWatermarkMillis())); } - PartitionState state = - new PartitionState( + states.add( + new PartitionState<>( tp, nextOffset, source .getSpec() .getTimestampPolicyFactory() - .createTimestampPolicy(tp, prevWatermark)); - states.add(state); - perPartitionBacklogMetrics.put(state.name(), 0L); + .createTimestampPolicy(tp, prevWatermark))); } partitionStates = ImmutableList.copyOf(states); @@ -680,6 +668,8 @@ private void nextBatch() throws IOException { partitionStates.forEach(p -> p.recordIter = records.records(p.topicPartition).iterator()); + reportBacklog(); + // cycle through the partitions in order to interleave records from each. curBatch = Iterators.cycle(new ArrayList<>(partitionStates)); } @@ -758,7 +748,6 @@ private long getSplitBacklogMessageCount() { if (pBacklog == UnboundedReader.BACKLOG_UNKNOWN) { return UnboundedReader.BACKLOG_UNKNOWN; } - perPartitionBacklogMetrics.put(p.name(), pBacklog); backlogCount += pBacklog; } diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 4d7aa6b32aef..26964d43a16f 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -19,11 +19,14 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; +import java.math.BigDecimal; +import java.math.MathContext; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; @@ -222,13 +225,12 @@ private ReadFromKafkaDoFn( // Valid between bundle start and bundle finish. private transient @Nullable Deserializer keyDeserializerInstance = null; private transient @Nullable Deserializer valueDeserializerInstance = null; - private transient @Nullable Map offsetEstimatorCache; + private transient @Nullable LoadingCache + offsetEstimatorCache; - private transient @Nullable LoadingCache avgRecordSize; + private transient @Nullable LoadingCache + avgRecordSizeCache; private static final long DEFAULT_KAFKA_POLL_TIMEOUT = 2L; - - private HashMap perPartitionBacklogMetrics = new HashMap();; - @VisibleForTesting final long consumerPollingTimeout; @VisibleForTesting final DeserializerProvider keyDeserializerProvider; @VisibleForTesting final DeserializerProvider valueDeserializerProvider; @@ -290,7 +292,7 @@ public OffsetRange initialRestriction(@Element KafkaSourceDescriptor kafkaSource Map updatedConsumerConfig = overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); TopicPartition partition = kafkaSourceDescriptor.getTopicPartition(); - LOG.info("Creating Kafka consumer for initial restriction for {}", partition); + LOG.info("Creating Kafka consumer for initial restriction for {}", kafkaSourceDescriptor); try (Consumer offsetConsumer = consumerFactoryFn.apply(updatedConsumerConfig)) { ConsumerSpEL.evaluateAssign(offsetConsumer, ImmutableList.of(partition)); long startOffset; @@ -337,38 +339,31 @@ public WatermarkEstimator newWatermarkEstimator( @GetSize public double getSize( @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange offsetRange) - throws Exception { + throws ExecutionException { // If present, estimates the record size to offset gap ratio. Compacted topics may hold less // records than the estimated offset range due to record deletion within a partition. - final LoadingCache avgRecordSize = - Preconditions.checkStateNotNull(this.avgRecordSize); + final LoadingCache avgRecordSizeCache = + Preconditions.checkStateNotNull(this.avgRecordSizeCache); + final @Nullable AverageRecordSize avgRecordSize = + avgRecordSizeCache.getIfPresent(kafkaSourceDescriptor); // The tracker estimates the offset range by subtracting the last claimed position from the // currently observed end offset for the partition belonging to this split. double estimatedOffsetRange = restrictionTracker(kafkaSourceDescriptor, offsetRange).getProgress().getWorkRemaining(); // Before processing elements, we don't have a good estimated size of records and offset gap. // Return the estimated offset range without scaling by a size to gap ratio. - if (!avgRecordSize.asMap().containsKey(kafkaSourceDescriptor.getTopicPartition())) { + if (avgRecordSize == null) { return estimatedOffsetRange; } - if (offsetEstimatorCache != null) { - for (Map.Entry tp : - offsetEstimatorCache.entrySet()) { - perPartitionBacklogMetrics.put(tp.getKey().toString(), tp.getValue().estimate()); - } - } - // When processing elements, a moving average estimates the size of records and offset gap. // Return the estimated offset range scaled by the estimated size to gap ratio. - return estimatedOffsetRange - * avgRecordSize - .get(kafkaSourceDescriptor.getTopicPartition()) - .estimateRecordByteSizeToOffsetCountRatio(); + return estimatedOffsetRange * avgRecordSize.estimateRecordByteSizeToOffsetCountRatio(); } @NewTracker public OffsetRangeTracker restrictionTracker( - @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange restriction) { + @Element KafkaSourceDescriptor kafkaSourceDescriptor, @Restriction OffsetRange restriction) + throws ExecutionException { if (restriction.getTo() < Long.MAX_VALUE) { return new OffsetRangeTracker(restriction); } @@ -376,24 +371,10 @@ public OffsetRangeTracker restrictionTracker( // OffsetEstimators are cached for each topic-partition because they hold a stateful connection, // so we want to minimize the amount of connections that we start and track with Kafka. Another // point is that it has a memoized backlog, and this should make that more reusable estimations. - final Map offsetEstimatorCacheInstance = + final LoadingCache offsetEstimatorCache = Preconditions.checkStateNotNull(this.offsetEstimatorCache); - - TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition(); - KafkaLatestOffsetEstimator offsetEstimator = offsetEstimatorCacheInstance.get(topicPartition); - if (offsetEstimator == null) { - Map updatedConsumerConfig = - overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); - - LOG.info("Creating Kafka consumer for offset estimation for {}", topicPartition); - - Consumer offsetConsumer = - consumerFactoryFn.apply( - KafkaIOUtils.getOffsetConsumerConfig( - "tracker-" + topicPartition, offsetConsumerConfig, updatedConsumerConfig)); - offsetEstimator = new KafkaLatestOffsetEstimator(offsetConsumer, topicPartition); - offsetEstimatorCacheInstance.put(topicPartition, offsetEstimator); - } + final KafkaLatestOffsetEstimator offsetEstimator = + offsetEstimatorCache.get(kafkaSourceDescriptor); return new GrowableOffsetRangeTracker(restriction.getFrom(), offsetEstimator); } @@ -405,22 +386,22 @@ public ProcessContinuation processElement( WatermarkEstimator watermarkEstimator, MultiOutputReceiver receiver) throws Exception { - final LoadingCache avgRecordSize = - Preconditions.checkStateNotNull(this.avgRecordSize); + final LoadingCache avgRecordSizeCache = + Preconditions.checkStateNotNull(this.avgRecordSizeCache); + final LoadingCache offsetEstimatorCache = + Preconditions.checkStateNotNull(this.offsetEstimatorCache); final Deserializer keyDeserializerInstance = Preconditions.checkStateNotNull(this.keyDeserializerInstance); final Deserializer valueDeserializerInstance = Preconditions.checkStateNotNull(this.valueDeserializerInstance); + final TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition(); + final AverageRecordSize avgRecordSize = avgRecordSizeCache.get(kafkaSourceDescriptor); + // TODO: Metrics should be reported per split instead of partition, add bootstrap server hash? final Distribution rawSizes = - Metrics.distribution( - METRIC_NAMESPACE, - RAW_SIZE_METRIC_PREFIX + kafkaSourceDescriptor.getTopicPartition().toString()); - for (Map.Entry backlogSplit : perPartitionBacklogMetrics.entrySet()) { - Gauge backlog = - Metrics.gauge( - METRIC_NAMESPACE, RAW_SIZE_METRIC_PREFIX + "backlogBytes_" + backlogSplit.getKey()); - backlog.set(backlogSplit.getValue()); - } + Metrics.distribution(METRIC_NAMESPACE, RAW_SIZE_METRIC_PREFIX + topicPartition.toString()); + final Gauge backlogBytes = + Metrics.gauge( + METRIC_NAMESPACE, RAW_SIZE_METRIC_PREFIX + "backlogBytes_" + topicPartition.toString()); // Stop processing current TopicPartition when it's time to stop. if (checkStopReadingFn != null @@ -438,13 +419,10 @@ public ProcessContinuation processElement( if (timestampPolicyFactory != null) { timestampPolicy = timestampPolicyFactory.createTimestampPolicy( - kafkaSourceDescriptor.getTopicPartition(), - Optional.ofNullable(watermarkEstimator.currentWatermark())); + topicPartition, Optional.ofNullable(watermarkEstimator.currentWatermark())); } - LOG.info( - "Creating Kafka consumer for process continuation for {}", - kafkaSourceDescriptor.getTopicPartition()); + LOG.info("Creating Kafka consumer for process continuation for {}", kafkaSourceDescriptor); try (Consumer consumer = consumerFactoryFn.apply(updatedConsumerConfig)) { ConsumerSpEL.evaluateAssign( consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition())); @@ -518,8 +496,8 @@ public ProcessContinuation processElement( int recordSize = (rawRecord.key() == null ? 0 : rawRecord.key().length) + (rawRecord.value() == null ? 0 : rawRecord.value().length); - avgRecordSize - .getUnchecked(kafkaSourceDescriptor.getTopicPartition()) + avgRecordSizeCache + .getUnchecked(kafkaSourceDescriptor) .update(recordSize, rawRecord.offset() - expectedOffset); rawSizes.update(recordSize); expectedOffset = rawRecord.offset() + 1; @@ -551,6 +529,15 @@ public ProcessContinuation processElement( } } } + + backlogBytes.set( + (long) + (BigDecimal.valueOf( + Preconditions.checkStateNotNull( + offsetEstimatorCache.get(kafkaSourceDescriptor).estimate())) + .subtract(BigDecimal.valueOf(expectedOffset), MathContext.DECIMAL128) + .doubleValue() + * avgRecordSize.estimateRecordByteSizeToOffsetCountRatio())); } } } @@ -611,19 +598,44 @@ public Coder restrictionCoder() { @Setup public void setup() throws Exception { // Start to track record size and offset gap per bundle. - avgRecordSize = + avgRecordSizeCache = CacheBuilder.newBuilder() .maximumSize(1000L) .build( - new CacheLoader() { + new CacheLoader() { @Override - public AverageRecordSize load(TopicPartition topicPartition) throws Exception { + public AverageRecordSize load(KafkaSourceDescriptor kafkaSourceDescriptor) + throws Exception { return new AverageRecordSize(); } }); keyDeserializerInstance = keyDeserializerProvider.getDeserializer(consumerConfig, true); valueDeserializerInstance = valueDeserializerProvider.getDeserializer(consumerConfig, false); - offsetEstimatorCache = new HashMap<>(); + offsetEstimatorCache = + CacheBuilder.newBuilder() + .weakValues() + .expireAfterAccess(1, TimeUnit.MINUTES) + .build( + new CacheLoader() { + @Override + public KafkaLatestOffsetEstimator load( + KafkaSourceDescriptor kafkaSourceDescriptor) throws Exception { + LOG.info( + "Creating Kafka consumer for offset estimation for {}", + kafkaSourceDescriptor); + + TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition(); + Map updatedConsumerConfig = + overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); + Consumer offsetConsumer = + consumerFactoryFn.apply( + KafkaIOUtils.getOffsetConsumerConfig( + "tracker-" + topicPartition, + offsetConsumerConfig, + updatedConsumerConfig)); + return new KafkaLatestOffsetEstimator(offsetConsumer, topicPartition); + } + }); if (checkStopReadingFn != null) { checkStopReadingFn.setup(); } @@ -645,7 +657,7 @@ public void teardown() throws Exception { } if (offsetEstimatorCache != null) { - offsetEstimatorCache.clear(); + offsetEstimatorCache.invalidateAll(); } if (checkStopReadingFn != null) { checkStopReadingFn.teardown(); From ff5feed53c217ead9c3b0e178b475871627769a9 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Wed, 13 Nov 2024 11:00:35 -0500 Subject: [PATCH 168/181] Introduce pipeline options to disable user counter and user stringset (#33059) --- .../org/apache/beam/sdk/io/FileSystems.java | 6 +- .../beam/sdk/metrics/DelegatingCounter.java | 4 ++ .../org/apache/beam/sdk/metrics/Metrics.java | 60 +++++++++++++++++++ .../apache/beam/sdk/metrics/MetricsTest.java | 21 +++++++ 4 files changed, 90 insertions(+), 1 deletion(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java index 5ca22749b163..7e2940a2c35b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileSystems.java @@ -51,6 +51,7 @@ import org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions; import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.metrics.Lineage; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.common.ReflectHelpers; import org.apache.beam.sdk.values.KV; @@ -567,7 +568,7 @@ static FileSystem getFileSystemInternal(String scheme) { * *

    Outside of workers where Beam FileSystem API is used (e.g. test methods, user code executed * during pipeline submission), consider use {@link #registerFileSystemsOnce} if initialize - * FIleSystem of supported schema is the main goal. + * FileSystem of supported schema is the main goal. */ @Internal public static void setDefaultPipelineOptions(PipelineOptions options) { @@ -575,6 +576,9 @@ public static void setDefaultPipelineOptions(PipelineOptions options) { long id = options.getOptionsId(); int nextRevision = options.revision(); + // entry to set other PipelineOption determined flags + Metrics.setDefaultPipelineOptions(options); + while (true) { KV revision = FILESYSTEM_REVISION.get(); // only update file systems if the pipeline changed or the options revision increased diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DelegatingCounter.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DelegatingCounter.java index a0b2e3b34678..7e8252d4fb3f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DelegatingCounter.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/DelegatingCounter.java @@ -19,6 +19,7 @@ import java.io.Serializable; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.metrics.Metrics.MetricsFlag; /** Implementation of {@link Counter} that delegates to the instance for the current context. */ @Internal @@ -70,6 +71,9 @@ public void inc() { /** Increment the counter by the given amount. */ @Override public void inc(long n) { + if (MetricsFlag.counterDisabled()) { + return; + } MetricsContainer container = this.processWideContainer ? MetricsEnvironment.getProcessWideContainer() diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java index a963015e98a7..6c8179006640 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java @@ -18,6 +18,13 @@ package org.apache.beam.sdk.metrics; import java.io.Serializable; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * The Metrics is a utility class for producing various kinds of metrics for reporting @@ -50,9 +57,59 @@ * example off how to query metrics. */ public class Metrics { + private static final Logger LOG = LoggerFactory.getLogger(Metrics.class); private Metrics() {} + static class MetricsFlag { + private static final AtomicReference<@Nullable MetricsFlag> INSTANCE = new AtomicReference<>(); + final boolean counterDisabled; + final boolean stringSetDisabled; + + private MetricsFlag(boolean counterDisabled, boolean stringSetDisabled) { + this.counterDisabled = counterDisabled; + this.stringSetDisabled = stringSetDisabled; + } + + static boolean counterDisabled() { + MetricsFlag flag = INSTANCE.get(); + return flag != null && flag.counterDisabled; + } + + static boolean stringSetDisabled() { + MetricsFlag flag = INSTANCE.get(); + return flag != null && flag.stringSetDisabled; + } + } + + /** + * Initialize metrics flags if not already done so. + * + *

    Should be called by worker at worker harness initialization. Should not be called by user + * code (and it does not have an effect as the initialization completed before). + */ + @Internal + public static void setDefaultPipelineOptions(PipelineOptions options) { + MetricsFlag flag = MetricsFlag.INSTANCE.get(); + if (flag == null) { + ExperimentalOptions exp = options.as(ExperimentalOptions.class); + boolean counterDisabled = ExperimentalOptions.hasExperiment(exp, "disableCounterMetrics"); + if (counterDisabled) { + LOG.info("Counter metrics are disabled."); + } + boolean stringSetDisabled = ExperimentalOptions.hasExperiment(exp, "disableStringSetMetrics"); + if (stringSetDisabled) { + LOG.info("StringSet metrics are disabled"); + } + MetricsFlag.INSTANCE.compareAndSet(null, new MetricsFlag(counterDisabled, stringSetDisabled)); + } + } + + @Internal + static void resetDefaultPipelineOptions() { + MetricsFlag.INSTANCE.set(null); + } + /** * Create a metric that can be incremented and decremented, and is aggregated by taking the sum. */ @@ -174,6 +231,9 @@ private DelegatingStringSet(MetricName name) { @Override public void add(String value) { + if (MetricsFlag.stringSetDisabled()) { + return; + } MetricsContainer container = MetricsEnvironment.getCurrentContainer(); if (container != null) { container.getStringSet(name).add(value); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java index 750d43a4f9ae..662c4f52628a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/metrics/MetricsTest.java @@ -24,7 +24,9 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.hasItem; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -39,6 +41,7 @@ import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.UsesAttemptedMetrics; @@ -245,6 +248,24 @@ public void testCounterToCell() { counter.dec(5L); verify(mockCounter).inc(-5); } + + @Test + public void testMetricsFlag() { + Metrics.resetDefaultPipelineOptions(); + assertFalse(Metrics.MetricsFlag.counterDisabled()); + assertFalse(Metrics.MetricsFlag.stringSetDisabled()); + PipelineOptions options = + PipelineOptionsFactory.fromArgs("--experiments=disableCounterMetrics").create(); + Metrics.setDefaultPipelineOptions(options); + assertTrue(Metrics.MetricsFlag.counterDisabled()); + assertFalse(Metrics.MetricsFlag.stringSetDisabled()); + Metrics.resetDefaultPipelineOptions(); + options = PipelineOptionsFactory.fromArgs("--experiments=disableStringSetMetrics").create(); + Metrics.setDefaultPipelineOptions(options); + assertFalse(Metrics.MetricsFlag.counterDisabled()); + assertTrue(Metrics.MetricsFlag.stringSetDisabled()); + Metrics.resetDefaultPipelineOptions(); + } } /** Tests for committed metrics. */ From c6a7354b32f522d98f9f7b0aa595fd5161c4b257 Mon Sep 17 00:00:00 2001 From: Israel Herraiz Date: Wed, 13 Nov 2024 17:11:37 +0100 Subject: [PATCH 169/181] SolaceIO write connector (#32060) * This is a follow-up PR to #31953, and part of the issue #31905. This PR adds the actual writer functionality, and some additional testing, including integration testing. This should be final PR for the SolaceIO write connector to be complete. * Use static imports for Preconditions * Remove unused method * Logging has builtin formatting support * Use TypeDescriptors to check the type used as input * Fix parameter name * Use interface + utils class for MessageProducer * Use null instead of optional * Avoid using ByteString just to create an empty byte array. * Fix documentation, we are not using ByteString now. * Not needed anymore, we are not using ByteString * Defer transforming latency from nanos to millis. The transform into millis is done at the presentation moment, when the metric is reported to Beam. * Avoid using top level classes with a single inner class. A couple of DoFns are moved to their own files too, as the abstract class forthe UnboundedSolaceWriter was in practice a "package". This commits addresses a few comments about the structure of UnboundedSolaceWriter and some base classes of that abstract class. * Remove using a state variable, there is already a timer. This DoFn is a stateful DoFn to force a shuffling with a given input key set cardinality. * Properties must always be set. The warnings are only shown if the user decided to set the properties that are overriden by the connector. This was changed in one of the previous commits but it is actually a bug. I am reverting that change and changing this to a switch block, to make it more clear that the properties need to be set always by the connector. * Add a new custom mode so no JCSMP property is overridden. This lets the user to fully control all the properties used by the connector, instead of making sensible choices on its behalf. This also adds some logging to be more explicit about what the connector is doing. This does not add too much logging pressure, this only adds logging at the producer creation moment. * Add some more documentation about the new custom submission mode. * Fix bug introduced with the refactoring of code for this PR. I forgot to pass the submission mode when the write session is created, and I called the wrong method in the base class because it was defined as public. This makes sure that the submission mode is passed to the session when the session is created for writing messages. * Remove unnecessary Serializable annotation. * Make the PublishResult class for handling callbacks non-static to handle pipelines with multiple write transforms. * Rename maxNumOfUsedWorkers to numShards * Use RoundRobin assignment of producers to process bundles. * Output results in a GlobalWindow * Add ErrorHandler * Fix docs * Remove PublishResultHandler class that was just a wrapper around a Queue * small refactors * Revert CsvIO docs fix * Add withErrorHandler docs * fix var scope --------- Co-authored-by: Bartosz Zablocki --- CHANGES.md | 1 + sdks/java/io/solace/build.gradle | 1 + .../apache/beam/sdk/io/solace/SolaceIO.java | 201 ++++++++-- .../broker/BasicAuthJcsmpSessionService.java | 150 +++++-- .../BasicAuthJcsmpSessionServiceFactory.java | 22 +- .../GCPSecretSessionServiceFactory.java | 2 +- .../sdk/io/solace/broker/MessageProducer.java | 61 +++ .../solace/broker/MessageProducerUtils.java | 110 ++++++ .../solace/broker/PublishResultHandler.java | 100 +++++ .../sdk/io/solace/broker/SessionService.java | 160 +++++--- .../solace/broker/SessionServiceFactory.java | 64 ++- .../solace/broker/SolaceMessageProducer.java | 87 ++++ .../beam/sdk/io/solace/data/Solace.java | 88 +++-- .../io/solace/read/UnboundedSolaceReader.java | 19 +- .../sdk/io/solace/write/AddShardKeyDoFn.java | 45 +++ .../write/RecordToPublishResultDoFn.java | 41 ++ .../sdk/io/solace/write/SolaceOutput.java | 34 +- .../write/SolaceWriteSessionsHandler.java | 112 ++++++ .../write/UnboundedBatchedSolaceWriter.java | 164 ++++++++ .../solace/write/UnboundedSolaceWriter.java | 373 ++++++++++++++++++ .../write/UnboundedStreamingSolaceWriter.java | 138 +++++++ .../io/solace/MockEmptySessionService.java | 24 +- .../beam/sdk/io/solace/MockProducer.java | 110 ++++++ .../sdk/io/solace/MockSessionService.java | 101 +++-- .../io/solace/MockSessionServiceFactory.java | 68 +++- ...olaceIOTest.java => SolaceIOReadTest.java} | 260 ++++++------ .../beam/sdk/io/solace/SolaceIOWriteTest.java | 208 ++++++++++ .../broker/OverrideWriterPropertiesTest.java | 20 +- .../sdk/io/solace/data/SolaceDataUtils.java | 4 +- .../beam/sdk/io/solace/it/SolaceIOIT.java | 132 ++++++- 30 files changed, 2508 insertions(+), 392 deletions(-) create mode 100644 sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducer.java create mode 100644 sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducerUtils.java create mode 100644 sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/PublishResultHandler.java create mode 100644 sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageProducer.java create mode 100644 sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/AddShardKeyDoFn.java create mode 100644 sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/RecordToPublishResultDoFn.java create mode 100644 sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceWriteSessionsHandler.java create mode 100644 sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedBatchedSolaceWriter.java create mode 100644 sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedSolaceWriter.java create mode 100644 sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedStreamingSolaceWriter.java create mode 100644 sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockProducer.java rename sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/{SolaceIOTest.java => SolaceIOReadTest.java} (72%) create mode 100644 sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOWriteTest.java diff --git a/CHANGES.md b/CHANGES.md index c5731bcff313..6962b0fb8ded 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -70,6 +70,7 @@ * [Managed Iceberg] Add support for TIMESTAMP, TIME, and DATE types ([#32688](https://github.com/apache/beam/pull/32688)) * BigQuery CDC writes are now available in Python SDK, only supported when using StorageWrite API at least once mode ([#32527](https://github.com/apache/beam/issues/32527)) * [Managed Iceberg] Allow updating table partition specs during pipeline runtime ([#32879](https://github.com/apache/beam/pull/32879)) +* Support for writing to [Solace messages queues](https://solace.com/) (`SolaceIO.Write`) added (Java) ([#31905](https://github.com/apache/beam/issues/31905)). ## New Features / Improvements diff --git a/sdks/java/io/solace/build.gradle b/sdks/java/io/solace/build.gradle index 741db51a5772..ef0d49891f08 100644 --- a/sdks/java/io/solace/build.gradle +++ b/sdks/java/io/solace/build.gradle @@ -53,6 +53,7 @@ dependencies { testImplementation library.java.junit testImplementation project(path: ":sdks:java:io:common") testImplementation project(path: ":sdks:java:testing:test-utils") + testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") testRuntimeOnly library.java.slf4j_jdk14 testImplementation library.java.testcontainers_solace testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/SolaceIO.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/SolaceIO.java index dcfdcc4fabb9..a55d8a0a4217 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/SolaceIO.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/SolaceIO.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.solace; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; @@ -38,16 +39,29 @@ import org.apache.beam.sdk.io.solace.broker.SessionService; import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.Record; import org.apache.beam.sdk.io.solace.data.Solace.SolaceRecordMapper; import org.apache.beam.sdk.io.solace.read.UnboundedSolaceSource; +import org.apache.beam.sdk.io.solace.write.AddShardKeyDoFn; import org.apache.beam.sdk.io.solace.write.SolaceOutput; +import org.apache.beam.sdk.io.solace.write.UnboundedBatchedSolaceWriter; +import org.apache.beam.sdk.io.solace.write.UnboundedStreamingSolaceWriter; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.checkerframework.checker.nullness.qual.Nullable; @@ -147,7 +161,7 @@ * function. * *

    {@code
    - * @DefaultSchema(JavaBeanSchema.class)
    + * {@literal @}DefaultSchema(JavaBeanSchema.class)
      * public static class SimpleRecord {
      *    public String payload;
      *    public String messageId;
    @@ -238,7 +252,7 @@
      * default VPN name by setting the required JCSMP property in the session factory (in this case,
      * with {@link BasicAuthJcsmpSessionServiceFactory#vpnName()}), the number of clients per worker
      * with {@link Write#withNumberOfClientsPerWorker(int)} and the number of parallel write clients
    - * using {@link Write#withMaxNumOfUsedWorkers(int)}.
    + * using {@link Write#withNumShards(int)}.
      *
      * 

    Writing to dynamic destinations

    * @@ -345,13 +359,17 @@ * *

    The streaming connector publishes each message individually, without holding up or batching * before the message is sent to Solace. This will ensure the lowest possible latency, but it will - * offer a much lower throughput. The streaming connector does not use state & timers. + * offer a much lower throughput. The streaming connector does not use state and timers. * - *

    Both connectors uses state & timers to control the level of parallelism. If you are using + *

    Both connectors uses state and timers to control the level of parallelism. If you are using * Cloud Dataflow, it is recommended that you enable Streaming Engine to use this * connector. * + *

    For full control over all the properties, use {@link SubmissionMode#CUSTOM}. The connector + * will not override any property that you set, and you will have full control over all the JCSMP + * properties. + * *

    Authentication

    * *

    When writing to Solace, the user must use {@link @@ -396,7 +414,7 @@ public class SolaceIO { private static final boolean DEFAULT_DEDUPLICATE_RECORDS = false; private static final Duration DEFAULT_WATERMARK_IDLE_DURATION_THRESHOLD = Duration.standardSeconds(30); - public static final int DEFAULT_WRITER_MAX_NUMBER_OF_WORKERS = 20; + public static final int DEFAULT_WRITER_NUM_SHARDS = 20; public static final int DEFAULT_WRITER_CLIENTS_PER_WORKER = 4; public static final Boolean DEFAULT_WRITER_PUBLISH_LATENCY_METRICS = false; public static final SubmissionMode DEFAULT_WRITER_SUBMISSION_MODE = @@ -445,6 +463,7 @@ public static Read read() { .setDeduplicateRecords(DEFAULT_DEDUPLICATE_RECORDS) .setWatermarkIdleDurationThreshold(DEFAULT_WATERMARK_IDLE_DURATION_THRESHOLD)); } + /** * Create a {@link Read} transform, to read from Solace. Specify a {@link SerializableFunction} to * map incoming {@link BytesXMLMessage} records, to the object of your choice. You also need to @@ -805,7 +824,9 @@ private Queue initializeQueueForTopicIfNeeded( public enum SubmissionMode { HIGHER_THROUGHPUT, - LOWER_LATENCY + LOWER_LATENCY, + CUSTOM, // Don't override any property set by the user + TESTING // Send acks 1 by 1, this will be very slow, never use this in an actual pipeline! } public enum WriterType { @@ -816,8 +837,9 @@ public enum WriterType { @AutoValue public abstract static class Write extends PTransform, SolaceOutput> { - public static final TupleTag FAILED_PUBLISH_TAG = - new TupleTag() {}; + private static final Logger LOG = LoggerFactory.getLogger(Write.class); + + public static final TupleTag FAILED_PUBLISH_TAG = new TupleTag() {}; public static final TupleTag SUCCESSFUL_PUBLISH_TAG = new TupleTag() {}; @@ -863,8 +885,8 @@ public Write to(Solace.Queue queue) { * cluster, and the need for performance when writing to Solace (more workers will achieve * higher throughput). */ - public Write withMaxNumOfUsedWorkers(int maxNumOfUsedWorkers) { - return toBuilder().setMaxNumOfUsedWorkers(maxNumOfUsedWorkers).build(); + public Write withNumShards(int numShards) { + return toBuilder().setNumShards(numShards).build(); } /** @@ -877,8 +899,8 @@ public Write withMaxNumOfUsedWorkers(int maxNumOfUsedWorkers) { * the number of clients created per VM. The clients will be re-used across different threads in * the same worker. * - *

    Set this number in combination with {@link #withMaxNumOfUsedWorkers}, to ensure that the - * limit for number of clients in your Solace cluster is not exceeded. + *

    Set this number in combination with {@link #withNumShards}, to ensure that the limit for + * number of clients in your Solace cluster is not exceeded. * *

    Normally, using a higher number of clients with fewer workers will achieve better * throughput at a lower cost, since the workers are better utilized. A good rule of thumb to @@ -921,15 +943,19 @@ public Write publishLatencyMetrics() { *

    For full details, please check https://docs.solace.com/API/API-Developer-Guide/Java-API-Best-Practices.htm. * - *

    The Solace JCSMP client libraries can dispatch messages using two different modes: + *

    The Solace JCSMP client libraries can dispatch messages using three different modes: * *

    One of the modes dispatches messages directly from the same thread that is doing the rest * of I/O work. This mode favors lower latency but lower throughput. Set this to LOWER_LATENCY * to use that mode (MESSAGE_CALLBACK_ON_REACTOR set to True). * - *

    The other mode uses a parallel thread to accumulate and dispatch messages. This mode - * favors higher throughput but also has higher latency. Set this to HIGHER_THROUGHPUT to use - * that mode. This is the default mode (MESSAGE_CALLBACK_ON_REACTOR set to False). + *

    Another mode uses a parallel thread to accumulate and dispatch messages. This mode favors + * higher throughput but also has higher latency. Set this to HIGHER_THROUGHPUT to use that + * mode. This is the default mode (MESSAGE_CALLBACK_ON_REACTOR set to False). + * + *

    If you prefer to have full control over all the JCSMP properties, set this to CUSTOM, and + * override the classes {@link SessionServiceFactory} and {@link SessionService} to have full + * control on how to create the JCSMP sessions and producers used by the connector. * *

    This is optional, the default value is HIGHER_THROUGHPUT. */ @@ -945,10 +971,12 @@ public Write withSubmissionMode(SubmissionMode submissionMode) { *

    In streaming mode, the publishing latency will be lower, but the throughput will also be * lower. * - *

    With the batched mode, messages are accumulated until a batch size of 50 is reached, or 5 - * seconds have elapsed since the first message in the batch was received. The 50 messages are - * sent to Solace in a single batch. This writer offers higher throughput but higher publishing - * latency, as messages can be held up for up to 5 seconds until they are published. + *

    With the batched mode, messages are accumulated until a batch size of 50 is reached, or + * {@link UnboundedBatchedSolaceWriter#ACKS_FLUSHING_INTERVAL_SECS} seconds have elapsed since + * the first message in the batch was received. The 50 messages are sent to Solace in a single + * batch. This writer offers higher throughput but higher publishing latency, as messages can be + * held up for up to {@link UnboundedBatchedSolaceWriter#ACKS_FLUSHING_INTERVAL_SECS}5seconds + * until they are published. * *

    Notice that this is the message publishing latency, not the end-to-end latency. For very * large scale pipelines, you will probably prefer to use the HIGHER_THROUGHPUT mode, as with @@ -971,7 +999,20 @@ public Write withSessionServiceFactory(SessionServiceFactory factory) { return toBuilder().setSessionServiceFactory(factory).build(); } - abstract int getMaxNumOfUsedWorkers(); + /** + * An optional error handler for handling records that failed to publish to Solace. + * + *

    If provided, this error handler will be invoked for each record that could not be + * successfully published. The error handler can implement custom logic for dealing with failed + * records, such as writing them to a dead-letter queue or logging them. + * + *

    If no error handler is provided, failed records will be ignored. + */ + public Write withErrorHandler(ErrorHandler errorHandler) { + return toBuilder().setErrorHandler(errorHandler).build(); + } + + abstract int getNumShards(); abstract int getNumberOfClientsPerWorker(); @@ -989,10 +1030,12 @@ public Write withSessionServiceFactory(SessionServiceFactory factory) { abstract @Nullable SessionServiceFactory getSessionServiceFactory(); + abstract @Nullable ErrorHandler getErrorHandler(); + static Builder builder() { return new AutoValue_SolaceIO_Write.Builder() .setDeliveryMode(DEFAULT_WRITER_DELIVERY_MODE) - .setMaxNumOfUsedWorkers(DEFAULT_WRITER_MAX_NUMBER_OF_WORKERS) + .setNumShards(DEFAULT_WRITER_NUM_SHARDS) .setNumberOfClientsPerWorker(DEFAULT_WRITER_CLIENTS_PER_WORKER) .setPublishLatencyMetrics(DEFAULT_WRITER_PUBLISH_LATENCY_METRICS) .setDispatchMode(DEFAULT_WRITER_SUBMISSION_MODE) @@ -1003,7 +1046,7 @@ static Builder builder() { @AutoValue.Builder abstract static class Builder { - abstract Builder setMaxNumOfUsedWorkers(int maxNumOfUsedWorkers); + abstract Builder setNumShards(int numShards); abstract Builder setNumberOfClientsPerWorker(int numberOfClientsPerWorker); @@ -1021,13 +1064,121 @@ abstract static class Builder { abstract Builder setSessionServiceFactory(SessionServiceFactory factory); + abstract Builder setErrorHandler(ErrorHandler errorHandler); + abstract Write build(); } @Override public SolaceOutput expand(PCollection input) { - // TODO: will be sent in upcoming PR - return SolaceOutput.in(input.getPipeline(), null, null); + boolean usingSolaceRecord = + TypeDescriptor.of(Solace.Record.class) + .isSupertypeOf(checkNotNull(input.getTypeDescriptor())); + + validateWriteTransform(usingSolaceRecord); + + boolean usingDynamicDestinations = getDestination() == null; + SerializableFunction destinationFn; + if (usingDynamicDestinations) { + destinationFn = x -> SolaceIO.convertToJcsmpDestination(checkNotNull(x.getDestination())); + } else { + // Constant destination for all messages (same topic or queue) + // This should not be non-null, as nulls would have been flagged by the + // validateWriteTransform method + destinationFn = x -> checkNotNull(getDestination()); + } + + @SuppressWarnings("unchecked") + PCollection records = + usingSolaceRecord + ? (PCollection) input + : input.apply( + "Format records", + MapElements.into(TypeDescriptor.of(Solace.Record.class)) + .via(checkNotNull(getFormatFunction()))); + + PCollection withGlobalWindow = + records.apply("Global window", Window.into(new GlobalWindows())); + + PCollection> withShardKeys = + withGlobalWindow.apply("Add shard key", ParDo.of(new AddShardKeyDoFn(getNumShards()))); + + String label = + getWriterType() == WriterType.STREAMING ? "Publish (streaming)" : "Publish (batched)"; + + PCollectionTuple solaceOutput = withShardKeys.apply(label, getWriterTransform(destinationFn)); + + SolaceOutput output; + if (getDeliveryMode() == DeliveryMode.PERSISTENT) { + if (getErrorHandler() != null) { + checkNotNull(getErrorHandler()).addErrorCollection(solaceOutput.get(FAILED_PUBLISH_TAG)); + } + output = SolaceOutput.in(input.getPipeline(), solaceOutput.get(SUCCESSFUL_PUBLISH_TAG)); + } else { + LOG.info( + "Solace.Write: omitting writer output because delivery mode is {}", getDeliveryMode()); + output = SolaceOutput.in(input.getPipeline(), null); + } + + return output; + } + + private ParDo.MultiOutput, Solace.PublishResult> getWriterTransform( + SerializableFunction destinationFn) { + + ParDo.SingleOutput, Solace.PublishResult> writer = + ParDo.of( + getWriterType() == WriterType.STREAMING + ? new UnboundedStreamingSolaceWriter( + destinationFn, + checkNotNull(getSessionServiceFactory()), + getDeliveryMode(), + getDispatchMode(), + getNumberOfClientsPerWorker(), + getPublishLatencyMetrics()) + : new UnboundedBatchedSolaceWriter( + destinationFn, + checkNotNull(getSessionServiceFactory()), + getDeliveryMode(), + getDispatchMode(), + getNumberOfClientsPerWorker(), + getPublishLatencyMetrics())); + + return writer.withOutputTags(SUCCESSFUL_PUBLISH_TAG, TupleTagList.of(FAILED_PUBLISH_TAG)); + } + + /** + * Called before running the Pipeline to verify this transform is fully and correctly specified. + */ + private void validateWriteTransform(boolean usingSolaceRecords) { + if (!usingSolaceRecords) { + checkNotNull( + getFormatFunction(), + "SolaceIO.Write: If you are not using Solace.Record as the input type, you" + + " must set a format function using withFormatFunction()."); + } + + checkArgument( + getNumShards() > 0, "SolaceIO.Write: The number of used workers must be positive."); + checkArgument( + getNumberOfClientsPerWorker() > 0, + "SolaceIO.Write: The number of clients per worker must be positive."); + checkArgument( + getDeliveryMode() == DeliveryMode.DIRECT || getDeliveryMode() == DeliveryMode.PERSISTENT, + String.format( + "SolaceIO.Write: Delivery mode must be either DIRECT or PERSISTENT. %s" + + " not supported", + getDeliveryMode())); + if (getPublishLatencyMetrics()) { + checkArgument( + getDeliveryMode() == DeliveryMode.PERSISTENT, + "SolaceIO.Write: Publish latency metrics can only be enabled for PERSISTENT" + + " delivery mode."); + } + checkNotNull( + getSessionServiceFactory(), + "SolaceIO: You need to pass a session service factory. For basic" + + " authentication, you can use BasicAuthJcsmpSessionServiceFactory."); } } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java index 2137d574b09a..b2196dbf1067 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionService.java @@ -19,6 +19,7 @@ import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import com.google.auto.value.AutoValue; import com.solacesystems.jcsmp.ConsumerFlowProperties; import com.solacesystems.jcsmp.EndpointProperties; import com.solacesystems.jcsmp.FlowReceiver; @@ -28,9 +29,15 @@ import com.solacesystems.jcsmp.JCSMPProperties; import com.solacesystems.jcsmp.JCSMPSession; import com.solacesystems.jcsmp.Queue; +import com.solacesystems.jcsmp.XMLMessageProducer; import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentLinkedQueue; import javax.annotation.Nullable; import org.apache.beam.sdk.io.solace.RetryCallableManager; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.data.Solace.PublishResult; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; /** @@ -39,34 +46,50 @@ *

    This class provides a way to connect to a Solace broker and receive messages from a queue. The * connection is established using basic authentication. */ -public class BasicAuthJcsmpSessionService extends SessionService { - private final String queueName; - private final String host; - private final String username; - private final String password; - private final String vpnName; - @Nullable private JCSMPSession jcsmpSession; - @Nullable private MessageReceiver messageReceiver; - private final RetryCallableManager retryCallableManager = RetryCallableManager.create(); +@AutoValue +public abstract class BasicAuthJcsmpSessionService extends SessionService { + + /** The name of the queue to receive messages from. */ + public abstract @Nullable String queueName(); + + /** The host name or IP address of the Solace broker. Format: Host[:Port] */ + public abstract String host(); + + /** The username to use for authentication. */ + public abstract String username(); + + /** The password to use for authentication. */ + public abstract String password(); + + /** The name of the VPN to connect to. */ + public abstract String vpnName(); + + public static Builder builder() { + return new AutoValue_BasicAuthJcsmpSessionService.Builder().vpnName(DEFAULT_VPN_NAME); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder queueName(@Nullable String queueName); + + public abstract Builder host(String host); - /** - * Creates a new {@link BasicAuthJcsmpSessionService} with the given parameters. - * - * @param queueName The name of the queue to receive messages from. - * @param host The host name or IP address of the Solace broker. Format: Host[:Port] - * @param username The username to use for authentication. - * @param password The password to use for authentication. - * @param vpnName The name of the VPN to connect to. - */ - public BasicAuthJcsmpSessionService( - String queueName, String host, String username, String password, String vpnName) { - this.queueName = queueName; - this.host = host; - this.username = username; - this.password = password; - this.vpnName = vpnName; + public abstract Builder username(String username); + + public abstract Builder password(String password); + + public abstract Builder vpnName(String vpnName); + + public abstract BasicAuthJcsmpSessionService build(); } + @Nullable private transient JCSMPSession jcsmpSession; + @Nullable private transient MessageReceiver messageReceiver; + @Nullable private transient MessageProducer messageProducer; + private final java.util.Queue publishedResultsQueue = + new ConcurrentLinkedQueue<>(); + private final RetryCallableManager retryCallableManager = RetryCallableManager.create(); + @Override public void connect() { retryCallableManager.retryCallable(this::connectSession, ImmutableSet.of(JCSMPException.class)); @@ -79,6 +102,9 @@ public void close() { if (messageReceiver != null) { messageReceiver.close(); } + if (messageProducer != null) { + messageProducer.close(); + } if (!isClosed()) { checkStateNotNull(jcsmpSession).closeSession(); } @@ -88,24 +114,64 @@ public void close() { } @Override - public MessageReceiver createReceiver() { - this.messageReceiver = - retryCallableManager.retryCallable( - this::createFlowReceiver, ImmutableSet.of(JCSMPException.class)); + public MessageReceiver getReceiver() { + if (this.messageReceiver == null) { + this.messageReceiver = + retryCallableManager.retryCallable( + this::createFlowReceiver, ImmutableSet.of(JCSMPException.class)); + } return this.messageReceiver; } + @Override + public MessageProducer getInitializedProducer(SubmissionMode submissionMode) { + if (this.messageProducer == null || this.messageProducer.isClosed()) { + Callable create = () -> createXMLMessageProducer(submissionMode); + this.messageProducer = + retryCallableManager.retryCallable(create, ImmutableSet.of(JCSMPException.class)); + } + return checkStateNotNull(this.messageProducer); + } + + @Override + public java.util.Queue getPublishedResultsQueue() { + return publishedResultsQueue; + } + @Override public boolean isClosed() { return jcsmpSession == null || jcsmpSession.isClosed(); } + private MessageProducer createXMLMessageProducer(SubmissionMode submissionMode) + throws JCSMPException, IOException { + + if (isClosed()) { + connectWriteSession(submissionMode); + } + + @SuppressWarnings("nullness") + Callable initProducer = + () -> + Objects.requireNonNull(jcsmpSession) + .getMessageProducer(new PublishResultHandler(publishedResultsQueue)); + + XMLMessageProducer producer = + retryCallableManager.retryCallable(initProducer, ImmutableSet.of(JCSMPException.class)); + if (producer == null) { + throw new IOException("SolaceIO.Write: Could not create producer, producer object is null"); + } + return new SolaceMessageProducer(producer); + } + private MessageReceiver createFlowReceiver() throws JCSMPException, IOException { if (isClosed()) { connectSession(); } - Queue queue = JCSMPFactory.onlyInstance().createQueue(queueName); + Queue queue = + JCSMPFactory.onlyInstance() + .createQueue(checkStateNotNull(queueName(), "SolaceIO.Read: Queue is not set.")); ConsumerFlowProperties flowProperties = new ConsumerFlowProperties(); flowProperties.setEndpoint(queue); @@ -118,7 +184,8 @@ private MessageReceiver createFlowReceiver() throws JCSMPException, IOException createFlowReceiver(jcsmpSession, flowProperties, endpointProperties)); } throw new IOException( - "SolaceIO.Read: Could not create a receiver from the Jcsmp session: session object is null."); + "SolaceIO.Read: Could not create a receiver from the Jcsmp session: session object is" + + " null."); } // The `@SuppressWarning` is needed here, because the checkerframework reports an error for the @@ -141,20 +208,33 @@ private int connectSession() throws JCSMPException { return 0; } + private int connectWriteSession(SubmissionMode mode) throws JCSMPException { + if (jcsmpSession == null) { + jcsmpSession = createWriteSessionObject(mode); + } + jcsmpSession.connect(); + return 0; + } + private JCSMPSession createSessionObject() throws InvalidPropertiesException { JCSMPProperties properties = initializeSessionProperties(new JCSMPProperties()); return JCSMPFactory.onlyInstance().createSession(properties); } + private JCSMPSession createWriteSessionObject(SubmissionMode mode) + throws InvalidPropertiesException { + return JCSMPFactory.onlyInstance().createSession(initializeWriteSessionProperties(mode)); + } + @Override public JCSMPProperties initializeSessionProperties(JCSMPProperties baseProps) { - baseProps.setProperty(JCSMPProperties.VPN_NAME, vpnName); + baseProps.setProperty(JCSMPProperties.VPN_NAME, vpnName()); baseProps.setProperty( JCSMPProperties.AUTHENTICATION_SCHEME, JCSMPProperties.AUTHENTICATION_SCHEME_BASIC); - baseProps.setProperty(JCSMPProperties.USERNAME, username); - baseProps.setProperty(JCSMPProperties.PASSWORD, password); - baseProps.setProperty(JCSMPProperties.HOST, host); + baseProps.setProperty(JCSMPProperties.USERNAME, username()); + baseProps.setProperty(JCSMPProperties.PASSWORD, password()); + baseProps.setProperty(JCSMPProperties.HOST, host()); return baseProps; } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionServiceFactory.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionServiceFactory.java index 2084e61b7e38..199dcccee854 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionServiceFactory.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/BasicAuthJcsmpSessionServiceFactory.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.io.solace.broker; import static org.apache.beam.sdk.io.solace.broker.SessionService.DEFAULT_VPN_NAME; -import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import com.google.auto.value.AutoValue; @@ -31,12 +30,16 @@ */ @AutoValue public abstract class BasicAuthJcsmpSessionServiceFactory extends SessionServiceFactory { + /** The host name or IP address of the Solace broker. Format: Host[:Port] */ public abstract String host(); + /** The username to use for authentication. */ public abstract String username(); + /** The password to use for authentication. */ public abstract String password(); + /** The name of the VPN to connect to. */ public abstract String vpnName(); public static Builder builder() { @@ -54,6 +57,7 @@ public abstract static class Builder { /** Set Solace username. */ public abstract Builder username(String username); + /** Set Solace password. */ public abstract Builder password(String password); @@ -65,11 +69,15 @@ public abstract static class Builder { @Override public SessionService create() { - return new BasicAuthJcsmpSessionService( - checkStateNotNull(queue, "SolaceIO.Read: Queue is not set.").getName(), - host(), - username(), - password(), - vpnName()); + BasicAuthJcsmpSessionService.Builder builder = BasicAuthJcsmpSessionService.builder(); + if (queue != null) { + builder = builder.queueName(queue.getName()); + } + return builder + .host(host()) + .username(username()) + .password(password()) + .vpnName(vpnName()) + .build(); } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/GCPSecretSessionServiceFactory.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/GCPSecretSessionServiceFactory.java index dd87e1d75fa5..7f691b46be31 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/GCPSecretSessionServiceFactory.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/GCPSecretSessionServiceFactory.java @@ -117,7 +117,7 @@ public abstract static class Builder { @Override public SessionService create() { - String password = null; + String password; try { password = retrieveSecret(); } catch (IOException e) { diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducer.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducer.java new file mode 100644 index 000000000000..8aa254b92cb1 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducer.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import java.util.List; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.transforms.SerializableFunction; + +/** + * Base class for publishing messages to a Solace broker. + * + *

    Implementations of this interface are responsible for managing the connection to the broker + * and for publishing messages to the broker. + */ +@Internal +public interface MessageProducer { + + /** Publishes a message to the broker. */ + void publishSingleMessage( + Solace.Record msg, + Destination topicOrQueue, + boolean useCorrelationKeyLatency, + DeliveryMode deliveryMode); + + /** + * Publishes a batch of messages to the broker. + * + *

    The size of the batch cannot exceed 50 messages, this is a limitation of the Solace API. + * + *

    It returns the number of messages written. + */ + int publishBatch( + List records, + boolean useCorrelationKeyLatency, + SerializableFunction destinationFn, + DeliveryMode deliveryMode); + + /** Returns {@literal true} if the message producer is closed, {@literal false} otherwise. */ + boolean isClosed(); + + /** Closes the message producer. */ + void close(); +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducerUtils.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducerUtils.java new file mode 100644 index 000000000000..dd4610910ff4 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/MessageProducerUtils.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import com.solacesystems.jcsmp.BytesXMLMessage; +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import com.solacesystems.jcsmp.JCSMPFactory; +import com.solacesystems.jcsmp.JCSMPSendMultipleEntry; +import java.util.List; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.transforms.SerializableFunction; + +@Internal +public class MessageProducerUtils { + // This is the batch limit supported by the send multiple JCSMP API method. + static final int SOLACE_BATCH_LIMIT = 50; + + /** + * Create a {@link BytesXMLMessage} to be published in Solace. + * + * @param record The record to be published. + * @param useCorrelationKeyLatency Whether to use a complex key for tracking latency. + * @param deliveryMode The {@link DeliveryMode} used to publish the message. + * @return A {@link BytesXMLMessage} that can be sent to Solace "as is". + */ + public static BytesXMLMessage createBytesXMLMessage( + Solace.Record record, boolean useCorrelationKeyLatency, DeliveryMode deliveryMode) { + JCSMPFactory jcsmpFactory = JCSMPFactory.onlyInstance(); + BytesXMLMessage msg = jcsmpFactory.createBytesXMLMessage(); + byte[] payload = record.getPayload(); + msg.writeBytes(payload); + + Long senderTimestamp = record.getSenderTimestamp(); + if (senderTimestamp == null) { + senderTimestamp = System.currentTimeMillis(); + } + msg.setSenderTimestamp(senderTimestamp); + msg.setDeliveryMode(deliveryMode); + if (useCorrelationKeyLatency) { + Solace.CorrelationKey key = + Solace.CorrelationKey.builder() + .setMessageId(record.getMessageId()) + .setPublishMonotonicNanos(System.nanoTime()) + .build(); + msg.setCorrelationKey(key); + } else { + // Use only a string as correlation key + msg.setCorrelationKey(record.getMessageId()); + } + msg.setApplicationMessageId(record.getMessageId()); + return msg; + } + + /** + * Create a {@link JCSMPSendMultipleEntry} array to be published in Solace. This can be used with + * `sendMultiple` to send all the messages in a single API call. + * + *

    The size of the list cannot be larger than 50 messages. This is a hard limit enforced by the + * Solace API. + * + * @param records A {@link List} of records to be published + * @param useCorrelationKeyLatency Whether to use a complex key for tracking latency. + * @param destinationFn A function that maps every record to its destination. + * @param deliveryMode The {@link DeliveryMode} used to publish the message. + * @return A {@link JCSMPSendMultipleEntry} array that can be sent to Solace "as is". + */ + public static JCSMPSendMultipleEntry[] createJCSMPSendMultipleEntry( + List records, + boolean useCorrelationKeyLatency, + SerializableFunction destinationFn, + DeliveryMode deliveryMode) { + if (records.size() > SOLACE_BATCH_LIMIT) { + throw new RuntimeException( + String.format( + "SolaceIO.Write: Trying to create a batch of %d, but Solace supports a" + + " maximum of %d. The batch will likely be rejected by Solace.", + records.size(), SOLACE_BATCH_LIMIT)); + } + + JCSMPSendMultipleEntry[] entries = new JCSMPSendMultipleEntry[records.size()]; + for (int i = 0; i < records.size(); i++) { + Solace.Record record = records.get(i); + JCSMPSendMultipleEntry entry = + JCSMPFactory.onlyInstance() + .createSendMultipleEntry( + createBytesXMLMessage(record, useCorrelationKeyLatency, deliveryMode), + destinationFn.apply(record)); + entries[i] = entry; + } + + return entries; + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/PublishResultHandler.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/PublishResultHandler.java new file mode 100644 index 000000000000..1153bfcb7a1c --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/PublishResultHandler.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import com.solacesystems.jcsmp.JCSMPException; +import com.solacesystems.jcsmp.JCSMPStreamingPublishCorrelatingEventHandler; +import java.util.Queue; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.PublishResult; +import org.apache.beam.sdk.io.solace.write.UnboundedSolaceWriter; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class is required to handle callbacks from Solace, to find out if messages were actually + * published or there were any kind of error. + * + *

    This class is also used to calculate the latency of the publication. The correlation key + * contains the original timestamp of when the message was sent from the pipeline to Solace. The + * comparison of that value with the clock now, using a monotonic clock, is understood as the + * latency of the publication + */ +public final class PublishResultHandler implements JCSMPStreamingPublishCorrelatingEventHandler { + + private static final Logger LOG = LoggerFactory.getLogger(PublishResultHandler.class); + private final Queue publishResultsQueue; + private final Counter batchesRejectedByBroker = + Metrics.counter(UnboundedSolaceWriter.class, "batches_rejected"); + + public PublishResultHandler(Queue publishResultsQueue) { + this.publishResultsQueue = publishResultsQueue; + } + + @Override + public void handleErrorEx(Object key, JCSMPException cause, long timestamp) { + processKey(key, false, cause); + } + + @Override + public void responseReceivedEx(Object key) { + processKey(key, true, null); + } + + private void processKey(Object key, boolean isPublished, @Nullable JCSMPException cause) { + PublishResult.Builder resultBuilder = PublishResult.builder(); + String messageId; + if (key == null) { + messageId = ""; + } else if (key instanceof Solace.CorrelationKey) { + messageId = ((Solace.CorrelationKey) key).getMessageId(); + long latencyNanos = calculateLatency((Solace.CorrelationKey) key); + resultBuilder = resultBuilder.setLatencyNanos(latencyNanos); + } else { + messageId = key.toString(); + } + + resultBuilder = resultBuilder.setMessageId(messageId).setPublished(isPublished); + if (!isPublished) { + batchesRejectedByBroker.inc(); + if (cause != null) { + resultBuilder = resultBuilder.setError(cause.getMessage()); + } else { + resultBuilder = resultBuilder.setError("NULL - Not set by Solace"); + } + } else if (cause != null) { + LOG.warn( + "Message with id {} is published but exception is populated. Ignoring exception", + messageId); + } + + PublishResult publishResult = resultBuilder.build(); + // Static reference, it receives all callbacks from all publications + // from all threads + publishResultsQueue.add(publishResult); + } + + private static long calculateLatency(Solace.CorrelationKey key) { + long currentMillis = System.nanoTime(); + long publishMillis = key.getPublishMonotonicNanos(); + return currentMillis - publishMillis; + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java index aed700a71ded..84a876a9d0bc 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionService.java @@ -19,7 +19,11 @@ import com.solacesystems.jcsmp.JCSMPProperties; import java.io.Serializable; +import java.util.Queue; import org.apache.beam.sdk.io.solace.SolaceIO; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.data.Solace.PublishResult; +import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -69,21 +73,23 @@ *

    For basic authentication, use {@link BasicAuthJcsmpSessionService} and {@link * BasicAuthJcsmpSessionServiceFactory}. * - *

    For other situations, you need to extend this class. For instance: + *

    For other situations, you need to extend this class and implement the `equals` method, so two + * instances of your class can be compared by value. We recommend using AutoValue for that. For + * instance: * *

    {@code
    + * {@literal }@AutoValue
      * public class MySessionService extends SessionService {
    - *   private final String authToken;
    + *   abstract String authToken();
      *
    - *   public MySessionService(String token) {
    - *    this.oauthToken = token;
    - *    ...
    + *   public static MySessionService create(String authToken) {
    + *       return new AutoValue_MySessionService(authToken);
      *   }
      *
      *   {@literal }@Override
      *   public JCSMPProperties initializeSessionProperties(JCSMPProperties baseProps) {
      *     baseProps.setProperty(JCSMPProperties.AUTHENTICATION_SCHEME, JCSMPProperties.AUTHENTICATION_SCHEME_OAUTH2);
    - *     baseProps.setProperty(JCSMPProperties.OAUTH2_ACCESS_TOKEN, authToken);
    + *     baseProps.setProperty(JCSMPProperties.OAUTH2_ACCESS_TOKEN, authToken());
      *     return props;
      *   }
      *
    @@ -101,6 +107,7 @@ public abstract class SessionService implements Serializable {
     
       public static final String DEFAULT_VPN_NAME = "default";
     
    +  private static final int TESTING_PUB_ACK_WINDOW = 1;
       private static final int STREAMING_PUB_ACK_WINDOW = 50;
       private static final int BATCHED_PUB_ACK_WINDOW = 255;
     
    @@ -121,10 +128,25 @@ public abstract class SessionService implements Serializable {
       public abstract boolean isClosed();
     
       /**
    -   * Creates a MessageReceiver object for receiving messages from Solace. Typically, this object is
    -   * created from the session instance.
    +   * Returns a MessageReceiver object for receiving messages from Solace. If it is the first time
    +   * this method is used, the receiver is created from the session instance, otherwise it returns
    +   * the receiver created initially.
        */
    -  public abstract MessageReceiver createReceiver();
    +  public abstract MessageReceiver getReceiver();
    +
    +  /**
    +   * Returns a MessageProducer object for publishing messages to Solace. If it is the first time
    +   * this method is used, the producer is created from the session instance, otherwise it returns
    +   * the producer created initially.
    +   */
    +  public abstract MessageProducer getInitializedProducer(SubmissionMode mode);
    +
    +  /**
    +   * Returns the {@link Queue} instance associated with this session, with the
    +   * asynchronously received callbacks from Solace for message publications. The queue
    +   * implementation has to be thread-safe for production use-cases.
    +   */
    +  public abstract Queue getPublishedResultsQueue();
     
       /**
        * Override this method and provide your specific properties, including all those related to
    @@ -147,6 +169,20 @@ public abstract class SessionService implements Serializable {
        */
       public abstract JCSMPProperties initializeSessionProperties(JCSMPProperties baseProperties);
     
    +  /**
    +   * You need to override this method to be able to compare these objects by value. We recommend
    +   * using AutoValue for that.
    +   */
    +  @Override
    +  public abstract boolean equals(@Nullable Object other);
    +
    +  /**
    +   * You need to override this method to be able to compare these objects by value. We recommend
    +   * using AutoValue for that.
    +   */
    +  @Override
    +  public abstract int hashCode();
    +
       /**
        * This method will be called by the write connector when a new session is started.
        *
    @@ -186,50 +222,80 @@ private static JCSMPProperties overrideConnectorProperties(
         // received from Solace. A value of 1 will have the lowest latency, but a very low
         // throughput and a monumental backpressure.
     
    -    // This controls how the messages are sent to Solace
    -    if (mode == SolaceIO.SubmissionMode.HIGHER_THROUGHPUT) {
    -      // Create a parallel thread and a queue to send the messages
    +    // Retrieve current values of the properties
    +    Boolean msgCbProp = props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR);
    +    Integer ackWindowSize = props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE);
     
    -      Boolean msgCbProp = props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR);
    -      if (msgCbProp != null && msgCbProp) {
    -        LOG.warn(
    -            "SolaceIO.Write: Overriding MESSAGE_CALLBACK_ON_REACTOR to false since"
    -                + " HIGHER_THROUGHPUT mode was selected");
    -      }
    +    switch (mode) {
    +      case HIGHER_THROUGHPUT:
    +        // Check if it was set by user, show override warning
    +        if (msgCbProp != null && msgCbProp) {
    +          LOG.warn(
    +              "SolaceIO.Write: Overriding MESSAGE_CALLBACK_ON_REACTOR to false since"
    +                  + " HIGHER_THROUGHPUT mode was selected");
    +        }
    +        if ((ackWindowSize != null && ackWindowSize != BATCHED_PUB_ACK_WINDOW)) {
    +          LOG.warn(
    +              String.format(
    +                  "SolaceIO.Write: Overriding PUB_ACK_WINDOW_SIZE to %d since"
    +                      + " HIGHER_THROUGHPUT mode was selected",
    +                  BATCHED_PUB_ACK_WINDOW));
    +        }
     
    -      props.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, false);
    +        // Override the properties
    +        // Use a dedicated thread for callbacks, increase the ack window size
    +        props.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, false);
    +        props.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, BATCHED_PUB_ACK_WINDOW);
    +        LOG.info(
    +            "SolaceIO.Write: Using HIGHER_THROUGHPUT mode, MESSAGE_CALLBACK_ON_REACTOR is FALSE,"
    +                + " PUB_ACK_WINDOW_SIZE is {}",
    +            BATCHED_PUB_ACK_WINDOW);
    +        break;
    +      case LOWER_LATENCY:
    +        // Check if it was set by user, show override warning
    +        if (msgCbProp != null && !msgCbProp) {
    +          LOG.warn(
    +              "SolaceIO.Write: Overriding MESSAGE_CALLBACK_ON_REACTOR to true since"
    +                  + " LOWER_LATENCY mode was selected");
    +        }
     
    -      Integer ackWindowSize = props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE);
    -      if ((ackWindowSize != null && ackWindowSize != BATCHED_PUB_ACK_WINDOW)) {
    -        LOG.warn(
    -            String.format(
    -                "SolaceIO.Write: Overriding PUB_ACK_WINDOW_SIZE to %d since"
    -                    + " HIGHER_THROUGHPUT mode was selected",
    -                BATCHED_PUB_ACK_WINDOW));
    -      }
    -      props.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, BATCHED_PUB_ACK_WINDOW);
    -    } else {
    -      // Send from the same thread where the produced is being called. This offers the lowest
    -      // latency, but a low throughput too.
    -      Boolean msgCbProp = props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR);
    -      if (msgCbProp != null && !msgCbProp) {
    -        LOG.warn(
    -            "SolaceIO.Write: Overriding MESSAGE_CALLBACK_ON_REACTOR to true since"
    -                + " LOWER_LATENCY mode was selected");
    -      }
    +        if ((ackWindowSize != null && ackWindowSize != STREAMING_PUB_ACK_WINDOW)) {
    +          LOG.warn(
    +              String.format(
    +                  "SolaceIO.Write: Overriding PUB_ACK_WINDOW_SIZE to %d since"
    +                      + " LOWER_LATENCY mode was selected",
    +                  STREAMING_PUB_ACK_WINDOW));
    +        }
     
    -      props.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, true);
    +        // Override the properties
    +        // Send from the same thread where the produced is being called. This offers the lowest
    +        // latency, but a low throughput too.
    +        props.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, true);
    +        props.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, STREAMING_PUB_ACK_WINDOW);
    +        LOG.info(
    +            "SolaceIO.Write: Using LOWER_LATENCY mode, MESSAGE_CALLBACK_ON_REACTOR is TRUE,"
    +                + " PUB_ACK_WINDOW_SIZE is {}",
    +            STREAMING_PUB_ACK_WINDOW);
     
    -      Integer ackWindowSize = props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE);
    -      if ((ackWindowSize != null && ackWindowSize != STREAMING_PUB_ACK_WINDOW)) {
    +        break;
    +      case CUSTOM:
    +        LOG.info(
    +            " SolaceIO.Write: Using the custom JCSMP properties set by the user. No property has"
    +                + " been overridden by the connector.");
    +        break;
    +      case TESTING:
             LOG.warn(
    -            String.format(
    -                "SolaceIO.Write: Overriding PUB_ACK_WINDOW_SIZE to %d since"
    -                    + " LOWER_LATENCY mode was selected",
    -                STREAMING_PUB_ACK_WINDOW));
    -      }
    -
    -      props.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, STREAMING_PUB_ACK_WINDOW);
    +            "SolaceIO.Write: Overriding JCSMP properties for testing. **IF THIS IS AN"
    +                + " ACTUAL PIPELINE, CHANGE THE SUBMISSION MODE TO HIGHER_THROUGHPUT "
    +                + "OR LOWER_LATENCY.**");
    +        // Minimize multi-threading for testing
    +        props.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, true);
    +        props.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, TESTING_PUB_ACK_WINDOW);
    +        break;
    +      default:
    +        LOG.error(
    +            "SolaceIO.Write: no submission mode is selected. Set the submission mode to"
    +                + " HIGHER_THROUGHPUT or LOWER_LATENCY;");
         }
         return props;
       }
    diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionServiceFactory.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionServiceFactory.java
    index 027de2cff134..bd1f3c23694d 100644
    --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionServiceFactory.java
    +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SessionServiceFactory.java
    @@ -19,11 +19,40 @@
     
     import com.solacesystems.jcsmp.Queue;
     import java.io.Serializable;
    +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode;
     import org.checkerframework.checker.nullness.qual.Nullable;
     
     /**
    - * This abstract class serves as a blueprint for creating `SessionService` objects. It introduces a
    - * queue property and mandates the implementation of a create() method in concrete subclasses.
    + * This abstract class serves as a blueprint for creating `SessionServiceFactory` objects. It
    + * introduces a queue property and mandates the implementation of a create() method in concrete
    + * subclasses.
    + *
    + * 

    For basic authentication, use {@link BasicAuthJcsmpSessionServiceFactory}. + * + *

    For other situations, you need to extend this class. Classes extending from this abstract + * class must implement the `equals` method so two instances can be compared by value, and not by + * reference. We recommend using AutoValue for that. + * + *

    {@code
    + * {@literal @}AutoValue
    + * public abstract class MyFactory implements SessionServiceClientFactory {
    + *
    + *   abstract String value1();
    + *
    + *   abstract String value2();
    + *
    + *   public static MyFactory create(String value1, String value2) {
    + *     return new AutoValue_MyFactory.Builder(value1, value2);
    + *   }
    + *
    + *   ...
    + *
    + *   {@literal @}Override
    + *   public SessionService create() {
    + *     ...
    + *   }
    + * }
    + * }
    */ public abstract class SessionServiceFactory implements Serializable { /** @@ -34,12 +63,32 @@ public abstract class SessionServiceFactory implements Serializable { */ @Nullable Queue queue; + /** + * The write submission mode. This is set when the writers are created. This property is used only + * by the write connector. + */ + @Nullable SubmissionMode submissionMode; + /** * This is the core method that subclasses must implement. It defines how to construct and return * a SessionService object. */ public abstract SessionService create(); + /** + * You need to override this method to be able to compare these objects by value. We recommend + * using AutoValue for that. + */ + @Override + public abstract boolean equals(@Nullable Object other); + + /** + * You need to override this method to be able to compare these objects by value. We recommend + * using AutoValue for that. + */ + @Override + public abstract int hashCode(); + /** * This method is called in the {@link * org.apache.beam.sdk.io.solace.SolaceIO.Read#expand(org.apache.beam.sdk.values.PBegin)} method @@ -48,4 +97,15 @@ public abstract class SessionServiceFactory implements Serializable { public void setQueue(Queue queue) { this.queue = queue; } + + /** + * Called by the write connector to set the submission mode used to create the message producers. + */ + public void setSubmissionMode(SubmissionMode submissionMode) { + this.submissionMode = submissionMode; + } + + public @Nullable SubmissionMode getSubmissionMode() { + return submissionMode; + } } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageProducer.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageProducer.java new file mode 100644 index 000000000000..b3806b5afae9 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/broker/SolaceMessageProducer.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.broker; + +import static org.apache.beam.sdk.io.solace.broker.MessageProducerUtils.createBytesXMLMessage; +import static org.apache.beam.sdk.io.solace.broker.MessageProducerUtils.createJCSMPSendMultipleEntry; + +import com.solacesystems.jcsmp.BytesXMLMessage; +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import com.solacesystems.jcsmp.JCSMPException; +import com.solacesystems.jcsmp.JCSMPSendMultipleEntry; +import com.solacesystems.jcsmp.XMLMessageProducer; +import java.util.List; +import java.util.concurrent.Callable; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.RetryCallableManager; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; + +@Internal +public class SolaceMessageProducer implements MessageProducer { + + private final XMLMessageProducer producer; + private final RetryCallableManager retryCallableManager = RetryCallableManager.create(); + + public SolaceMessageProducer(XMLMessageProducer producer) { + this.producer = producer; + } + + @Override + public void publishSingleMessage( + Solace.Record record, + Destination topicOrQueue, + boolean useCorrelationKeyLatency, + DeliveryMode deliveryMode) { + BytesXMLMessage msg = createBytesXMLMessage(record, useCorrelationKeyLatency, deliveryMode); + Callable publish = + () -> { + producer.send(msg, topicOrQueue); + return 0; + }; + + retryCallableManager.retryCallable(publish, ImmutableSet.of(JCSMPException.class)); + } + + @Override + public int publishBatch( + List records, + boolean useCorrelationKeyLatency, + SerializableFunction destinationFn, + DeliveryMode deliveryMode) { + JCSMPSendMultipleEntry[] batch = + createJCSMPSendMultipleEntry( + records, useCorrelationKeyLatency, destinationFn, deliveryMode); + Callable publish = () -> producer.sendMultiple(batch, 0, batch.length, 0); + return retryCallableManager.retryCallable(publish, ImmutableSet.of(JCSMPException.class)); + } + + @Override + public boolean isClosed() { + return producer == null || producer.isClosed(); + } + + @Override + public void close() { + if (!isClosed()) { + this.producer.close(); + } + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Solace.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Solace.java index 00b94b5b9ea9..21274237f46a 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Solace.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/data/Solace.java @@ -21,7 +21,6 @@ import com.solacesystems.jcsmp.BytesXMLMessage; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.nio.ByteBuffer; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaFieldNumber; @@ -52,6 +51,7 @@ public String getName() { return name; } } + /** Represents a Solace topic. */ public static class Topic { private final String name; @@ -68,6 +68,7 @@ public String getName() { return name; } } + /** Represents a Solace destination type. */ public enum DestinationType { TOPIC, @@ -93,17 +94,17 @@ public abstract static class Destination { */ public abstract DestinationType getType(); - static Builder builder() { + public static Builder builder() { return new AutoValue_Solace_Destination.Builder(); } @AutoValue.Builder - abstract static class Builder { - abstract Builder setName(String name); + public abstract static class Builder { + public abstract Builder setName(String name); - abstract Builder setType(DestinationType type); + public abstract Builder setType(DestinationType type); - abstract Destination build(); + public abstract Destination build(); } } @@ -120,17 +121,19 @@ public abstract static class Record { * @return The message ID, or null if not available. */ @SchemaFieldNumber("0") - public abstract @Nullable String getMessageId(); + public abstract String getMessageId(); /** - * Gets the payload of the message as a ByteString. + * Gets the payload of the message as a byte array. * *

    Mapped from {@link BytesXMLMessage#getBytes()} * * @return The message payload. */ + @SuppressWarnings("mutable") @SchemaFieldNumber("1") - public abstract ByteBuffer getPayload(); + public abstract byte[] getPayload(); + /** * Gets the destination (topic or queue) to which the message was sent. * @@ -192,7 +195,7 @@ public abstract static class Record { * @return The timestamp. */ @SchemaFieldNumber("7") - public abstract long getReceiveTimestamp(); + public abstract @Nullable Long getReceiveTimestamp(); /** * Gets the timestamp (in milliseconds since the Unix epoch) when the message was sent by the @@ -241,55 +244,62 @@ public abstract static class Record { public abstract @Nullable String getReplicationGroupMessageId(); /** - * Gets the attachment data of the message as a ByteString, if any. This might represent files + * Gets the attachment data of the message as a byte array, if any. This might represent files * or other binary content associated with the message. * *

    Mapped from {@link BytesXMLMessage#getAttachmentByteBuffer()} * - * @return The attachment data, or an empty ByteString if no attachment is present. + * @return The attachment data, or an empty byte array if no attachment is present. */ + @SuppressWarnings("mutable") @SchemaFieldNumber("12") - public abstract ByteBuffer getAttachmentBytes(); + public abstract byte[] getAttachmentBytes(); - static Builder builder() { - return new AutoValue_Solace_Record.Builder(); + public static Builder builder() { + return new AutoValue_Solace_Record.Builder() + .setExpiration(0L) + .setPriority(-1) + .setRedelivered(false) + .setTimeToLive(0) + .setAttachmentBytes(new byte[0]); } @AutoValue.Builder - abstract static class Builder { - abstract Builder setMessageId(@Nullable String messageId); + public abstract static class Builder { + public abstract Builder setMessageId(String messageId); - abstract Builder setPayload(ByteBuffer payload); + public abstract Builder setPayload(byte[] payload); - abstract Builder setDestination(@Nullable Destination destination); + public abstract Builder setDestination(@Nullable Destination destination); - abstract Builder setExpiration(long expiration); + public abstract Builder setExpiration(long expiration); - abstract Builder setPriority(int priority); + public abstract Builder setPriority(int priority); - abstract Builder setRedelivered(boolean redelivered); + public abstract Builder setRedelivered(boolean redelivered); - abstract Builder setReplyTo(@Nullable Destination replyTo); + public abstract Builder setReplyTo(@Nullable Destination replyTo); - abstract Builder setReceiveTimestamp(long receiveTimestamp); + public abstract Builder setReceiveTimestamp(@Nullable Long receiveTimestamp); - abstract Builder setSenderTimestamp(@Nullable Long senderTimestamp); + public abstract Builder setSenderTimestamp(@Nullable Long senderTimestamp); - abstract Builder setSequenceNumber(@Nullable Long sequenceNumber); + public abstract Builder setSequenceNumber(@Nullable Long sequenceNumber); - abstract Builder setTimeToLive(long timeToLive); + public abstract Builder setTimeToLive(long timeToLive); - abstract Builder setReplicationGroupMessageId(@Nullable String replicationGroupMessageId); + public abstract Builder setReplicationGroupMessageId( + @Nullable String replicationGroupMessageId); - abstract Builder setAttachmentBytes(ByteBuffer attachmentBytes); + public abstract Builder setAttachmentBytes(byte[] attachmentBytes); - abstract Record build(); + public abstract Record build(); } } /** * The result of writing a message to Solace. This will be returned by the {@link - * com.google.cloud.dataflow.dce.io.solace.SolaceIO.Write} connector. + * org.apache.beam.sdk.io.solace.SolaceIO.Write} connector. * *

    This class provides a builder to create instances, but you will probably not need it. The * write connector will create and return instances of {@link Solace.PublishResult}. @@ -311,12 +321,12 @@ public abstract static class PublishResult { public abstract Boolean getPublished(); /** - * The publishing latency in milliseconds. This is the difference between the time the message + * The publishing latency in nanoseconds. This is the difference between the time the message * was created, and the time the message was published. It is only available if the {@link - * CorrelationKey} class is used as correlation key of the messages. + * CorrelationKey} class is used as correlation key of the messages, and null otherwise. */ @SchemaFieldNumber("2") - public abstract @Nullable Long getLatencyMilliseconds(); + public abstract @Nullable Long getLatencyNanos(); /** The error details if the message could not be published. */ @SchemaFieldNumber("3") @@ -332,7 +342,7 @@ public abstract static class Builder { public abstract Builder setPublished(Boolean published); - public abstract Builder setLatencyMilliseconds(Long latencyMs); + public abstract Builder setLatencyNanos(Long latencyNanos); public abstract Builder setError(String error); @@ -354,7 +364,7 @@ public abstract static class CorrelationKey { public abstract String getMessageId(); @SchemaFieldNumber("1") - public abstract long getPublishMonotonicMillis(); + public abstract long getPublishMonotonicNanos(); public static Builder builder() { return new AutoValue_Solace_CorrelationKey.Builder(); @@ -364,7 +374,7 @@ public static Builder builder() { public abstract static class Builder { public abstract Builder setMessageId(String messageId); - public abstract Builder setPublishMonotonicMillis(long millis); + public abstract Builder setPublishMonotonicNanos(long nanos); public abstract CorrelationKey build(); } @@ -414,7 +424,7 @@ public static class SolaceRecordMapper { Destination destination = getDestination(msg.getCorrelationId(), msg.getDestination()); return Record.builder() .setMessageId(msg.getApplicationMessageId()) - .setPayload(ByteBuffer.wrap(payloadBytesStream.toByteArray())) + .setPayload(payloadBytesStream.toByteArray()) .setDestination(destination) .setExpiration(msg.getExpiration()) .setPriority(msg.getPriority()) @@ -428,7 +438,7 @@ public static class SolaceRecordMapper { msg.getReplicationGroupMessageId() != null ? msg.getReplicationGroupMessageId().toString() : null) - .setAttachmentBytes(ByteBuffer.wrap(attachmentBytesStream.toByteArray())) + .setAttachmentBytes(attachmentBytesStream.toByteArray()) .build(); } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/UnboundedSolaceReader.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/UnboundedSolaceReader.java index c18a9d110b2a..a421970370da 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/UnboundedSolaceReader.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/read/UnboundedSolaceReader.java @@ -29,7 +29,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; -import org.apache.beam.sdk.io.solace.broker.MessageReceiver; import org.apache.beam.sdk.io.solace.broker.SempClient; import org.apache.beam.sdk.io.solace.broker.SessionService; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -49,7 +48,6 @@ class UnboundedSolaceReader extends UnboundedReader { private final SempClient sempClient; private @Nullable BytesXMLMessage solaceOriginalRecord; private @Nullable T solaceMappedRecord; - private @Nullable MessageReceiver messageReceiver; private @Nullable SessionService sessionService; AtomicBoolean active = new AtomicBoolean(true); @@ -72,7 +70,7 @@ public UnboundedSolaceReader(UnboundedSolaceSource currentSource) { @Override public boolean start() { populateSession(); - populateMessageConsumer(); + checkNotNull(sessionService).getReceiver().start(); return advance(); } @@ -85,22 +83,11 @@ public void populateSession() { } } - private void populateMessageConsumer() { - if (messageReceiver == null) { - messageReceiver = checkNotNull(sessionService).createReceiver(); - messageReceiver.start(); - } - MessageReceiver receiver = checkNotNull(messageReceiver); - if (receiver.isClosed()) { - receiver.start(); - } - } - @Override public boolean advance() { BytesXMLMessage receivedXmlMessage; try { - receivedXmlMessage = checkNotNull(messageReceiver).receive(); + receivedXmlMessage = checkNotNull(sessionService).getReceiver().receive(); } catch (IOException e) { LOG.warn("SolaceIO.Read: Exception when pulling messages from the broker.", e); return false; @@ -125,7 +112,7 @@ public void close() { @Override public Instant getWatermark() { // should be only used by a test receiver - if (checkNotNull(messageReceiver).isEOF()) { + if (checkNotNull(sessionService).getReceiver().isEOF()) { return BoundedWindow.TIMESTAMP_MAX_VALUE; } return watermarkPolicy.getWatermark(); diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/AddShardKeyDoFn.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/AddShardKeyDoFn.java new file mode 100644 index 000000000000..12d8a8507d8a --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/AddShardKeyDoFn.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.values.KV; + +/** + * This class a pseudo-key with a given cardinality. The downstream steps will use state {@literal + * &} timers to distribute the data and control for the number of parallel workers used for writing. + */ +@Internal +public class AddShardKeyDoFn extends DoFn> { + private final int shardCount; + private int shardKey; + + public AddShardKeyDoFn(int shardCount) { + this.shardCount = shardCount; + shardKey = -1; + } + + @ProcessElement + public void processElement( + @Element Solace.Record record, OutputReceiver> c) { + shardKey = (shardKey + 1) % shardCount; + c.output(KV.of(shardKey, record)); + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/RecordToPublishResultDoFn.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/RecordToPublishResultDoFn.java new file mode 100644 index 000000000000..4be5b0a014b3 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/RecordToPublishResultDoFn.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.transforms.DoFn; + +/** + * This class just transforms to PublishResult to be able to capture the windowing with the right + * strategy. The output is not used for anything else. + */ +@Internal +public class RecordToPublishResultDoFn extends DoFn { + @ProcessElement + public void processElement( + @Element Solace.Record record, OutputReceiver receiver) { + Solace.PublishResult result = + Solace.PublishResult.builder() + .setPublished(true) + .setMessageId(record.getMessageId()) + .setLatencyNanos(0L) + .build(); + receiver.output(result); + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceOutput.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceOutput.java index 6c37f879ae7f..d9c37326f83f 100644 --- a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceOutput.java +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceOutput.java @@ -22,6 +22,7 @@ import org.apache.beam.sdk.io.solace.SolaceIO; import org.apache.beam.sdk.io.solace.data.Solace; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; @@ -31,50 +32,33 @@ import org.checkerframework.checker.nullness.qual.Nullable; /** - * The {@link SolaceIO.Write} transform's output return this type, containing both the successful - * publishes ({@link #getSuccessfulPublish()}) and the failed publishes ({@link - * #getFailedPublish()}). + * The {@link SolaceIO.Write} transform's output return this type, containing the successful + * publishes ({@link #getSuccessfulPublish()}). To access failed records, configure the connector + * with {@link SolaceIO.Write#withErrorHandler(ErrorHandler)}. * *

    The streaming writer with DIRECT messages does not return anything, and the output {@link - * PCollection}s will be equal to null. + * PCollection} will be equal to null. */ public final class SolaceOutput implements POutput { private final Pipeline pipeline; - private final TupleTag failedPublishTag; private final TupleTag successfulPublishTag; - private final @Nullable PCollection failedPublish; private final @Nullable PCollection successfulPublish; - public @Nullable PCollection getFailedPublish() { - return failedPublish; - } - public @Nullable PCollection getSuccessfulPublish() { return successfulPublish; } public static SolaceOutput in( - Pipeline pipeline, - @Nullable PCollection failedPublish, - @Nullable PCollection successfulPublish) { - return new SolaceOutput( - pipeline, - SolaceIO.Write.FAILED_PUBLISH_TAG, - SolaceIO.Write.SUCCESSFUL_PUBLISH_TAG, - failedPublish, - successfulPublish); + Pipeline pipeline, @Nullable PCollection successfulPublish) { + return new SolaceOutput(pipeline, SolaceIO.Write.SUCCESSFUL_PUBLISH_TAG, successfulPublish); } private SolaceOutput( Pipeline pipeline, - TupleTag failedPublishTag, TupleTag successfulPublishTag, - @Nullable PCollection failedPublish, @Nullable PCollection successfulPublish) { this.pipeline = pipeline; - this.failedPublishTag = failedPublishTag; this.successfulPublishTag = successfulPublishTag; - this.failedPublish = failedPublish; this.successfulPublish = successfulPublish; } @@ -87,10 +71,6 @@ public Pipeline getPipeline() { public Map, PValue> expand() { ImmutableMap.Builder, PValue> builder = ImmutableMap., PValue>builder(); - if (failedPublish != null) { - builder.put(failedPublishTag, failedPublish); - } - if (successfulPublish != null) { builder.put(successfulPublishTag, successfulPublish); } diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceWriteSessionsHandler.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceWriteSessionsHandler.java new file mode 100644 index 000000000000..109010231d17 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/SolaceWriteSessionsHandler.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import static org.apache.beam.sdk.io.solace.SolaceIO.DEFAULT_WRITER_CLIENTS_PER_WORKER; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + +import com.google.auto.value.AutoValue; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.broker.SessionService; +import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; + +/** + * All the writer threads belonging to the same factory share the same instance of this class, to + * control for the number of clients that are connected to Solace, and minimize problems with quotas + * and limits. + * + *

    This class maintains a map of all the session open in a worker, and control the size of that + * map, to avoid creating more sessions than Solace could handle. + * + *

    This class is thread-safe and creates a pool of producers per SessionServiceFactory. If there + * is only a Write transform in the pipeline, this is effectively a singleton. If there are more + * than one, each {@link SessionServiceFactory} instance keeps their own pool of producers. + */ +final class SolaceWriteSessionsHandler { + + private static final ConcurrentHashMap sessionsMap = + new ConcurrentHashMap<>(DEFAULT_WRITER_CLIENTS_PER_WORKER); + + public static SessionService getSessionServiceWithProducer( + int producerIndex, SessionServiceFactory sessionServiceFactory, UUID writerTransformUuid) { + SessionConfigurationIndex key = + SessionConfigurationIndex.builder() + .producerIndex(producerIndex) + .sessionServiceFactory(sessionServiceFactory) + .writerTransformUuid(writerTransformUuid) + .build(); + return sessionsMap.computeIfAbsent( + key, SolaceWriteSessionsHandler::createSessionAndStartProducer); + } + + private static SessionService createSessionAndStartProducer(SessionConfigurationIndex key) { + SessionServiceFactory factory = key.sessionServiceFactory(); + SessionService sessionService = factory.create(); + // Start the producer now that the initialization is locked for other threads + SubmissionMode mode = factory.getSubmissionMode(); + checkStateNotNull( + mode, + "SolaceIO.Write: Submission mode is not set. You need to set it to create write sessions."); + sessionService.getInitializedProducer(mode); + return sessionService; + } + + /** Disconnect all the sessions from Solace, and clear the corresponding state. */ + public static void disconnectFromSolace( + SessionServiceFactory factory, int producersCardinality, UUID writerTransformUuid) { + for (int i = 0; i < producersCardinality; i++) { + SessionConfigurationIndex key = + SessionConfigurationIndex.builder() + .producerIndex(i) + .sessionServiceFactory(factory) + .writerTransformUuid(writerTransformUuid) + .build(); + + SessionService sessionService = sessionsMap.remove(key); + if (sessionService != null) { + sessionService.close(); + } + } + } + + @AutoValue + abstract static class SessionConfigurationIndex { + abstract int producerIndex(); + + abstract SessionServiceFactory sessionServiceFactory(); + + abstract UUID writerTransformUuid(); + + static Builder builder() { + return new AutoValue_SolaceWriteSessionsHandler_SessionConfigurationIndex.Builder(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder producerIndex(int producerIndex); + + abstract Builder sessionServiceFactory(SessionServiceFactory sessionServiceFactory); + + abstract Builder writerTransformUuid(UUID writerTransformUuid); + + abstract SessionConfigurationIndex build(); + } + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedBatchedSolaceWriter.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedBatchedSolaceWriter.java new file mode 100644 index 000000000000..dd4f81eeb082 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedBatchedSolaceWriter.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import java.io.IOException; +import java.util.List; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.Record; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.KV; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This DoFn is the responsible for writing to Solace in batch mode (holding up any messages), and + * emit the corresponding output (success or fail; only for persistent messages), so the + * SolaceIO.Write connector can be composed with other subsequent transforms in the pipeline. + * + *

    The DoFn will create several JCSMP sessions per VM, and the sessions and producers will be + * reused across different threads (if the number of threads is higher than the number of sessions, + * which is probably the most common case). + * + *

    The producer uses the JCSMP send multiple mode to publish a batch of messages together with a + * single API call. The acks from this publication are also processed in batch, and returned as the + * output of the DoFn. + * + *

    The batch size is 50, and this is currently the maximum value supported by Solace. + * + *

    There are no acks if the delivery mode is set to DIRECT. + * + *

    This writer DoFn offers higher throughput than {@link UnboundedStreamingSolaceWriter} but also + * higher latency. + */ +@Internal +public final class UnboundedBatchedSolaceWriter extends UnboundedSolaceWriter { + + private static final Logger LOG = LoggerFactory.getLogger(UnboundedBatchedSolaceWriter.class); + + private static final int ACKS_FLUSHING_INTERVAL_SECS = 10; + + private final Counter sentToBroker = + Metrics.counter(UnboundedBatchedSolaceWriter.class, "msgs_sent_to_broker"); + + private final Counter batchesRejectedByBroker = + Metrics.counter(UnboundedSolaceWriter.class, "batches_rejected"); + + // State variables are never explicitly "used" + @SuppressWarnings("UnusedVariable") + @TimerId("bundle_flusher") + private final TimerSpec bundleFlusherTimerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); + + public UnboundedBatchedSolaceWriter( + SerializableFunction destinationFn, + SessionServiceFactory sessionServiceFactory, + DeliveryMode deliveryMode, + SubmissionMode submissionMode, + int producersMapCardinality, + boolean publishLatencyMetrics) { + super( + destinationFn, + sessionServiceFactory, + deliveryMode, + submissionMode, + producersMapCardinality, + publishLatencyMetrics); + } + + // The state variable is here just to force a shuffling with a certain cardinality + @ProcessElement + public void processElement( + @Element KV element, + @TimerId("bundle_flusher") Timer bundleFlusherTimer, + @Timestamp Instant timestamp) { + + setCurrentBundleTimestamp(timestamp); + + Solace.Record record = element.getValue(); + + if (record == null) { + LOG.error( + "SolaceIO.Write: Found null record with key {}. Ignoring record.", element.getKey()); + } else { + addToCurrentBundle(record); + // Extend timer for bundle flushing + bundleFlusherTimer + .offset(Duration.standardSeconds(ACKS_FLUSHING_INTERVAL_SECS)) + .setRelative(); + } + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) throws IOException { + // Take messages in groups of 50 (if there are enough messages) + List currentBundle = getCurrentBundle(); + for (int i = 0; i < currentBundle.size(); i += SOLACE_BATCH_LIMIT) { + int toIndex = Math.min(i + SOLACE_BATCH_LIMIT, currentBundle.size()); + List batch = currentBundle.subList(i, toIndex); + if (batch.isEmpty()) { + continue; + } + publishBatch(batch); + } + getCurrentBundle().clear(); + + publishResults(BeamContextWrapper.of(context)); + } + + @OnTimer("bundle_flusher") + public void flushBundle(OnTimerContext context) throws IOException { + publishResults(BeamContextWrapper.of(context)); + } + + private void publishBatch(List records) { + try { + int entriesPublished = + solaceSessionServiceWithProducer() + .getInitializedProducer(getSubmissionMode()) + .publishBatch( + records, shouldPublishLatencyMetrics(), getDestinationFn(), getDeliveryMode()); + sentToBroker.inc(entriesPublished); + } catch (Exception e) { + batchesRejectedByBroker.inc(); + Solace.PublishResult errorPublish = + Solace.PublishResult.builder() + .setPublished(false) + .setMessageId(String.format("BATCH_OF_%d_ENTRIES", records.size())) + .setError( + String.format( + "Batch could not be published after several" + " retries. Error: %s", + e.getMessage())) + .setLatencyNanos(System.nanoTime()) + .build(); + solaceSessionServiceWithProducer().getPublishedResultsQueue().add(errorPublish); + } + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedSolaceWriter.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedSolaceWriter.java new file mode 100644 index 000000000000..1c98113c2416 --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedSolaceWriter.java @@ -0,0 +1,373 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import static org.apache.beam.sdk.io.solace.SolaceIO.Write.FAILED_PUBLISH_TAG; +import static org.apache.beam.sdk.io.solace.SolaceIO.Write.SUCCESSFUL_PUBLISH_TAG; + +import com.solacesystems.jcsmp.BytesXMLMessage; +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import com.solacesystems.jcsmp.JCSMPFactory; +import com.solacesystems.jcsmp.JCSMPSendMultipleEntry; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Queue; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.SolaceIO; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.broker.SessionService; +import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.PublishResult; +import org.apache.beam.sdk.io.solace.data.Solace.Record; +import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This DoFn encapsulates common code used both for the {@link UnboundedBatchedSolaceWriter} and + * {@link UnboundedStreamingSolaceWriter}. + */ +@Internal +public abstract class UnboundedSolaceWriter + extends DoFn, Solace.PublishResult> { + + private static final Logger LOG = LoggerFactory.getLogger(UnboundedSolaceWriter.class); + + // This is the batch limit supported by the send multiple JCSMP API method. + static final int SOLACE_BATCH_LIMIT = 50; + private final Distribution latencyPublish = + Metrics.distribution(SolaceIO.Write.class, "latency_publish_ms"); + + private final Distribution latencyErrors = + Metrics.distribution(SolaceIO.Write.class, "latency_failed_ms"); + + private final SerializableFunction destinationFn; + + private final SessionServiceFactory sessionServiceFactory; + private final DeliveryMode deliveryMode; + private final SubmissionMode submissionMode; + private final int producersMapCardinality; + private final boolean publishLatencyMetrics; + private static final AtomicInteger bundleProducerIndexCounter = new AtomicInteger(); + private int currentBundleProducerIndex = 0; + + private final List batchToEmit; + + private @Nullable Instant bundleTimestamp; + + final UUID writerTransformUuid = UUID.randomUUID(); + + public UnboundedSolaceWriter( + SerializableFunction destinationFn, + SessionServiceFactory sessionServiceFactory, + DeliveryMode deliveryMode, + SubmissionMode submissionMode, + int producersMapCardinality, + boolean publishLatencyMetrics) { + this.destinationFn = destinationFn; + this.sessionServiceFactory = sessionServiceFactory; + // Make sure that we set the submission mode now that we know which mode has been set by the + // user. + this.sessionServiceFactory.setSubmissionMode(submissionMode); + this.deliveryMode = deliveryMode; + this.submissionMode = submissionMode; + this.producersMapCardinality = producersMapCardinality; + this.publishLatencyMetrics = publishLatencyMetrics; + this.batchToEmit = new ArrayList<>(); + } + + @Teardown + public void teardown() { + SolaceWriteSessionsHandler.disconnectFromSolace( + sessionServiceFactory, producersMapCardinality, writerTransformUuid); + } + + public void updateProducerIndex() { + currentBundleProducerIndex = + bundleProducerIndexCounter.getAndIncrement() % producersMapCardinality; + } + + @StartBundle + public void startBundle() { + // Pick a producer at random for this bundle, reuse for the whole bundle + updateProducerIndex(); + batchToEmit.clear(); + } + + public SessionService solaceSessionServiceWithProducer() { + return SolaceWriteSessionsHandler.getSessionServiceWithProducer( + currentBundleProducerIndex, sessionServiceFactory, writerTransformUuid); + } + + public void publishResults(BeamContextWrapper context) { + long sumPublish = 0; + long countPublish = 0; + long minPublish = Long.MAX_VALUE; + long maxPublish = 0; + + long sumFailed = 0; + long countFailed = 0; + long minFailed = Long.MAX_VALUE; + long maxFailed = 0; + + Queue publishResultsQueue = + solaceSessionServiceWithProducer().getPublishedResultsQueue(); + Solace.PublishResult result = publishResultsQueue.poll(); + + if (result != null) { + if (getCurrentBundleTimestamp() == null) { + setCurrentBundleTimestamp(Instant.now()); + } + } + + while (result != null) { + Long latency = result.getLatencyNanos(); + + if (latency == null && shouldPublishLatencyMetrics()) { + LOG.error( + "SolaceIO.Write: Latency is null but user asked for latency metrics." + + " This may be a bug."); + } + + if (latency != null) { + if (result.getPublished()) { + sumPublish += latency; + countPublish++; + minPublish = Math.min(minPublish, latency); + maxPublish = Math.max(maxPublish, latency); + } else { + sumFailed += latency; + countFailed++; + minFailed = Math.min(minFailed, latency); + maxFailed = Math.max(maxFailed, latency); + } + } + if (result.getPublished()) { + context.output( + SUCCESSFUL_PUBLISH_TAG, result, getCurrentBundleTimestamp(), GlobalWindow.INSTANCE); + } else { + try { + BadRecord b = + BadRecord.fromExceptionInformation( + result, + null, + null, + Optional.ofNullable(result.getError()).orElse("SolaceIO.Write: unknown error.")); + context.output(FAILED_PUBLISH_TAG, b, getCurrentBundleTimestamp(), GlobalWindow.INSTANCE); + } catch (IOException e) { + // ignore, the exception is thrown when the exception argument in the + // `BadRecord.fromExceptionInformation` is not null. + } + } + + result = publishResultsQueue.poll(); + } + + if (shouldPublishLatencyMetrics()) { + // Report all latency value in milliseconds + if (countPublish > 0) { + getPublishLatencyMetric() + .update( + TimeUnit.NANOSECONDS.toMillis(sumPublish), + countPublish, + TimeUnit.NANOSECONDS.toMillis(minPublish), + TimeUnit.NANOSECONDS.toMillis(maxPublish)); + } + + if (countFailed > 0) { + getFailedLatencyMetric() + .update( + TimeUnit.NANOSECONDS.toMillis(sumFailed), + countFailed, + TimeUnit.NANOSECONDS.toMillis(minFailed), + TimeUnit.NANOSECONDS.toMillis(maxFailed)); + } + } + } + + public BytesXMLMessage createSingleMessage( + Solace.Record record, boolean useCorrelationKeyLatency) { + JCSMPFactory jcsmpFactory = JCSMPFactory.onlyInstance(); + BytesXMLMessage msg = jcsmpFactory.createBytesXMLMessage(); + byte[] payload = record.getPayload(); + msg.writeBytes(payload); + + Long senderTimestamp = record.getSenderTimestamp(); + if (senderTimestamp == null) { + LOG.error( + "SolaceIO.Write: Record with id {} has no sender timestamp. Using current" + + " worker clock as timestamp.", + record.getMessageId()); + senderTimestamp = System.currentTimeMillis(); + } + msg.setSenderTimestamp(senderTimestamp); + msg.setDeliveryMode(getDeliveryMode()); + if (useCorrelationKeyLatency) { + Solace.CorrelationKey key = + Solace.CorrelationKey.builder() + .setMessageId(record.getMessageId()) + .setPublishMonotonicNanos(System.nanoTime()) + .build(); + msg.setCorrelationKey(key); + } else { + // Use only a string as correlation key + msg.setCorrelationKey(record.getMessageId()); + } + msg.setApplicationMessageId(record.getMessageId()); + return msg; + } + + public JCSMPSendMultipleEntry[] createMessagesArray( + Iterable records, boolean useCorrelationKeyLatency) { + // Solace batch publishing only supports 50 elements max, so it is safe to convert to + // list here + ArrayList recordsList = Lists.newArrayList(records); + if (recordsList.size() > SOLACE_BATCH_LIMIT) { + LOG.error( + "SolaceIO.Write: Trying to create a batch of {}, but Solace supports a" + + " maximum of {}. The batch will likely be rejected by Solace.", + recordsList.size(), + SOLACE_BATCH_LIMIT); + } + + JCSMPSendMultipleEntry[] entries = new JCSMPSendMultipleEntry[recordsList.size()]; + for (int i = 0; i < recordsList.size(); i++) { + Solace.Record record = recordsList.get(i); + JCSMPSendMultipleEntry entry = + JCSMPFactory.onlyInstance() + .createSendMultipleEntry( + createSingleMessage(record, useCorrelationKeyLatency), + getDestinationFn().apply(record)); + entries[i] = entry; + } + + return entries; + } + + public int getProducersMapCardinality() { + return producersMapCardinality; + } + + public Distribution getPublishLatencyMetric() { + return latencyPublish; + } + + public Distribution getFailedLatencyMetric() { + return latencyErrors; + } + + public boolean shouldPublishLatencyMetrics() { + return publishLatencyMetrics; + } + + public SerializableFunction getDestinationFn() { + return destinationFn; + } + + public DeliveryMode getDeliveryMode() { + return deliveryMode; + } + + public SubmissionMode getSubmissionMode() { + return submissionMode; + } + + public void addToCurrentBundle(Solace.Record record) { + batchToEmit.add(record); + } + + public List getCurrentBundle() { + return batchToEmit; + } + + public @Nullable Instant getCurrentBundleTimestamp() { + return bundleTimestamp; + } + + public void setCurrentBundleTimestamp(Instant bundleTimestamp) { + if (this.bundleTimestamp == null || bundleTimestamp.isBefore(this.bundleTimestamp)) { + this.bundleTimestamp = bundleTimestamp; + } + } + + /** + * Since we need to publish from on timer methods and finish bundle methods, we need a consistent + * way to handle both WindowedContext and FinishBundleContext. + */ + static class BeamContextWrapper { + private @Nullable WindowedContext windowedContext; + private @Nullable FinishBundleContext finishBundleContext; + + private BeamContextWrapper() {} + + public static BeamContextWrapper of(WindowedContext windowedContext) { + BeamContextWrapper beamContextWrapper = new BeamContextWrapper(); + beamContextWrapper.windowedContext = windowedContext; + return beamContextWrapper; + } + + public static BeamContextWrapper of(FinishBundleContext finishBundleContext) { + BeamContextWrapper beamContextWrapper = new BeamContextWrapper(); + beamContextWrapper.finishBundleContext = finishBundleContext; + return beamContextWrapper; + } + + public void output( + TupleTag tag, + T output, + @Nullable Instant timestamp, // Not required for windowed context + @Nullable BoundedWindow window) { // Not required for windowed context + if (windowedContext != null) { + windowedContext.output(tag, output); + } else if (finishBundleContext != null) { + if (timestamp == null) { + throw new IllegalStateException( + "SolaceIO.Write.UnboundedSolaceWriter.Context: Timestamp is required for a" + + " FinishBundleContext."); + } + if (window == null) { + throw new IllegalStateException( + "SolaceIO.Write.UnboundedSolaceWriter.Context: BoundedWindow is required for a" + + " FinishBundleContext."); + } + finishBundleContext.output(tag, output, timestamp, window); + } else { + throw new IllegalStateException( + "SolaceIO.Write.UnboundedSolaceWriter.Context: No context provided"); + } + } + } +} diff --git a/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedStreamingSolaceWriter.java b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedStreamingSolaceWriter.java new file mode 100644 index 000000000000..6d6d0b27e2bb --- /dev/null +++ b/sdks/java/io/solace/src/main/java/org/apache/beam/sdk/io/solace/write/UnboundedStreamingSolaceWriter.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace.write; + +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.solace.SolaceIO; +import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.KV; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This DoFn is the responsible for writing to Solace in streaming mode (one message at a time, not + * holding up any message), and emit the corresponding output (success or fail; only for persistent + * messages), so the SolaceIO.Write connector can be composed with other subsequent transforms in + * the pipeline. + * + *

    The DoFn will create several JCSMP sessions per VM, and the sessions and producers will be + * reused across different threads (if the number of threads is higher than the number of sessions, + * which is probably the most common case). + * + *

    The producer uses the JCSMP streaming mode to publish a single message at a time, processing + * the acks from this publication, and returning them as output of the DoFn. + * + *

    There are no acks if the delivery mode is set to DIRECT. + * + *

    This writer DoFn offers lower latency and lower throughput than {@link + * UnboundedBatchedSolaceWriter}. + */ +@Internal +public final class UnboundedStreamingSolaceWriter extends UnboundedSolaceWriter { + + private static final Logger LOG = LoggerFactory.getLogger(UnboundedStreamingSolaceWriter.class); + + private final Counter sentToBroker = + Metrics.counter(UnboundedStreamingSolaceWriter.class, "msgs_sent_to_broker"); + + private final Counter rejectedByBroker = + Metrics.counter(UnboundedStreamingSolaceWriter.class, "msgs_rejected_by_broker"); + + // We use a state variable to force a shuffling and ensure the cardinality of the processing + @SuppressWarnings("UnusedVariable") + @StateId("current_key") + private final StateSpec> currentKeySpec = StateSpecs.value(); + + public UnboundedStreamingSolaceWriter( + SerializableFunction destinationFn, + SessionServiceFactory sessionServiceFactory, + DeliveryMode deliveryMode, + SolaceIO.SubmissionMode submissionMode, + int producersMapCardinality, + boolean publishLatencyMetrics) { + super( + destinationFn, + sessionServiceFactory, + deliveryMode, + submissionMode, + producersMapCardinality, + publishLatencyMetrics); + } + + @ProcessElement + public void processElement( + @Element KV element, + @Timestamp Instant timestamp, + @AlwaysFetched @StateId("current_key") ValueState currentKeyState) { + + setCurrentBundleTimestamp(timestamp); + + Integer currentKey = currentKeyState.read(); + Integer elementKey = element.getKey(); + Solace.Record record = element.getValue(); + + if (currentKey == null || !currentKey.equals(elementKey)) { + currentKeyState.write(elementKey); + } + + if (record == null) { + LOG.error("SolaceIO.Write: Found null record with key {}. Ignoring record.", elementKey); + return; + } + + // The publish method will retry, let's send a failure message if all the retries fail + try { + solaceSessionServiceWithProducer() + .getInitializedProducer(getSubmissionMode()) + .publishSingleMessage( + record, + getDestinationFn().apply(record), + shouldPublishLatencyMetrics(), + getDeliveryMode()); + sentToBroker.inc(); + } catch (Exception e) { + rejectedByBroker.inc(); + Solace.PublishResult errorPublish = + Solace.PublishResult.builder() + .setPublished(false) + .setMessageId(record.getMessageId()) + .setError( + String.format( + "Message could not be published after several" + " retries. Error: %s", + e.getMessage())) + .setLatencyNanos(System.nanoTime()) + .build(); + solaceSessionServiceWithProducer().getPublishedResultsQueue().add(errorPublish); + } + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) { + publishResults(BeamContextWrapper.of(context)); + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java index ec0ae7194686..38b4953a5984 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockEmptySessionService.java @@ -17,14 +17,24 @@ */ package org.apache.beam.sdk.io.solace; +import com.google.auto.value.AutoValue; import com.solacesystems.jcsmp.JCSMPProperties; +import java.util.Queue; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.broker.MessageProducer; import org.apache.beam.sdk.io.solace.broker.MessageReceiver; import org.apache.beam.sdk.io.solace.broker.SessionService; +import org.apache.beam.sdk.io.solace.data.Solace.PublishResult; -public class MockEmptySessionService extends SessionService { +@AutoValue +public abstract class MockEmptySessionService extends SessionService { String exceptionMessage = "This is an empty client, use a MockSessionService instead."; + public static MockEmptySessionService create() { + return new AutoValue_MockEmptySessionService(); + } + @Override public void close() { throw new UnsupportedOperationException(exceptionMessage); @@ -36,7 +46,17 @@ public boolean isClosed() { } @Override - public MessageReceiver createReceiver() { + public MessageReceiver getReceiver() { + throw new UnsupportedOperationException(exceptionMessage); + } + + @Override + public MessageProducer getInitializedProducer(SubmissionMode mode) { + throw new UnsupportedOperationException(exceptionMessage); + } + + @Override + public Queue getPublishedResultsQueue() { throw new UnsupportedOperationException(exceptionMessage); } diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockProducer.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockProducer.java new file mode 100644 index 000000000000..271310359577 --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockProducer.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace; + +import com.solacesystems.jcsmp.DeliveryMode; +import com.solacesystems.jcsmp.Destination; +import com.solacesystems.jcsmp.JCSMPException; +import java.time.Instant; +import java.util.List; +import org.apache.beam.sdk.io.solace.broker.MessageProducer; +import org.apache.beam.sdk.io.solace.broker.PublishResultHandler; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.Record; +import org.apache.beam.sdk.transforms.SerializableFunction; + +public abstract class MockProducer implements MessageProducer { + final PublishResultHandler handler; + + public MockProducer(PublishResultHandler handler) { + this.handler = handler; + } + + @Override + public int publishBatch( + List records, + boolean useCorrelationKeyLatency, + SerializableFunction destinationFn, + DeliveryMode deliveryMode) { + for (Record record : records) { + this.publishSingleMessage( + record, destinationFn.apply(record), useCorrelationKeyLatency, deliveryMode); + } + return records.size(); + } + + @Override + public boolean isClosed() { + return false; + } + + @Override + public void close() {} + + public static class MockSuccessProducer extends MockProducer { + public MockSuccessProducer(PublishResultHandler handler) { + super(handler); + } + + @Override + public void publishSingleMessage( + Record msg, + Destination topicOrQueue, + boolean useCorrelationKeyLatency, + DeliveryMode deliveryMode) { + if (useCorrelationKeyLatency) { + handler.responseReceivedEx( + Solace.PublishResult.builder() + .setPublished(true) + .setMessageId(msg.getMessageId()) + .build()); + } else { + handler.responseReceivedEx(msg.getMessageId()); + } + } + } + + public static class MockFailedProducer extends MockProducer { + public MockFailedProducer(PublishResultHandler handler) { + super(handler); + } + + @Override + public void publishSingleMessage( + Record msg, + Destination topicOrQueue, + boolean useCorrelationKeyLatency, + DeliveryMode deliveryMode) { + if (useCorrelationKeyLatency) { + handler.handleErrorEx( + Solace.PublishResult.builder() + .setPublished(false) + .setMessageId(msg.getMessageId()) + .setError("Some error") + .build(), + new JCSMPException("Some JCSMPException"), + Instant.now().toEpochMilli()); + } else { + handler.handleErrorEx( + msg.getMessageId(), + new JCSMPException("Some JCSMPException"), + Instant.now().toEpochMilli()); + } + } + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java index a4d6a42ef302..bd52dee7ea86 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionService.java @@ -17,38 +17,63 @@ */ package org.apache.beam.sdk.io.solace; +import com.google.auto.value.AutoValue; import com.solacesystems.jcsmp.BytesXMLMessage; import com.solacesystems.jcsmp.JCSMPProperties; import java.io.IOException; -import java.io.Serializable; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import org.apache.beam.sdk.io.solace.MockProducer.MockSuccessProducer; import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.broker.MessageProducer; import org.apache.beam.sdk.io.solace.broker.MessageReceiver; +import org.apache.beam.sdk.io.solace.broker.PublishResultHandler; import org.apache.beam.sdk.io.solace.broker.SessionService; +import org.apache.beam.sdk.io.solace.data.Solace.PublishResult; import org.apache.beam.sdk.transforms.SerializableFunction; import org.checkerframework.checker.nullness.qual.Nullable; -public class MockSessionService extends SessionService { +@AutoValue +public abstract class MockSessionService extends SessionService { + public static int ackWindowSizeForTesting = 87; + public static boolean callbackOnReactor = true; - private final SerializableFunction getRecordFn; - private MessageReceiver messageReceiver = null; - private final int minMessagesReceived; - private final @Nullable SubmissionMode mode; - - public MockSessionService( - SerializableFunction getRecordFn, - int minMessagesReceived, - @Nullable SubmissionMode mode) { - this.getRecordFn = getRecordFn; - this.minMessagesReceived = minMessagesReceived; - this.mode = mode; + public abstract @Nullable SerializableFunction recordFn(); + + public abstract int minMessagesReceived(); + + public abstract @Nullable SubmissionMode mode(); + + public abstract Function mockProducerFn(); + + private final Queue publishedResultsReceiver = new ConcurrentLinkedQueue<>(); + + public static Builder builder() { + return new AutoValue_MockSessionService.Builder() + .minMessagesReceived(0) + .mockProducerFn(MockSuccessProducer::new); } - public MockSessionService( - SerializableFunction getRecordFn, int minMessagesReceived) { - this(getRecordFn, minMessagesReceived, null); + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder recordFn( + @Nullable SerializableFunction recordFn); + + public abstract Builder minMessagesReceived(int minMessagesReceived); + + public abstract Builder mode(@Nullable SubmissionMode mode); + + public abstract Builder mockProducerFn( + Function mockProducerFn); + + public abstract MockSessionService build(); } + private MessageReceiver messageReceiver = null; + private MockProducer messageProducer = null; + @Override public void close() {} @@ -58,17 +83,41 @@ public boolean isClosed() { } @Override - public MessageReceiver createReceiver() { + public MessageReceiver getReceiver() { if (messageReceiver == null) { - messageReceiver = new MockReceiver(getRecordFn, minMessagesReceived); + messageReceiver = new MockReceiver(recordFn(), minMessagesReceived()); } return messageReceiver; } + @Override + public MessageProducer getInitializedProducer(SubmissionMode mode) { + if (messageProducer == null) { + messageProducer = mockProducerFn().apply(new PublishResultHandler(publishedResultsReceiver)); + } + return messageProducer; + } + + @Override + public Queue getPublishedResultsQueue() { + return publishedResultsReceiver; + } + @Override public void connect() {} - public static class MockReceiver implements MessageReceiver, Serializable { + @Override + public JCSMPProperties initializeSessionProperties(JCSMPProperties baseProperties) { + // Let's override some properties that will be overriden by the connector + // Opposite of the mode, to test that is overriden + baseProperties.setProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, callbackOnReactor); + + baseProperties.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, ackWindowSizeForTesting); + + return baseProperties; + } + + public static class MockReceiver implements MessageReceiver { private final AtomicInteger counter = new AtomicInteger(); private final SerializableFunction getRecordFn; private final int minMessagesReceived; @@ -100,16 +149,4 @@ public boolean isEOF() { return counter.get() >= minMessagesReceived; } } - - @Override - public JCSMPProperties initializeSessionProperties(JCSMPProperties baseProperties) { - // Let's override some properties that will be overriden by the connector - // Opposite of the mode, to test that is overriden - baseProperties.setProperty( - JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR, mode == SubmissionMode.HIGHER_THROUGHPUT); - - baseProperties.setProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE, 87); - - return baseProperties; - } } diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionServiceFactory.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionServiceFactory.java index 603a30ad2c90..9c17ca604201 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionServiceFactory.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/MockSessionServiceFactory.java @@ -17,22 +17,78 @@ */ package org.apache.beam.sdk.io.solace; +import com.google.auto.value.AutoValue; +import com.solacesystems.jcsmp.BytesXMLMessage; +import org.apache.beam.sdk.io.solace.MockProducer.MockFailedProducer; +import org.apache.beam.sdk.io.solace.MockProducer.MockSuccessProducer; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; import org.apache.beam.sdk.io.solace.broker.SessionService; import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.checkerframework.checker.nullness.qual.Nullable; -public class MockSessionServiceFactory extends SessionServiceFactory { - SessionService sessionService; +@AutoValue +public abstract class MockSessionServiceFactory extends SessionServiceFactory { + public abstract @Nullable SubmissionMode mode(); - public MockSessionServiceFactory(SessionService clientService) { - this.sessionService = clientService; + public abstract @Nullable SerializableFunction recordFn(); + + public abstract int minMessagesReceived(); + + public abstract SessionServiceType sessionServiceType(); + + public static Builder builder() { + return new AutoValue_MockSessionServiceFactory.Builder() + .minMessagesReceived(0) + .sessionServiceType(SessionServiceType.WITH_SUCCEEDING_PRODUCER); } public static SessionServiceFactory getDefaultMock() { - return new MockSessionServiceFactory(new MockEmptySessionService()); + return MockSessionServiceFactory.builder().build(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder mode(@Nullable SubmissionMode mode); + + public abstract Builder recordFn( + @Nullable SerializableFunction recordFn); + + public abstract Builder minMessagesReceived(int minMessagesReceived); + + public abstract Builder sessionServiceType(SessionServiceType sessionServiceType); + + public abstract MockSessionServiceFactory build(); } @Override public SessionService create() { - return sessionService; + switch (sessionServiceType()) { + case EMPTY: + return MockEmptySessionService.create(); + case WITH_SUCCEEDING_PRODUCER: + return MockSessionService.builder() + .recordFn(recordFn()) + .minMessagesReceived(minMessagesReceived()) + .mode(mode()) + .mockProducerFn(MockSuccessProducer::new) + .build(); + case WITH_FAILING_PRODUCER: + return MockSessionService.builder() + .recordFn(recordFn()) + .minMessagesReceived(minMessagesReceived()) + .mode(mode()) + .mockProducerFn(MockFailedProducer::new) + .build(); + default: + throw new RuntimeException( + String.format("Unknown sessionServiceType: %s", sessionServiceType().name())); + } + } + + public enum SessionServiceType { + EMPTY, + WITH_SUCCEEDING_PRODUCER, + WITH_FAILING_PRODUCER } } diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOTest.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOReadTest.java similarity index 72% rename from sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOTest.java rename to sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOReadTest.java index cc1fa1d667aa..c718c55e1b48 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOTest.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOReadTest.java @@ -31,10 +31,12 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; +import org.apache.beam.sdk.io.solace.MockSessionServiceFactory.SessionServiceType; import org.apache.beam.sdk.io.solace.SolaceIO.Read; import org.apache.beam.sdk.io.solace.SolaceIO.Read.Configuration; import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; @@ -49,6 +51,7 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; @@ -61,7 +64,7 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) -public class SolaceIOTest { +public class SolaceIOReadTest { @Rule public final transient TestPipeline pipeline = TestPipeline.create(); @@ -69,7 +72,6 @@ private Read getDefaultRead() { return SolaceIO.read() .from(Solace.Queue.fromName("queue")) .withSempClientFactory(MockSempClientFactory.getDefaultMock()) - .withSessionServiceFactory(MockSessionServiceFactory.getDefaultMock()) .withMaxNumConnections(1); } @@ -77,7 +79,6 @@ private Read getDefaultReadForTopic() { return SolaceIO.read() .from(Solace.Topic.fromName("topic")) .withSempClientFactory(MockSempClientFactory.getDefaultMock()) - .withSessionServiceFactory(MockSessionServiceFactory.getDefaultMock()) .withMaxNumConnections(1); } @@ -102,20 +103,18 @@ private static UnboundedSolaceSource getSource(Read spec, TestPi @Test public void testReadMessages() { // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = - ImmutableList.of( - SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), - SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), - SolaceDataUtils.getBytesXmlMessage("payload_test2", "452")); - return getOrNull(index, messages); - }, - 3); + SerializableFunction recordFn = + index -> { + List messages = + ImmutableList.of( + SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), + SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), + SolaceDataUtils.getBytesXmlMessage("payload_test2", "452")); + return getOrNull(index, messages); + }; SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().minMessagesReceived(3).recordFn(recordFn).build(); // Expected data List expected = new ArrayList<>(); @@ -137,20 +136,18 @@ public void testReadMessages() { @Test public void testReadMessagesWithDeduplication() { // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = - ImmutableList.of( - SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), - SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), - SolaceDataUtils.getBytesXmlMessage("payload_test2", "451")); - return getOrNull(index, messages); - }, - 3); + SerializableFunction recordFn = + index -> { + List messages = + ImmutableList.of( + SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), + SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), + SolaceDataUtils.getBytesXmlMessage("payload_test2", "451")); + return getOrNull(index, messages); + }; SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(3).build(); // Expected data List expected = new ArrayList<>(); @@ -172,19 +169,18 @@ public void testReadMessagesWithDeduplication() { @Test public void testReadMessagesWithoutDeduplication() { // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = - ImmutableList.of( - SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), - SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), - SolaceDataUtils.getBytesXmlMessage("payload_test2", "451")); - return getOrNull(index, messages); - }, - 3); + SerializableFunction recordFn = + index -> { + List messages = + ImmutableList.of( + SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), + SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), + SolaceDataUtils.getBytesXmlMessage("payload_test2", "451")); + return getOrNull(index, messages); + }; + SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(3).build(); // Expected data List expected = new ArrayList<>(); @@ -206,32 +202,38 @@ public void testReadMessagesWithoutDeduplication() { @Test public void testReadMessagesWithDeduplicationOnReplicationGroupMessageId() { // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = - ImmutableList.of( - SolaceDataUtils.getBytesXmlMessage( - "payload_test0", null, null, new ReplicationGroupMessageIdImpl(2L, 1L)), - SolaceDataUtils.getBytesXmlMessage( - "payload_test1", null, null, new ReplicationGroupMessageIdImpl(2L, 2L)), - SolaceDataUtils.getBytesXmlMessage( - "payload_test2", null, null, new ReplicationGroupMessageIdImpl(2L, 2L))); - return getOrNull(index, messages); - }, - 3); + + String id0 = UUID.randomUUID().toString(); + String id1 = UUID.randomUUID().toString(); + String id2 = UUID.randomUUID().toString(); + + SerializableFunction recordFn = + index -> { + List messages = + ImmutableList.of( + SolaceDataUtils.getBytesXmlMessage( + "payload_test0", id0, null, new ReplicationGroupMessageIdImpl(2L, 1L)), + SolaceDataUtils.getBytesXmlMessage( + "payload_test1", id1, null, new ReplicationGroupMessageIdImpl(2L, 2L)), + SolaceDataUtils.getBytesXmlMessage( + "payload_test2", id2, null, new ReplicationGroupMessageIdImpl(2L, 2L))); + return getOrNull(index, messages); + }; SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(3).build(); // Expected data List expected = new ArrayList<>(); expected.add( SolaceDataUtils.getSolaceRecord( - "payload_test0", null, new ReplicationGroupMessageIdImpl(2L, 1L))); + "payload_test0", id0, new ReplicationGroupMessageIdImpl(2L, 1L))); + expected.add( + SolaceDataUtils.getSolaceRecord( + "payload_test1", id1, new ReplicationGroupMessageIdImpl(2L, 2L))); expected.add( SolaceDataUtils.getSolaceRecord( - "payload_test1", null, new ReplicationGroupMessageIdImpl(2L, 2L))); + "payload_test2", id2, new ReplicationGroupMessageIdImpl(2L, 2L))); // Run the pipeline PCollection events = @@ -248,19 +250,18 @@ public void testReadMessagesWithDeduplicationOnReplicationGroupMessageId() { @Test public void testReadWithCoderAndParseFnAndTimestampFn() { // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = - ImmutableList.of( - SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), - SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), - SolaceDataUtils.getBytesXmlMessage("payload_test2", "452")); - return getOrNull(index, messages); - }, - 3); + SerializableFunction recordFn = + index -> { + List messages = + ImmutableList.of( + SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), + SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), + SolaceDataUtils.getBytesXmlMessage("payload_test2", "452")); + return getOrNull(index, messages); + }; + SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(3).build(); // Expected data List expected = new ArrayList<>(); @@ -304,7 +305,10 @@ public void testSplitsForExclusiveQueue() throws Exception { SolaceIO.read() .from(Solace.Queue.fromName("queue")) .withSempClientFactory(new MockSempClientFactory(mockSempClient)) - .withSessionServiceFactory(MockSessionServiceFactory.getDefaultMock()); + .withSessionServiceFactory( + MockSessionServiceFactory.builder() + .sessionServiceType(SessionServiceType.EMPTY) + .build()); int desiredNumSplits = 5; @@ -316,7 +320,10 @@ public void testSplitsForExclusiveQueue() throws Exception { @Test public void testSplitsForNonExclusiveQueueWithMaxNumConnections() throws Exception { - Read spec = getDefaultRead().withMaxNumConnections(3); + Read spec = + getDefaultRead() + .withMaxNumConnections(3) + .withSessionServiceFactory(MockSessionServiceFactory.getDefaultMock()); int desiredNumSplits = 5; @@ -328,7 +335,10 @@ public void testSplitsForNonExclusiveQueueWithMaxNumConnections() throws Excepti @Test public void testSplitsForNonExclusiveQueueWithMaxNumConnectionsRespectDesired() throws Exception { - Read spec = getDefaultRead().withMaxNumConnections(10); + Read spec = + getDefaultRead() + .withMaxNumConnections(10) + .withSessionServiceFactory(MockSessionServiceFactory.getDefaultMock()); int desiredNumSplits = 5; UnboundedSolaceSource initialSource = getSource(spec, pipeline); @@ -346,7 +356,9 @@ public void testCreateQueueForTopic() { .build(); Read spec = - getDefaultReadForTopic().withSempClientFactory(new MockSempClientFactory(mockSempClient)); + getDefaultReadForTopic() + .withSempClientFactory(new MockSempClientFactory(mockSempClient)) + .withSessionServiceFactory(MockSessionServiceFactory.getDefaultMock()); spec.expand(PBegin.in(TestPipeline.create())); // check if createQueueForTopic was executed assertEquals(1, createQueueForTopicFnCounter.get()); @@ -358,22 +370,22 @@ public void testCheckpointMark() throws Exception { AtomicInteger countAckMessages = new AtomicInteger(0); // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = new ArrayList<>(); - for (int i = 0; i < 10; i++) { - messages.add( - SolaceDataUtils.getBytesXmlMessage( - "payload_test" + i, "45" + i, (num) -> countAckMessages.incrementAndGet())); - } - countConsumedMessages.incrementAndGet(); - return getOrNull(index, messages); - }, - 10); + + SerializableFunction recordFn = + index -> { + List messages = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + messages.add( + SolaceDataUtils.getBytesXmlMessage( + "payload_test" + i, "45" + i, (num) -> countAckMessages.incrementAndGet())); + } + countConsumedMessages.incrementAndGet(); + return getOrNull(index, messages); + }; SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(10).build(); + Read spec = getDefaultRead().withSessionServiceFactory(fakeSessionServiceFactory); UnboundedSolaceSource initialSource = getSource(spec, pipeline); @@ -407,21 +419,20 @@ public void testCheckpointMarkAndFinalizeSeparately() throws Exception { AtomicInteger countAckMessages = new AtomicInteger(0); // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = new ArrayList<>(); - for (int i = 0; i < 10; i++) { - messages.add( - SolaceDataUtils.getBytesXmlMessage( - "payload_test" + i, "45" + i, (num) -> countAckMessages.incrementAndGet())); - } - countConsumedMessages.incrementAndGet(); - return getOrNull(index, messages); - }, - 10); + SerializableFunction recordFn = + index -> { + List messages = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + messages.add( + SolaceDataUtils.getBytesXmlMessage( + "payload_test" + i, "45" + i, (num) -> countAckMessages.incrementAndGet())); + } + countConsumedMessages.incrementAndGet(); + return getOrNull(index, messages); + }; + SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(10).build(); Read spec = getDefaultRead() @@ -467,22 +478,21 @@ public void testCheckpointMarkSafety() throws Exception { AtomicInteger countAckMessages = new AtomicInteger(0); // Broker that creates input data - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = new ArrayList<>(); - for (int i = 0; i < messagesToProcess; i++) { - messages.add( - SolaceDataUtils.getBytesXmlMessage( - "payload_test" + i, "45" + i, (num) -> countAckMessages.incrementAndGet())); - } - countConsumedMessages.incrementAndGet(); - return getOrNull(index, messages); - }, - 10); + SerializableFunction recordFn = + index -> { + List messages = new ArrayList<>(); + for (int i = 0; i < messagesToProcess; i++) { + messages.add( + SolaceDataUtils.getBytesXmlMessage( + "payload_test" + i, "45" + i, (num) -> countAckMessages.incrementAndGet())); + } + countConsumedMessages.incrementAndGet(); + return getOrNull(index, messages); + }; SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(10).build(); + Read spec = getDefaultRead() .withSessionServiceFactory(fakeSessionServiceFactory) @@ -558,20 +568,18 @@ public void testDestinationTopicQueueCreation() { @Test public void testTopicEncoding() { - MockSessionService mockClientService = - new MockSessionService( - index -> { - List messages = - ImmutableList.of( - SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), - SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), - SolaceDataUtils.getBytesXmlMessage("payload_test2", "452")); - return getOrNull(index, messages); - }, - 3); + SerializableFunction recordFn = + index -> { + List messages = + ImmutableList.of( + SolaceDataUtils.getBytesXmlMessage("payload_test0", "450"), + SolaceDataUtils.getBytesXmlMessage("payload_test1", "451"), + SolaceDataUtils.getBytesXmlMessage("payload_test2", "452")); + return getOrNull(index, messages); + }; SessionServiceFactory fakeSessionServiceFactory = - new MockSessionServiceFactory(mockClientService); + MockSessionServiceFactory.builder().recordFn(recordFn).minMessagesReceived(3).build(); // Run PCollection events = diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOWriteTest.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOWriteTest.java new file mode 100644 index 000000000000..e92657c3c3d2 --- /dev/null +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/SolaceIOWriteTest.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.solace; + +import static org.apache.beam.sdk.values.TypeDescriptors.strings; + +import com.solacesystems.jcsmp.DeliveryMode; +import java.util.List; +import java.util.Objects; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.io.solace.MockSessionServiceFactory.SessionServiceType; +import org.apache.beam.sdk.io.solace.SolaceIO.SubmissionMode; +import org.apache.beam.sdk.io.solace.SolaceIO.WriterType; +import org.apache.beam.sdk.io.solace.broker.SessionServiceFactory; +import org.apache.beam.sdk.io.solace.data.Solace; +import org.apache.beam.sdk.io.solace.data.Solace.Record; +import org.apache.beam.sdk.io.solace.data.SolaceDataUtils; +import org.apache.beam.sdk.io.solace.write.SolaceOutput; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.errorhandling.BadRecord; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandler; +import org.apache.beam.sdk.transforms.errorhandling.ErrorHandlingTestUtils.ErrorSinkTransform; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class SolaceIOWriteTest { + + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); + + private final List keys = ImmutableList.of("450", "451", "452"); + private final List payloads = ImmutableList.of("payload0", "payload1", "payload2"); + + private PCollection getRecords(Pipeline p) { + TestStream.Builder> kvBuilder = + TestStream.create(KvCoder.of(AvroCoder.of(String.class), AvroCoder.of(String.class))) + .advanceWatermarkTo(Instant.EPOCH); + + assert keys.size() == payloads.size(); + + for (int k = 0; k < keys.size(); k++) { + kvBuilder = + kvBuilder + .addElements(KV.of(keys.get(k), payloads.get(k))) + .advanceProcessingTime(Duration.standardSeconds(60)); + } + + TestStream> testStream = kvBuilder.advanceWatermarkToInfinity(); + PCollection> kvs = p.apply("Test stream", testStream); + + return kvs.apply( + "To Record", + MapElements.into(TypeDescriptor.of(Record.class)) + .via(kv -> SolaceDataUtils.getSolaceRecord(kv.getValue(), kv.getKey()))); + } + + private SolaceOutput getWriteTransform( + SubmissionMode mode, + WriterType writerType, + Pipeline p, + ErrorHandler errorHandler) { + SessionServiceFactory fakeSessionServiceFactory = + MockSessionServiceFactory.builder().mode(mode).build(); + + PCollection records = getRecords(p); + return records.apply( + "Write to Solace", + SolaceIO.write() + .to(Solace.Queue.fromName("queue")) + .withSubmissionMode(mode) + .withWriterType(writerType) + .withDeliveryMode(DeliveryMode.PERSISTENT) + .withSessionServiceFactory(fakeSessionServiceFactory) + .withErrorHandler(errorHandler)); + } + + private static PCollection getIdsPCollection(SolaceOutput output) { + return output + .getSuccessfulPublish() + .apply( + "Get message ids", MapElements.into(strings()).via(Solace.PublishResult::getMessageId)); + } + + @Test + public void testWriteLatencyStreaming() throws Exception { + SubmissionMode mode = SubmissionMode.LOWER_LATENCY; + WriterType writerType = WriterType.STREAMING; + + ErrorHandler> errorHandler = + pipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); + SolaceOutput output = getWriteTransform(mode, writerType, pipeline, errorHandler); + PCollection ids = getIdsPCollection(output); + + PAssert.that(ids).containsInAnyOrder(keys); + errorHandler.close(); + PAssert.that(errorHandler.getOutput()).empty(); + + pipeline.run(); + } + + @Test + public void testWriteThroughputStreaming() throws Exception { + SubmissionMode mode = SubmissionMode.HIGHER_THROUGHPUT; + WriterType writerType = WriterType.STREAMING; + ErrorHandler> errorHandler = + pipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); + SolaceOutput output = getWriteTransform(mode, writerType, pipeline, errorHandler); + PCollection ids = getIdsPCollection(output); + + PAssert.that(ids).containsInAnyOrder(keys); + errorHandler.close(); + PAssert.that(errorHandler.getOutput()).empty(); + + pipeline.run(); + } + + @Test + public void testWriteLatencyBatched() throws Exception { + SubmissionMode mode = SubmissionMode.LOWER_LATENCY; + WriterType writerType = WriterType.BATCHED; + ErrorHandler> errorHandler = + pipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); + SolaceOutput output = getWriteTransform(mode, writerType, pipeline, errorHandler); + PCollection ids = getIdsPCollection(output); + + PAssert.that(ids).containsInAnyOrder(keys); + errorHandler.close(); + PAssert.that(errorHandler.getOutput()).empty(); + pipeline.run(); + } + + @Test + public void testWriteThroughputBatched() throws Exception { + SubmissionMode mode = SubmissionMode.HIGHER_THROUGHPUT; + WriterType writerType = WriterType.BATCHED; + ErrorHandler> errorHandler = + pipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); + SolaceOutput output = getWriteTransform(mode, writerType, pipeline, errorHandler); + PCollection ids = getIdsPCollection(output); + + PAssert.that(ids).containsInAnyOrder(keys); + errorHandler.close(); + PAssert.that(errorHandler.getOutput()).empty(); + pipeline.run(); + } + + @Test + public void testWriteWithFailedRecords() throws Exception { + SubmissionMode mode = SubmissionMode.HIGHER_THROUGHPUT; + WriterType writerType = WriterType.BATCHED; + ErrorHandler> errorHandler = + pipeline.registerBadRecordErrorHandler(new ErrorSinkTransform()); + + SessionServiceFactory fakeSessionServiceFactory = + MockSessionServiceFactory.builder() + .mode(mode) + .sessionServiceType(SessionServiceType.WITH_FAILING_PRODUCER) + .build(); + + PCollection records = getRecords(pipeline); + SolaceOutput output = + records.apply( + "Write to Solace", + SolaceIO.write() + .to(Solace.Queue.fromName("queue")) + .withSubmissionMode(mode) + .withWriterType(writerType) + .withDeliveryMode(DeliveryMode.PERSISTENT) + .withSessionServiceFactory(fakeSessionServiceFactory) + .withErrorHandler(errorHandler)); + + PCollection ids = getIdsPCollection(output); + + PAssert.that(ids).empty(); + errorHandler.close(); + PAssert.thatSingleton(Objects.requireNonNull(errorHandler.getOutput())) + .isEqualTo((long) payloads.size()); + pipeline.run(); + } +} diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/OverrideWriterPropertiesTest.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/OverrideWriterPropertiesTest.java index 0c6f88a7c9d5..357734f18aad 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/OverrideWriterPropertiesTest.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/broker/OverrideWriterPropertiesTest.java @@ -31,9 +31,8 @@ public class OverrideWriterPropertiesTest { @Test public void testOverrideForHigherThroughput() { SolaceIO.SubmissionMode mode = SolaceIO.SubmissionMode.HIGHER_THROUGHPUT; - MockSessionService service = new MockSessionService(null, 0, mode); + MockSessionService service = MockSessionService.builder().mode(mode).build(); - // Test HIGHER_THROUGHPUT mode JCSMPProperties props = service.initializeWriteSessionProperties(mode); assertEquals(false, props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR)); assertEquals( @@ -44,13 +43,26 @@ public void testOverrideForHigherThroughput() { @Test public void testOverrideForLowerLatency() { SolaceIO.SubmissionMode mode = SolaceIO.SubmissionMode.LOWER_LATENCY; - MockSessionService service = new MockSessionService(null, 0, mode); + MockSessionService service = MockSessionService.builder().mode(mode).build(); - // Test HIGHER_THROUGHPUT mode JCSMPProperties props = service.initializeWriteSessionProperties(mode); assertEquals(true, props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR)); assertEquals( Long.valueOf(50), Long.valueOf(props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE))); } + + @Test + public void testDontOverrideForCustom() { + SolaceIO.SubmissionMode mode = SolaceIO.SubmissionMode.CUSTOM; + MockSessionService service = MockSessionService.builder().mode(mode).build(); + + JCSMPProperties props = service.initializeWriteSessionProperties(mode); + assertEquals( + MockSessionService.callbackOnReactor, + props.getBooleanProperty(JCSMPProperties.MESSAGE_CALLBACK_ON_REACTOR)); + assertEquals( + Long.valueOf(MockSessionService.ackWindowSizeForTesting), + Long.valueOf(props.getIntegerProperty(JCSMPProperties.PUB_ACK_WINDOW_SIZE))); + } } diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/data/SolaceDataUtils.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/data/SolaceDataUtils.java index 5134bd131d73..9e04c4cfd276 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/data/SolaceDataUtils.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/data/SolaceDataUtils.java @@ -100,7 +100,7 @@ public static Solace.Record getSolaceRecord( : DEFAULT_REPLICATION_GROUP_ID.toString(); return Solace.Record.builder() - .setPayload(ByteBuffer.wrap(payload.getBytes(StandardCharsets.UTF_8))) + .setPayload(payload.getBytes(StandardCharsets.UTF_8)) .setMessageId(messageId) .setDestination( Solace.Destination.builder() @@ -116,7 +116,7 @@ public static Solace.Record getSolaceRecord( .setTimeToLive(1000L) .setSenderTimestamp(null) .setReplicationGroupMessageId(replicationGroupMessageIdString) - .setAttachmentBytes(ByteBuffer.wrap(new byte[0])) + .setAttachmentBytes(new byte[0]) .build(); } diff --git a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOIT.java b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOIT.java index 1a2a056efd45..ee5d206533dc 100644 --- a/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOIT.java +++ b/sdks/java/io/solace/src/test/java/org/apache/beam/sdk/io/solace/it/SolaceIOIT.java @@ -17,49 +17,71 @@ */ package org.apache.beam.sdk.io.solace.it; +import static org.apache.beam.sdk.io.solace.it.SolaceContainerManager.TOPIC_NAME; +import static org.apache.beam.sdk.values.TypeDescriptors.strings; import static org.junit.Assert.assertEquals; +import com.solacesystems.jcsmp.DeliveryMode; import java.io.IOException; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; import org.apache.beam.sdk.io.solace.SolaceIO; import org.apache.beam.sdk.io.solace.broker.BasicAuthJcsmpSessionServiceFactory; import org.apache.beam.sdk.io.solace.broker.BasicAuthSempClientFactory; +import org.apache.beam.sdk.io.solace.data.Solace; import org.apache.beam.sdk.io.solace.data.Solace.Queue; +import org.apache.beam.sdk.io.solace.data.SolaceDataUtils; +import org.apache.beam.sdk.io.solace.write.SolaceOutput; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.testutils.metrics.MetricsReader; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; import org.joda.time.Duration; +import org.joda.time.Instant; import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.FixMethodOrder; import org.junit.Rule; import org.junit.Test; +import org.junit.runners.MethodSorters; +@FixMethodOrder(MethodSorters.NAME_ASCENDING) public class SolaceIOIT { private static final String NAMESPACE = SolaceIOIT.class.getName(); private static final String READ_COUNT = "read_count"; + private static final String WRITE_COUNT = "write_count"; private static SolaceContainerManager solaceContainerManager; - private static final TestPipelineOptions readPipelineOptions; + private static final String queueName = "test_queue"; + private static final TestPipelineOptions pipelineOptions; + private static final long PUBLISH_MESSAGE_COUNT = 20; static { - readPipelineOptions = PipelineOptionsFactory.create().as(TestPipelineOptions.class); - readPipelineOptions.setBlockOnRun(false); - readPipelineOptions.as(TestPipelineOptions.class).setBlockOnRun(false); - readPipelineOptions.as(StreamingOptions.class).setStreaming(false); + pipelineOptions = PipelineOptionsFactory.create().as(TestPipelineOptions.class); + pipelineOptions.as(StreamingOptions.class).setStreaming(true); + // For the read connector tests, we need to make sure that p.run() does not block + pipelineOptions.setBlockOnRun(false); + pipelineOptions.as(TestPipelineOptions.class).setBlockOnRun(false); } - @Rule public final TestPipeline readPipeline = TestPipeline.fromOptions(readPipelineOptions); + @Rule public final TestPipeline pipeline = TestPipeline.fromOptions(pipelineOptions); @BeforeClass public static void setup() throws IOException { solaceContainerManager = new SolaceContainerManager(); solaceContainerManager.start(); + solaceContainerManager.createQueueWithSubscriptionTopic(queueName); } @AfterClass @@ -69,20 +91,17 @@ public static void afterClass() { } } + // The order of the following tests matter. The first test publishes some messages in a Solace + // queue, and those messages are read by the second test. If another writer tests is run before + // the read test, that will alter the count for the read test and will make it fail. @Test - public void testRead() { - String queueName = "test_queue"; - solaceContainerManager.createQueueWithSubscriptionTopic(queueName); - - // todo this is very slow, needs to be replaced with the SolaceIO.write connector. - int publishMessagesCount = 20; - for (int i = 0; i < publishMessagesCount; i++) { - solaceContainerManager.sendToTopic( - "{\"field_str\":\"value\",\"field_int\":123}", - ImmutableList.of("Solace-Message-ID:m" + i)); - } + public void test01WriteStreaming() { + testWriteConnector(SolaceIO.WriterType.STREAMING); + } - readPipeline + @Test + public void test02Read() { + pipeline .apply( "Read from Solace", SolaceIO.read() @@ -105,12 +124,83 @@ public void testRead() { .build())) .apply("Count", ParDo.of(new CountingFn<>(NAMESPACE, READ_COUNT))); - PipelineResult pipelineResult = readPipeline.run(); + PipelineResult pipelineResult = pipeline.run(); + // We need enough time for Beam to pull all messages from the queue, but we need a timeout too, + // as the Read connector will keep attempting to read forever. pipelineResult.waitUntilFinish(Duration.standardSeconds(15)); MetricsReader metricsReader = new MetricsReader(pipelineResult, NAMESPACE); long actualRecordsCount = metricsReader.getCounterMetric(READ_COUNT); - assertEquals(publishMessagesCount, actualRecordsCount); + assertEquals(PUBLISH_MESSAGE_COUNT, actualRecordsCount); + } + + @Test + public void test03WriteBatched() { + testWriteConnector(SolaceIO.WriterType.BATCHED); + } + + private void testWriteConnector(SolaceIO.WriterType writerType) { + Pipeline p = createWriterPipeline(writerType); + + PipelineResult pipelineResult = p.run(); + pipelineResult.waitUntilFinish(); + MetricsReader metricsReader = new MetricsReader(pipelineResult, NAMESPACE); + long actualRecordsCount = metricsReader.getCounterMetric(WRITE_COUNT); + assertEquals(PUBLISH_MESSAGE_COUNT, actualRecordsCount); + } + + private Pipeline createWriterPipeline(SolaceIO.WriterType writerType) { + TestStream.Builder> kvBuilder = + TestStream.create(KvCoder.of(AvroCoder.of(String.class), AvroCoder.of(String.class))) + .advanceWatermarkTo(Instant.EPOCH); + + for (int i = 0; i < PUBLISH_MESSAGE_COUNT; i++) { + String key = "Solace-Message-ID:m" + i; + String payload = String.format("{\"field_str\":\"value\",\"field_int\":123%d}", i); + kvBuilder = + kvBuilder + .addElements(KV.of(key, payload)) + .advanceProcessingTime(Duration.standardSeconds(60)); + } + + TestStream> testStream = kvBuilder.advanceWatermarkToInfinity(); + + PCollection> kvs = + pipeline.apply(String.format("Test stream %s", writerType), testStream); + + PCollection records = + kvs.apply( + String.format("To Record %s", writerType), + MapElements.into(TypeDescriptor.of(Solace.Record.class)) + .via(kv -> SolaceDataUtils.getSolaceRecord(kv.getValue(), kv.getKey()))); + + SolaceOutput result = + records.apply( + String.format("Write to Solace %s", writerType), + SolaceIO.write() + .to(Solace.Topic.fromName(TOPIC_NAME)) + .withSubmissionMode(SolaceIO.SubmissionMode.TESTING) + .withWriterType(writerType) + .withDeliveryMode(DeliveryMode.PERSISTENT) + .withNumberOfClientsPerWorker(1) + .withNumShards(1) + .withSessionServiceFactory( + BasicAuthJcsmpSessionServiceFactory.builder() + .host("localhost:" + solaceContainerManager.jcsmpPortMapped) + .username(SolaceContainerManager.USERNAME) + .password(SolaceContainerManager.PASSWORD) + .vpnName(SolaceContainerManager.VPN_NAME) + .build())); + result + .getSuccessfulPublish() + .apply( + String.format("Get ids %s", writerType), + MapElements.into(strings()).via(Solace.PublishResult::getMessageId)) + .apply( + String.format("Count %s", writerType), + ParDo.of(new CountingFn<>(NAMESPACE, WRITE_COUNT))); + + return pipeline; } private static class CountingFn extends DoFn { From c6436458f2828ea7f63936fb4939e68e736ea7a6 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Wed, 13 Nov 2024 11:13:21 -0500 Subject: [PATCH 170/181] fix imports --- sdks/python/apache_beam/dataframe/transforms.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/dataframe/transforms.py b/sdks/python/apache_beam/dataframe/transforms.py index c8ac8174232d..8f27ce95c294 100644 --- a/sdks/python/apache_beam/dataframe/transforms.py +++ b/sdks/python/apache_beam/dataframe/transforms.py @@ -17,7 +17,6 @@ import collections import logging from collections.abc import Mapping -from typing import TYPE_CHECKING from typing import Any from typing import TypeVar from typing import Union @@ -29,16 +28,13 @@ from apache_beam.dataframe import expressions from apache_beam.dataframe import frames # pylint: disable=unused-import from apache_beam.dataframe import partitionings +from apache_beam.pvalue import PCollection from apache_beam.utils import windowed_value __all__ = [ 'DataframeTransform', ] -if TYPE_CHECKING: - # pylint: disable=ungrouped-imports - from apache_beam.pvalue import PCollection - T = TypeVar('T') TARGET_PARTITION_SIZE = 1 << 23 # 8M From 306c6d742e1386a9017a1ce48a0c20d21a28ee0a Mon Sep 17 00:00:00 2001 From: Steven van Rossum Date: Wed, 13 Nov 2024 18:22:06 +0100 Subject: [PATCH 171/181] Share AvgRecordSize and KafkaLatestOffsetEstimator caches among DoFns (#32928) --- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 113 ++++++++++++------ 1 file changed, 75 insertions(+), 38 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 26964d43a16f..1cf4aad34e4e 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -26,8 +26,10 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.sdk.coders.Coder; @@ -62,6 +64,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheLoader; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.LoadingCache; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; @@ -207,6 +210,23 @@ private ReadFromKafkaDoFn( private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class); + /** + * A holder class for all construction time unique instances of {@link ReadFromKafkaDoFn}. Caches + * must run clean up tasks when {@link #teardown()} is called. + */ + private static final class SharedStateHolder { + + private static final Map> + OFFSET_ESTIMATOR_CACHE = new ConcurrentHashMap<>(); + private static final Map> + AVG_RECORD_SIZE_CACHE = new ConcurrentHashMap<>(); + } + + private static final AtomicLong FN_ID = new AtomicLong(); + + // A unique identifier for the instance. Generally unique unless the ID generator overflows. + private final long fnId = FN_ID.getAndIncrement(); + private final @Nullable Map offsetConsumerConfig; private final @Nullable CheckStopReadingFn checkStopReadingFn; @@ -599,43 +619,56 @@ public Coder restrictionCoder() { public void setup() throws Exception { // Start to track record size and offset gap per bundle. avgRecordSizeCache = - CacheBuilder.newBuilder() - .maximumSize(1000L) - .build( - new CacheLoader() { - @Override - public AverageRecordSize load(KafkaSourceDescriptor kafkaSourceDescriptor) - throws Exception { - return new AverageRecordSize(); - } - }); + SharedStateHolder.AVG_RECORD_SIZE_CACHE.computeIfAbsent( + fnId, + k -> { + return CacheBuilder.newBuilder() + .maximumSize(1000L) + .build( + new CacheLoader() { + @Override + public AverageRecordSize load(KafkaSourceDescriptor kafkaSourceDescriptor) + throws Exception { + return new AverageRecordSize(); + } + }); + }); keyDeserializerInstance = keyDeserializerProvider.getDeserializer(consumerConfig, true); valueDeserializerInstance = valueDeserializerProvider.getDeserializer(consumerConfig, false); offsetEstimatorCache = - CacheBuilder.newBuilder() - .weakValues() - .expireAfterAccess(1, TimeUnit.MINUTES) - .build( - new CacheLoader() { - @Override - public KafkaLatestOffsetEstimator load( - KafkaSourceDescriptor kafkaSourceDescriptor) throws Exception { - LOG.info( - "Creating Kafka consumer for offset estimation for {}", - kafkaSourceDescriptor); - - TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition(); - Map updatedConsumerConfig = - overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); - Consumer offsetConsumer = - consumerFactoryFn.apply( - KafkaIOUtils.getOffsetConsumerConfig( - "tracker-" + topicPartition, - offsetConsumerConfig, - updatedConsumerConfig)); - return new KafkaLatestOffsetEstimator(offsetConsumer, topicPartition); - } - }); + SharedStateHolder.OFFSET_ESTIMATOR_CACHE.computeIfAbsent( + fnId, + k -> { + final Map consumerConfig = ImmutableMap.copyOf(this.consumerConfig); + final @Nullable Map offsetConsumerConfig = + this.offsetConsumerConfig == null + ? null + : ImmutableMap.copyOf(this.offsetConsumerConfig); + return CacheBuilder.newBuilder() + .weakValues() + .expireAfterAccess(1, TimeUnit.MINUTES) + .build( + new CacheLoader() { + @Override + public KafkaLatestOffsetEstimator load( + KafkaSourceDescriptor kafkaSourceDescriptor) throws Exception { + LOG.info( + "Creating Kafka consumer for offset estimation for {}", + kafkaSourceDescriptor); + + TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition(); + Map updatedConsumerConfig = + overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); + Consumer offsetConsumer = + consumerFactoryFn.apply( + KafkaIOUtils.getOffsetConsumerConfig( + "tracker-" + topicPartition, + offsetConsumerConfig, + updatedConsumerConfig)); + return new KafkaLatestOffsetEstimator(offsetConsumer, topicPartition); + } + }); + }); if (checkStopReadingFn != null) { checkStopReadingFn.setup(); } @@ -643,6 +676,10 @@ public KafkaLatestOffsetEstimator load( @Teardown public void teardown() throws Exception { + final LoadingCache avgRecordSizeCache = + Preconditions.checkStateNotNull(this.avgRecordSizeCache); + final LoadingCache offsetEstimatorCache = + Preconditions.checkStateNotNull(this.offsetEstimatorCache); try { if (valueDeserializerInstance != null) { Closeables.close(valueDeserializerInstance, true); @@ -655,13 +692,13 @@ public void teardown() throws Exception { } catch (Exception anyException) { LOG.warn("Fail to close resource during finishing bundle.", anyException); } - - if (offsetEstimatorCache != null) { - offsetEstimatorCache.invalidateAll(); - } if (checkStopReadingFn != null) { checkStopReadingFn.teardown(); } + + // Allow the cache to perform clean up tasks when this instance is about to be deleted. + avgRecordSizeCache.cleanUp(); + offsetEstimatorCache.cleanUp(); } private Map overrideBootstrapServersConfig( From 50ed69ab560491339409f0b50b5fcd596ff53623 Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com> Date: Wed, 13 Nov 2024 12:23:57 -0500 Subject: [PATCH 172/181] Portable Managed BigQuery destinations (#33017) * managed bigqueryio * spotless * move managed dependency to test only * cleanup after merging snake_case PR * choose write method based on boundedness and pipeline options * rename bigquery write config class * spotless * change read output tag to 'output' * spotless * revert logic that depends on DataflowServiceOptions. switching BQ methods can instead be done in Dataflow service side * spotless * fix typo * separate BQ write config to a new class * fix doc * resolve after syncing to HEAD * spotless * fork on batch/streaming * cleanup * spotless * portable bigquery destinations * move forking logic to BQ schematransform side * add file loads translation and tests; add test checks that the correct transform is chosen * set top-level wrapper to be the underlying managed BQ transform urn; change tests to verify underlying transform name * move unit tests to respectvie schematransform test classes * expose to Python SDK as well * cleanup * address comment * set enable_streaming_engine option; add to CHANGES --- CHANGES.md | 1 + .../io/google-cloud-platform/build.gradle | 1 + ...QueryFileLoadsSchemaTransformProvider.java | 12 +- ...torageWriteApiSchemaTransformProvider.java | 142 ++++++------------ .../providers/BigQueryWriteConfiguration.java | 30 +++- .../PortableBigQueryDestinations.java | 105 +++++++++++++ ...yFileLoadsSchemaTransformProviderTest.java | 39 +++++ .../bigquery/providers/BigQueryManagedIT.java | 91 ++++++++--- ...geWriteApiSchemaTransformProviderTest.java | 51 +++++-- 9 files changed, 327 insertions(+), 145 deletions(-) create mode 100644 sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java diff --git a/CHANGES.md b/CHANGES.md index 6962b0fb8ded..bc7ec096fe33 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -70,6 +70,7 @@ * [Managed Iceberg] Add support for TIMESTAMP, TIME, and DATE types ([#32688](https://github.com/apache/beam/pull/32688)) * BigQuery CDC writes are now available in Python SDK, only supported when using StorageWrite API at least once mode ([#32527](https://github.com/apache/beam/issues/32527)) * [Managed Iceberg] Allow updating table partition specs during pipeline runtime ([#32879](https://github.com/apache/beam/pull/32879)) +* Added BigQueryIO as a Managed IO ([#31486](https://github.com/apache/beam/pull/31486)) * Support for writing to [Solace messages queues](https://solace.com/) (`SolaceIO.Write`) added (Java) ([#31905](https://github.com/apache/beam/issues/31905)). ## New Features / Improvements diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle index 2acce3e94cc2..b8e71e289827 100644 --- a/sdks/java/io/google-cloud-platform/build.gradle +++ b/sdks/java/io/google-cloud-platform/build.gradle @@ -198,6 +198,7 @@ task integrationTest(type: Test, dependsOn: processTestResources) { "--runner=DirectRunner", "--project=${gcpProject}", "--tempRoot=${gcpTempRoot}", + "--tempLocation=${gcpTempRoot}", "--firestoreDb=${firestoreDb}", "--firestoreHost=${firestoreHost}", "--bigtableChangeStreamInstanceId=${bigtableChangeStreamInstanceId}", diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java index 092cf42a29a4..7872c91d1f72 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java @@ -28,6 +28,7 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider; @@ -89,16 +90,19 @@ public static class BigQueryFileLoadsSchemaTransform extends SchemaTransform { @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { PCollection rowPCollection = input.getSinglePCollection(); - BigQueryIO.Write write = toWrite(input.getPipeline().getOptions()); + BigQueryIO.Write write = + toWrite(rowPCollection.getSchema(), input.getPipeline().getOptions()); rowPCollection.apply(write); return PCollectionRowTuple.empty(input.getPipeline()); } - BigQueryIO.Write toWrite(PipelineOptions options) { + BigQueryIO.Write toWrite(Schema schema, PipelineOptions options) { + PortableBigQueryDestinations dynamicDestinations = + new PortableBigQueryDestinations(schema, configuration); BigQueryIO.Write write = BigQueryIO.write() - .to(configuration.getTable()) + .to(dynamicDestinations) .withMethod(BigQueryIO.Write.Method.FILE_LOADS) .withFormatFunction(BigQueryUtils.toTableRow()) // TODO(https://github.com/apache/beam/issues/33074) BatchLoad's @@ -106,7 +110,7 @@ BigQueryIO.Write toWrite(PipelineOptions options) { .withCustomGcsTempLocation( ValueProvider.StaticValueProvider.of(options.getTempLocation())) .withWriteDisposition(WriteDisposition.WRITE_APPEND) - .useBeamSchema(); + .withFormatFunction(dynamicDestinations.getFilterFormatFunction(false)); if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) { CreateDisposition createDisposition = diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java index c45433aaf0e7..1e53ad3553e0 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java @@ -18,15 +18,14 @@ package org.apache.beam.sdk.io.gcp.bigquery.providers; import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteConfiguration.DYNAMIC_DESTINATIONS; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.PortableBigQueryDestinations.DESTINATION; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.PortableBigQueryDestinations.RECORD; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import com.google.api.services.bigquery.model.TableConstraints; -import com.google.api.services.bigquery.model.TableSchema; import com.google.auto.service.AutoService; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Optional; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.Method; @@ -34,9 +33,7 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryStorageApiInsertError; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; -import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations; import org.apache.beam.sdk.io.gcp.bigquery.RowMutationInformation; -import org.apache.beam.sdk.io.gcp.bigquery.TableDestination; import org.apache.beam.sdk.io.gcp.bigquery.WriteResult; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; @@ -54,7 +51,6 @@ import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptors; -import org.apache.beam.sdk.values.ValueInSingleWindow; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.joda.time.Duration; @@ -80,6 +76,7 @@ public class BigQueryStorageWriteApiSchemaTransformProvider private static final String FAILED_ROWS_TAG = "FailedRows"; private static final String FAILED_ROWS_WITH_ERRORS_TAG = "FailedRowsWithErrors"; // magic string that tells us to write to dynamic destinations + protected static final String DYNAMIC_DESTINATIONS = "DYNAMIC_DESTINATIONS"; protected static final String ROW_PROPERTY_MUTATION_INFO = "row_mutation_info"; protected static final String ROW_PROPERTY_MUTATION_TYPE = "mutation_type"; protected static final String ROW_PROPERTY_MUTATION_SQN = "change_sequence_number"; @@ -176,52 +173,6 @@ private static class NoOutputDoFn extends DoFn { public void process(ProcessContext c) {} } - private static class RowDynamicDestinations extends DynamicDestinations { - final Schema schema; - final String fixedDestination; - final List primaryKey; - - RowDynamicDestinations(Schema schema) { - this.schema = schema; - this.fixedDestination = null; - this.primaryKey = null; - } - - public RowDynamicDestinations( - Schema schema, String fixedDestination, List primaryKey) { - this.schema = schema; - this.fixedDestination = fixedDestination; - this.primaryKey = primaryKey; - } - - @Override - public String getDestination(ValueInSingleWindow element) { - return Optional.ofNullable(fixedDestination) - .orElseGet(() -> element.getValue().getString("destination")); - } - - @Override - public TableDestination getTable(String destination) { - return new TableDestination(destination, null); - } - - @Override - public TableSchema getSchema(String destination) { - return BigQueryUtils.toTableSchema(schema); - } - - @Override - public TableConstraints getTableConstraints(String destination) { - return Optional.ofNullable(this.primaryKey) - .filter(pk -> !pk.isEmpty()) - .map( - pk -> - new TableConstraints() - .setPrimaryKey(new TableConstraints.PrimaryKey().setColumns(pk))) - .orElse(null); - } - } - @Override public PCollectionRowTuple expand(PCollectionRowTuple input) { // Check that the input exists @@ -309,13 +260,6 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } } - void validateDynamicDestinationsExpectedSchema(Schema schema) { - checkArgument( - schema.getFieldNames().containsAll(Arrays.asList("destination", "record")), - "When writing to dynamic destinations, we expect Row Schema with a " - + "\"destination\" string field and a \"record\" Row field."); - } - BigQueryIO.Write createStorageWriteApiTransform(Schema schema) { Method writeMethod = configuration.getUseAtLeastOnceSemantics() != null @@ -326,21 +270,37 @@ BigQueryIO.Write createStorageWriteApiTransform(Schema schema) { BigQueryIO.Write write = BigQueryIO.write() .withMethod(writeMethod) - .withFormatFunction(BigQueryUtils.toTableRow()) .withWriteDisposition(WriteDisposition.WRITE_APPEND); - // in case CDC writes are configured we validate and include them in the configuration - if (Optional.ofNullable(configuration.getUseCdcWrites()).orElse(false)) { - write = validateAndIncludeCDCInformation(write, schema); - } else if (configuration.getTable().equals(DYNAMIC_DESTINATIONS)) { - validateDynamicDestinationsExpectedSchema(schema); + Schema rowSchema = schema; + boolean fetchNestedRecord = false; + if (configuration.getTable().equals(DYNAMIC_DESTINATIONS)) { + validateDynamicDestinationsSchema(schema); + rowSchema = schema.getField(RECORD).getType().getRowSchema(); + fetchNestedRecord = true; + } + if (Boolean.TRUE.equals(configuration.getUseCdcWrites())) { + validateCdcSchema(schema); + rowSchema = schema.getField(RECORD).getType().getRowSchema(); + fetchNestedRecord = true; write = write - .to(new RowDynamicDestinations(schema.getField("record").getType().getRowSchema())) - .withFormatFunction(row -> BigQueryUtils.toTableRow(row.getRow("record"))); - } else { - write = write.to(configuration.getTable()).useBeamSchema(); + .withPrimaryKey(configuration.getPrimaryKey()) + .withRowMutationInformationFn( + row -> + RowMutationInformation.of( + RowMutationInformation.MutationType.valueOf( + row.getRow(ROW_PROPERTY_MUTATION_INFO) + .getString(ROW_PROPERTY_MUTATION_TYPE)), + row.getRow(ROW_PROPERTY_MUTATION_INFO) + .getString(ROW_PROPERTY_MUTATION_SQN))); } + PortableBigQueryDestinations dynamicDestinations = + new PortableBigQueryDestinations(rowSchema, configuration); + write = + write + .to(dynamicDestinations) + .withFormatFunction(dynamicDestinations.getFilterFormatFunction(fetchNestedRecord)); if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) { CreateDisposition createDisposition = @@ -363,19 +323,27 @@ BigQueryIO.Write createStorageWriteApiTransform(Schema schema) { return write; } - BigQueryIO.Write validateAndIncludeCDCInformation( - BigQueryIO.Write write, Schema schema) { + void validateDynamicDestinationsSchema(Schema schema) { + checkArgument( + schema.getFieldNames().containsAll(Arrays.asList(DESTINATION, RECORD)), + String.format( + "When writing to dynamic destinations, we expect Row Schema with a " + + "\"%s\" string field and a \"%s\" Row field.", + DESTINATION, RECORD)); + } + + private void validateCdcSchema(Schema schema) { checkArgument( - schema.getFieldNames().containsAll(Arrays.asList(ROW_PROPERTY_MUTATION_INFO, "record")), + schema.getFieldNames().containsAll(Arrays.asList(ROW_PROPERTY_MUTATION_INFO, RECORD)), "When writing using CDC functionality, we expect Row Schema with a " + "\"" + ROW_PROPERTY_MUTATION_INFO + "\" Row field and a \"record\" Row field."); - Schema rowSchema = schema.getField(ROW_PROPERTY_MUTATION_INFO).getType().getRowSchema(); + Schema mutationSchema = schema.getField(ROW_PROPERTY_MUTATION_INFO).getType().getRowSchema(); checkArgument( - rowSchema.equals(ROW_SCHEMA_MUTATION_INFO), + mutationSchema != null && mutationSchema.equals(ROW_SCHEMA_MUTATION_INFO), "When writing using CDC functionality, we expect a \"" + ROW_PROPERTY_MUTATION_INFO + "\" field of Row type with schema:\n" @@ -384,31 +352,7 @@ BigQueryIO.Write validateAndIncludeCDCInformation( + "Received \"" + ROW_PROPERTY_MUTATION_INFO + "\" field with schema:\n" - + rowSchema.toString()); - - String tableDestination = null; - - if (configuration.getTable().equals(DYNAMIC_DESTINATIONS)) { - validateDynamicDestinationsExpectedSchema(schema); - } else { - tableDestination = configuration.getTable(); - } - - return write - .to( - new RowDynamicDestinations( - schema.getField("record").getType().getRowSchema(), - tableDestination, - configuration.getPrimaryKey())) - .withFormatFunction(row -> BigQueryUtils.toTableRow(row.getRow("record"))) - .withPrimaryKey(configuration.getPrimaryKey()) - .withRowMutationInformationFn( - row -> - RowMutationInformation.of( - RowMutationInformation.MutationType.valueOf( - row.getRow(ROW_PROPERTY_MUTATION_INFO) - .getString(ROW_PROPERTY_MUTATION_TYPE)), - row.getRow(ROW_PROPERTY_MUTATION_INFO).getString(ROW_PROPERTY_MUTATION_SQN))); + + mutationSchema); } } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java index 4296da7e0cd5..505ce7125cee 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryWriteConfiguration.java @@ -18,20 +18,18 @@ package org.apache.beam.sdk.io.gcp.bigquery.providers; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import com.google.auto.value.AutoValue; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; -import javax.annotation.Nullable; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; +import org.checkerframework.checker.nullness.qual.Nullable; /** * Configuration for writing to BigQuery with SchemaTransforms. Used by {@link @@ -68,11 +66,6 @@ public void validate() { !Strings.isNullOrEmpty(this.getTable()), invalidConfigMessage + "Table spec for a BigQuery Write must be specified."); - // if we have an input table spec, validate it - if (!this.getTable().equals(DYNAMIC_DESTINATIONS)) { - checkNotNull(BigQueryHelpers.parseTableSpec(this.getTable())); - } - // validate create and write dispositions String createDisposition = getCreateDisposition(); if (createDisposition != null && !createDisposition.isEmpty()) { @@ -186,6 +179,21 @@ public static Builder builder() { @Nullable public abstract List getPrimaryKey(); + @SchemaFieldDescription( + "A list of field names to keep in the input record. All other fields are dropped before writing. " + + "Is mutually exclusive with 'drop' and 'only'.") + public abstract @Nullable List getKeep(); + + @SchemaFieldDescription( + "A list of field names to drop from the input record before writing. " + + "Is mutually exclusive with 'keep' and 'only'.") + public abstract @Nullable List getDrop(); + + @SchemaFieldDescription( + "The name of a single record field that should be written. " + + "Is mutually exclusive with 'keep' and 'drop'.") + public abstract @Nullable String getOnly(); + /** Builder for {@link BigQueryWriteConfiguration}. */ @AutoValue.Builder public abstract static class Builder { @@ -212,6 +220,12 @@ public abstract static class Builder { public abstract Builder setPrimaryKey(List pkColumns); + public abstract Builder setKeep(List keep); + + public abstract Builder setDrop(List drop); + + public abstract Builder setOnly(String only); + /** Builds a {@link BigQueryWriteConfiguration} instance. */ public abstract BigQueryWriteConfiguration build(); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java new file mode 100644 index 000000000000..54d125012eac --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery.providers; + +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryWriteConfiguration.DYNAMIC_DESTINATIONS; +import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + +import com.google.api.services.bigquery.model.TableConstraints; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import java.util.List; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; +import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations; +import org.apache.beam.sdk.io.gcp.bigquery.TableDestination; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.RowFilter; +import org.apache.beam.sdk.util.RowStringInterpolator; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.ValueInSingleWindow; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.checkerframework.checker.nullness.qual.Nullable; + +@Internal +public class PortableBigQueryDestinations extends DynamicDestinations { + public static final String DESTINATION = "destination"; + public static final String RECORD = "record"; + private @MonotonicNonNull RowStringInterpolator interpolator = null; + private final @Nullable List primaryKey; + private final RowFilter rowFilter; + + public PortableBigQueryDestinations(Schema rowSchema, BigQueryWriteConfiguration configuration) { + // DYNAMIC_DESTINATIONS magic string is the old way of doing it for cross-language. + // In that case, we do no interpolation + if (!configuration.getTable().equals(DYNAMIC_DESTINATIONS)) { + this.interpolator = new RowStringInterpolator(configuration.getTable(), rowSchema); + } + this.primaryKey = configuration.getPrimaryKey(); + RowFilter rf = new RowFilter(rowSchema); + if (configuration.getDrop() != null) { + rf = rf.drop(checkStateNotNull(configuration.getDrop())); + } + if (configuration.getKeep() != null) { + rf = rf.keep(checkStateNotNull(configuration.getKeep())); + } + if (configuration.getOnly() != null) { + rf = rf.only(checkStateNotNull(configuration.getOnly())); + } + this.rowFilter = rf; + } + + @Override + public String getDestination(@Nullable ValueInSingleWindow element) { + if (interpolator != null) { + return interpolator.interpolate(checkArgumentNotNull(element)); + } + return checkStateNotNull(checkStateNotNull(element).getValue().getString(DESTINATION)); + } + + @Override + public TableDestination getTable(String destination) { + return new TableDestination(destination, null); + } + + @Override + public @Nullable TableSchema getSchema(String destination) { + return BigQueryUtils.toTableSchema(rowFilter.outputSchema()); + } + + @Override + public @Nullable TableConstraints getTableConstraints(String destination) { + if (primaryKey != null) { + return new TableConstraints() + .setPrimaryKey(new TableConstraints.PrimaryKey().setColumns(primaryKey)); + } + return null; + } + + public SerializableFunction getFilterFormatFunction(boolean fetchNestedRecord) { + return row -> { + if (fetchNestedRecord) { + row = checkStateNotNull(row.getRow(RECORD)); + } + Row filtered = rowFilter.filter(row); + return BigQueryUtils.toTableRow(filtered); + }; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java index 897d95da3b13..168febea9d88 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProviderTest.java @@ -25,6 +25,7 @@ import com.google.api.services.bigquery.model.TableReference; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.stream.Collectors; import org.apache.beam.model.pipeline.v1.RunnerApi; @@ -32,6 +33,7 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryOptions; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryFileLoadsSchemaTransformProvider.BigQueryFileLoadsSchemaTransform; import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; @@ -42,6 +44,7 @@ import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.RowFilter; import org.apache.beam.sdk.util.construction.PipelineTranslation; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; @@ -125,6 +128,42 @@ public void testLoad() throws IOException, InterruptedException { assertEquals(ROWS.size(), fakeDatasetService.getAllRows(PROJECT, DATASET, TABLE_ID).size()); } + @Test + public void testWriteToPortableDynamicDestinations() throws Exception { + String destinationTemplate = + String.format("%s:%s.dynamic_write_{name}_{number}", PROJECT, DATASET); + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() + .setTable(destinationTemplate) + .setDrop(Collections.singletonList("number")) + .build(); + BigQueryFileLoadsSchemaTransform write = + (BigQueryFileLoadsSchemaTransform) + new BigQueryFileLoadsSchemaTransformProvider().from(config); + write.setTestBigQueryServices(fakeBigQueryServices); + + PCollection inputRows = p.apply(Create.of(ROWS)).setRowSchema(SCHEMA); + PCollectionRowTuple.of("input", inputRows).apply(write); + p.run().waitUntilFinish(); + + RowFilter rowFilter = new RowFilter(SCHEMA).drop(Collections.singletonList("number")); + assertEquals( + rowFilter.filter(ROWS.get(0)), + BigQueryUtils.toBeamRow( + rowFilter.outputSchema(), + fakeDatasetService.getAllRows(PROJECT, DATASET, "dynamic_write_a_1").get(0))); + assertEquals( + rowFilter.filter(ROWS.get(1)), + BigQueryUtils.toBeamRow( + rowFilter.outputSchema(), + fakeDatasetService.getAllRows(PROJECT, DATASET, "dynamic_write_b_2").get(0))); + assertEquals( + rowFilter.filter(ROWS.get(2)), + BigQueryUtils.toBeamRow( + rowFilter.outputSchema(), + fakeDatasetService.getAllRows(PROJECT, DATASET, "dynamic_write_c_3").get(0))); + } + @Test public void testManagedChoosesFileLoadsForBoundedWrites() { PCollection batchInput = p.apply(Create.of(ROWS)).setRowSchema(SCHEMA); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java index 63727107a651..3aba2c2c6fef 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java @@ -17,22 +17,30 @@ */ package org.apache.beam.sdk.io.gcp.bigquery.providers; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; + import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.LongStream; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.testing.BigqueryClient; import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PeriodicImpulse; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.RowFilter; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; @@ -58,17 +66,14 @@ public class BigQueryManagedIT { private static final Schema SCHEMA = Schema.of( Schema.Field.of("str", Schema.FieldType.STRING), - Schema.Field.of("number", Schema.FieldType.INT64)); + Schema.Field.of("number", Schema.FieldType.INT64), + Schema.Field.of("dest", Schema.FieldType.INT64)); + + private static final SerializableFunction ROW_FUNC = + l -> Row.withSchema(SCHEMA).addValue(Long.toString(l)).addValue(l).addValue(l % 3).build(); private static final List ROWS = - LongStream.range(0, 20) - .mapToObj( - i -> - Row.withSchema(SCHEMA) - .withFieldValue("str", Long.toString(i)) - .withFieldValue("number", i) - .build()) - .collect(Collectors.toList()); + LongStream.range(0, 20).mapToObj(ROW_FUNC::apply).collect(Collectors.toList()); private static final BigqueryClient BQ_CLIENT = new BigqueryClient("BigQueryManagedIT"); @@ -93,10 +98,6 @@ public void testBatchFileLoadsWriteRead() { String.format("%s:%s.%s", PROJECT, BIG_QUERY_DATASET_ID, testName.getMethodName()); Map config = ImmutableMap.of("table", table); - // file loads requires a GCS temp location - String tempLocation = writePipeline.getOptions().as(TestPipelineOptions.class).getTempRoot(); - writePipeline.getOptions().setTempLocation(tempLocation); - // batch write PCollectionRowTuple.of("input", getInput(writePipeline, false)) .apply(Managed.write(Managed.BIGQUERY).withConfig(config)); @@ -131,6 +132,59 @@ public void testStreamingStorageWriteRead() { readPipeline.run().waitUntilFinish(); } + public void testDynamicDestinations(boolean streaming) throws IOException, InterruptedException { + String baseTableName = + String.format("%s:%s.dynamic_" + System.nanoTime(), PROJECT, BIG_QUERY_DATASET_ID); + String destinationTemplate = baseTableName + "_{dest}"; + Map config = + ImmutableMap.of("table", destinationTemplate, "drop", Collections.singletonList("dest")); + + // write + PCollectionRowTuple.of("input", getInput(writePipeline, streaming)) + .apply(Managed.write(Managed.BIGQUERY).withConfig(config)); + writePipeline.run().waitUntilFinish(); + + List destinations = + Arrays.asList(baseTableName + "_0", baseTableName + "_1", baseTableName + "_2"); + + // read and validate each table destination + RowFilter rowFilter = new RowFilter(SCHEMA).drop(Collections.singletonList("dest")); + for (int i = 0; i < destinations.size(); i++) { + long mod = i; + String dest = destinations.get(i); + List writtenRows = + BQ_CLIENT + .queryUnflattened(String.format("SELECT * FROM [%s]", dest), PROJECT, true, false) + .stream() + .map(tableRow -> BigQueryUtils.toBeamRow(rowFilter.outputSchema(), tableRow)) + .collect(Collectors.toList()); + + List expectedRecords = + ROWS.stream() + .filter(row -> row.getInt64("dest") == mod) + .map(rowFilter::filter) + .collect(Collectors.toList()); + + assertThat(writtenRows, containsInAnyOrder(expectedRecords.toArray())); + } + } + + @Test + public void testStreamingDynamicDestinations() throws IOException, InterruptedException { + if (writePipeline.getOptions().getRunner().getName().contains("DataflowRunner")) { + // Need to manually enable streaming engine for legacy dataflow runner + ExperimentalOptions.addExperiment( + writePipeline.getOptions().as(ExperimentalOptions.class), + GcpOptions.STREAMING_ENGINE_EXPERIMENT); + } + testDynamicDestinations(true); + } + + @Test + public void testBatchDynamicDestinations() throws IOException, InterruptedException { + testDynamicDestinations(false); + } + public PCollection getInput(Pipeline p, boolean isStreaming) { if (isStreaming) { return p.apply( @@ -138,14 +192,7 @@ public PCollection getInput(Pipeline p, boolean isStreaming) { .startAt(new Instant(0)) .stopAt(new Instant(19)) .withInterval(Duration.millis(1))) - .apply( - MapElements.into(TypeDescriptors.rows()) - .via( - i -> - Row.withSchema(SCHEMA) - .withFieldValue("str", Long.toString(i.getMillis())) - .withFieldValue("number", i.getMillis()) - .build())) + .apply(MapElements.into(TypeDescriptors.rows()).via(i -> ROW_FUNC.apply(i.getMillis()))) .setRowSchema(SCHEMA); } return p.apply(Create.of(ROWS)).setRowSchema(SCHEMA); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java index 7b59552bbbe4..584309778286 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.io.gcp.bigquery.providers; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.PortableBigQueryDestinations.DESTINATION; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.PortableBigQueryDestinations.RECORD; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.assertEquals; @@ -37,6 +39,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransform; import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; @@ -56,6 +59,7 @@ import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.util.RowFilter; import org.apache.beam.sdk.util.construction.PipelineTranslation; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; @@ -116,7 +120,6 @@ public void setUp() throws Exception { public void testInvalidConfig() { List invalidConfigs = Arrays.asList( - BigQueryWriteConfiguration.builder().setTable("not_a_valid_table_spec"), BigQueryWriteConfiguration.builder() .setTable("project:dataset.table") .setCreateDisposition("INVALID_DISPOSITION")); @@ -170,10 +173,7 @@ public Boolean rowsEquals(List expectedRows, List actualRows) { } public boolean rowEquals(Row expectedRow, TableRow actualRow) { - return expectedRow.getValue("name").equals(actualRow.get("name")) - && expectedRow - .getValue("number") - .equals(Long.parseLong(actualRow.get("number").toString())); + return expectedRow.equals(BigQueryUtils.toBeamRow(expectedRow.getSchema(), actualRow)); } @Test @@ -199,14 +199,14 @@ public void testWriteToDynamicDestinations() throws Exception { String baseTableSpec = "project:dataset.dynamic_write_"; Schema schemaWithDestinations = - Schema.builder().addStringField("destination").addRowField("record", SCHEMA).build(); + Schema.builder().addStringField(DESTINATION).addRowField(RECORD, SCHEMA).build(); List rowsWithDestinations = ROWS.stream() .map( row -> Row.withSchema(schemaWithDestinations) - .withFieldValue("destination", baseTableSpec + row.getInt64("number")) - .withFieldValue("record", row) + .withFieldValue(DESTINATION, baseTableSpec + row.getInt64("number")) + .withFieldValue(RECORD, row) .build()) .collect(Collectors.toList()); @@ -227,17 +227,44 @@ public void testWriteToDynamicDestinations() throws Exception { fakeDatasetService.getAllRows("project", "dataset", "dynamic_write_3").get(0))); } + @Test + public void testWriteToPortableDynamicDestinations() throws Exception { + String destinationTemplate = "project:dataset.dynamic_write_{name}_{number}"; + BigQueryWriteConfiguration config = + BigQueryWriteConfiguration.builder() + .setTable(destinationTemplate) + .setKeep(Arrays.asList("number", "dt")) + .build(); + + runWithConfig(config); + p.run().waitUntilFinish(); + + RowFilter rowFilter = new RowFilter(SCHEMA).keep(Arrays.asList("number", "dt")); + assertTrue( + rowEquals( + rowFilter.filter(ROWS.get(0)), + fakeDatasetService.getAllRows("project", "dataset", "dynamic_write_a_1").get(0))); + assertTrue( + rowEquals( + rowFilter.filter(ROWS.get(1)), + fakeDatasetService.getAllRows("project", "dataset", "dynamic_write_b_2").get(0))); + assertTrue( + rowEquals( + rowFilter.filter(ROWS.get(2)), + fakeDatasetService.getAllRows("project", "dataset", "dynamic_write_c_3").get(0))); + } + List createCDCUpsertRows(List rows, boolean dynamicDestination, String tablePrefix) { Schema.Builder schemaBuilder = Schema.builder() - .addRowField("record", SCHEMA) + .addRowField(RECORD, SCHEMA) .addRowField( BigQueryStorageWriteApiSchemaTransformProvider.ROW_PROPERTY_MUTATION_INFO, BigQueryStorageWriteApiSchemaTransformProvider.ROW_SCHEMA_MUTATION_INFO); if (dynamicDestination) { - schemaBuilder = schemaBuilder.addStringField("destination"); + schemaBuilder = schemaBuilder.addStringField(DESTINATION); } Schema schemaWithCDC = schemaBuilder.build(); @@ -261,10 +288,10 @@ List createCDCUpsertRows(List rows, boolean dynamicDestination, String .ROW_PROPERTY_MUTATION_SQN, "AAA" + idx) .build()) - .withFieldValue("record", row); + .withFieldValue(RECORD, row); if (dynamicDestination) { rowBuilder = - rowBuilder.withFieldValue("destination", tablePrefix + row.getInt64("number")); + rowBuilder.withFieldValue(DESTINATION, tablePrefix + row.getInt64("number")); } return rowBuilder.build(); }) From 7f268acbfcdb30f2c11698232c504717c47d9aad Mon Sep 17 00:00:00 2001 From: claudevdm <33973061+claudevdm@users.noreply.github.com> Date: Wed, 13 Nov 2024 12:40:09 -0500 Subject: [PATCH 173/181] Disable gradle cache for gcp expansion service (#33099) Co-authored-by: Claude --- .../beam_PostCommit_Python_Xlang_Gcp_Dataflow.json | 2 +- .../beam_PostCommit_Python_Xlang_Gcp_Direct.json | 2 +- .../beam_PostCommit_TransformService_Direct.json | 2 +- sdks/java/io/google-cloud-platform/build.gradle | 2 +- .../io/google-cloud-platform/expansion-service/build.gradle | 4 ++++ 5 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json index 27c1f3ae26cd..84cf24574f22 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json index e3d6056a5de9..b26833333238 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 1 + "modification": 2 } diff --git a/.github/trigger_files/beam_PostCommit_TransformService_Direct.json b/.github/trigger_files/beam_PostCommit_TransformService_Direct.json index 7663aee09101..876e0ebee981 100644 --- a/.github/trigger_files/beam_PostCommit_TransformService_Direct.json +++ b/.github/trigger_files/beam_PostCommit_TransformService_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "revision: "1" + "revision: "2" } diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle index b8e71e289827..0a5a89072963 100644 --- a/sdks/java/io/google-cloud-platform/build.gradle +++ b/sdks/java/io/google-cloud-platform/build.gradle @@ -359,4 +359,4 @@ task postCommit { description = "Integration tests of GCP connectors using the DirectRunner." dependsOn integrationTest dependsOn integrationTestKms -} +} \ No newline at end of file diff --git a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle index f6c6f07d0cdf..01181721e9a4 100644 --- a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle +++ b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle @@ -47,3 +47,7 @@ task runExpansionService (type: JavaExec) { classpath = sourceSets.test.runtimeClasspath args = [project.findProperty("constructionService.port") ?: "8097"] } + +shadowJar { + outputs.upToDateWhen { false } +} \ No newline at end of file From 0b1b1546d0b9589910d6f7f81578ccde1294c838 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 13 Nov 2024 09:44:40 -0800 Subject: [PATCH 174/181] [YAML] Add a new StripErrorMetadata transform. (#33094) Beam Yaml's error handling framework returns per-record errors as a schema'd PCollection with associated error metadata (e.g. error messages, tracebacks). Currently there is no way to "unnest" the nested rececords (except for field by field) back to the top level if one wants to re-process these records (or otherwise ignore the metadata). Even if there was a way to do this "up-one-level" unnesting it's not clear that this would be obvious to users to find. Worse, various forms of error handling are not consistent in what the "bad records" schema is, or even where the original record is found (though we do have a caveat in the docs that this is still not set in stone). This adds a simple, easy to identify transform that abstracts all of these complexities away for the basic usecase. --- sdks/python/apache_beam/yaml/yaml_mapping.py | 61 +++++++++++++++++++ .../apache_beam/yaml/yaml_transform_test.py | 45 ++++++++++++++ .../en/documentation/sdks/yaml-errors.md | 15 ++++- 3 files changed, 118 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 5c14b0f5ea79..130bde75ed96 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -44,6 +44,7 @@ from apache_beam.typehints.row_type import RowTypeConstraint from apache_beam.typehints.schemas import named_fields_from_element_type from apache_beam.typehints.schemas import schema_from_element_type +from apache_beam.typehints.schemas import typing_from_runner_api from apache_beam.utils import python_callable from apache_beam.yaml import json_utils from apache_beam.yaml import options @@ -482,6 +483,65 @@ def expand(pcoll, error_handling=None, **kwargs): return expand +class _StripErrorMetadata(beam.PTransform): + """Strips error metadata from outputs returned via error handling. + + Generally the error outputs for transformations return information about + the error encountered (e.g. error messages and tracebacks) in addition to the + failing element itself. This transformation attempts to remove that metadata + and returns the bad element alone which can be useful for re-processing. + + For example, in the following pipeline snippet:: + + - name: MyMappingTransform + type: MapToFields + input: SomeInput + config: + language: python + fields: + ... + error_handling: + output: errors + + - name: RecoverOriginalElements + type: StripErrorMetadata + input: MyMappingTransform.errors + + the output of `RecoverOriginalElements` will contain exactly those elements + from SomeInput that failed to processes (whereas `MyMappingTransform.errors` + would contain those elements paired with error information). + + Note that this relies on the preceding transform actually returning the + failing input in a schema'd way. Most built-in transformation follow the + correct conventions. + """ + + _ERROR_FIELD_NAMES = ('failed_row', 'element', 'record') + + def expand(self, pcoll): + try: + existing_fields = { + fld.name: fld.type + for fld in schema_from_element_type(pcoll.element_type).fields + } + except TypeError: + fld = None + else: + for fld in self._ERROR_FIELD_NAMES: + if fld in existing_fields: + break + else: + raise ValueError( + f"No field name matches one of {self._ERROR_FIELD_NAMES}") + + if fld is None: + # This handles with_exception_handling() that returns bare tuples. + return pcoll | beam.Map(lambda x: x[0]) + else: + return pcoll | beam.Map(lambda x: getattr(x, fld)).with_output_types( + typing_from_runner_api(existing_fields[fld])) + + class _Validate(beam.PTransform): """Validates each element of a PCollection against a json schema. @@ -838,6 +898,7 @@ def create_mapping_providers(): 'Partition-python': _Partition, 'Partition-javascript': _Partition, 'Partition-generic': _Partition, + 'StripErrorMetadata': _StripErrorMetadata, 'ValidateWithSchema': _Validate, }), yaml_provider.SqlBackedProvider({ diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py b/sdks/python/apache_beam/yaml/yaml_transform_test.py index fbdae6679e96..7fcea7e2b662 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py @@ -401,6 +401,51 @@ def test_error_handling_outputs(self): assert_that(result['good'], equal_to(['a', 'b']), label="CheckGood") assert_that(result['bad'], equal_to(["ValueError('biiiiig')"])) + def test_strip_error_metadata(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + result = p | YamlTransform( + ''' + type: composite + transforms: + - type: Create + config: + elements: ['a', 'b', 'biiiiig'] + + - type: SizeLimiter + input: Create + config: + limit: 5 + error_handling: + output: errors + - type: StripErrorMetadata + name: StripErrorMetadata1 + input: SizeLimiter.errors + + - type: MapToFields + input: Create + config: + language: python + fields: + out: "1/(1-len(element))" + error_handling: + output: errors + - type: StripErrorMetadata + name: StripErrorMetadata2 + input: MapToFields.errors + + output: + good: SizeLimiter + bad1: StripErrorMetadata1 + bad2: StripErrorMetadata2 + ''', + providers=TEST_PROVIDERS) + assert_that(result['good'], equal_to(['a', 'b']), label="CheckGood") + assert_that( + result['bad1'] | beam.Map(lambda x: x.element), equal_to(['biiiiig'])) + assert_that( + result['bad2'] | beam.Map(lambda x: x.element), equal_to(['a', 'b'])) + def test_must_handle_error_output(self): with self.assertRaisesRegex(Exception, 'Unconsumed error output .*line 7'): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( diff --git a/website/www/site/content/en/documentation/sdks/yaml-errors.md b/website/www/site/content/en/documentation/sdks/yaml-errors.md index 903e18d6b3c7..6edd1751a65b 100644 --- a/website/www/site/content/en/documentation/sdks/yaml-errors.md +++ b/website/www/site/content/en/documentation/sdks/yaml-errors.md @@ -78,8 +78,10 @@ for a robust pipeline). Note also that the exact format of the error outputs is still being finalized. They can be safely printed and written to outputs, but their precise schema may change in a future version of Beam and should not yet be depended on. -Currently it has, at the very least, an `element` field which holds the element -that caused the error. +It generally contains the failed record itself as well as information about +the error that was encountered (e.g. error messages and tracebacks). +To recover the bad record alone one can process the error output with the +`StripErrorMetadata` transformation. Some transforms allow for extra arguments in their error_handling config, e.g. for Python functions one can give a `threshold` which limits the relative number @@ -139,9 +141,16 @@ pipeline: error_handling: output: my_error_output + - type: StripErrorMetadata + name: FailedRecordsWithoutMetadata + # Takes the error information from ComputeRatio and returns just the + # failing records themselves for another attempt with a different + # transform. + input: ComputeRatio.my_error_output + - type: MapToFields name: ComputeRatioForBadRecords - input: ComputeRatio.my_error_output + input: FailedRecordsWithoutMetadata config: language: python fields: From cb06b1bd35c5d8cf5ea5de83290d8ef23a17bc5a Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com> Date: Wed, 13 Nov 2024 13:09:45 -0500 Subject: [PATCH 175/181] Set streaming engine option to fix V1 tests (#33100) * set enable_streaming_engine option * trigger test * trigger test * revert test trigger --- .../sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java index 3aba2c2c6fef..6b685392809f 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryManagedIT.java @@ -118,6 +118,13 @@ public void testStreamingStorageWriteRead() { String.format("%s:%s.%s", PROJECT, BIG_QUERY_DATASET_ID, testName.getMethodName()); Map config = ImmutableMap.of("table", table); + if (writePipeline.getOptions().getRunner().getName().contains("DataflowRunner")) { + // Need to manually enable streaming engine for legacy dataflow runner + ExperimentalOptions.addExperiment( + writePipeline.getOptions().as(ExperimentalOptions.class), + GcpOptions.STREAMING_ENGINE_EXPERIMENT); + } + // streaming write PCollectionRowTuple.of("input", getInput(writePipeline, true)) .apply(Managed.write(Managed.BIGQUERY).withConfig(config)); From 3c664e9eacf19e8fc6e455d4e4bf4ffc7d51f650 Mon Sep 17 00:00:00 2001 From: damccorm Date: Wed, 13 Nov 2024 18:11:34 +0000 Subject: [PATCH 176/181] Moving to 2.62.0-SNAPSHOT on master branch. --- .asf.yaml | 1 + gradle.properties | 4 ++-- sdks/go/pkg/beam/core/core.go | 2 +- sdks/python/apache_beam/version.py | 2 +- sdks/typescript/package.json | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.asf.yaml b/.asf.yaml index 703aca276e6b..50886f2cea5a 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -49,6 +49,7 @@ github: protected_branches: master: {} + release-2.61.0: {} release-2.60.0: {} release-2.59.0: {} release-2.58.1: {} diff --git a/gradle.properties b/gradle.properties index ffd4efaaab32..3923dc204272 100644 --- a/gradle.properties +++ b/gradle.properties @@ -30,8 +30,8 @@ signing.gnupg.useLegacyGpg=true # buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy. # To build a custom Beam version make sure you change it in both places, see # https://github.com/apache/beam/issues/21302. -version=2.61.0-SNAPSHOT -sdk_version=2.61.0.dev +version=2.62.0-SNAPSHOT +sdk_version=2.62.0.dev javaVersion=1.8 diff --git a/sdks/go/pkg/beam/core/core.go b/sdks/go/pkg/beam/core/core.go index e1b660e99ac6..1b478f483077 100644 --- a/sdks/go/pkg/beam/core/core.go +++ b/sdks/go/pkg/beam/core/core.go @@ -27,7 +27,7 @@ const ( // SdkName is the human readable name of the SDK for UserAgents. SdkName = "Apache Beam SDK for Go" // SdkVersion is the current version of the SDK. - SdkVersion = "2.61.0.dev" + SdkVersion = "2.62.0.dev" // DefaultDockerImage represents the associated image for this release. DefaultDockerImage = "apache/beam_go_sdk:" + SdkVersion diff --git a/sdks/python/apache_beam/version.py b/sdks/python/apache_beam/version.py index dfe451175fde..9974bb68bccf 100644 --- a/sdks/python/apache_beam/version.py +++ b/sdks/python/apache_beam/version.py @@ -17,4 +17,4 @@ """Apache Beam SDK version information and utilities.""" -__version__ = '2.61.0.dev' +__version__ = '2.62.0.dev' diff --git a/sdks/typescript/package.json b/sdks/typescript/package.json index 3dcbab684090..9ccfcaa663d1 100644 --- a/sdks/typescript/package.json +++ b/sdks/typescript/package.json @@ -1,6 +1,6 @@ { "name": "apache-beam", - "version": "2.61.0-SNAPSHOT", + "version": "2.62.0-SNAPSHOT", "devDependencies": { "@google-cloud/bigquery": "^5.12.0", "@types/mocha": "^9.0.0", From 38415b8239f05f06f9bdf2ea1681bca143a86d37 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Wed, 13 Nov 2024 15:47:30 -0500 Subject: [PATCH 177/181] Make AvroUtils compatible with older versions of Avro (#33102) * Make AvroUtils compatible with older versions of Avro * Create beam_PostCommit_Java_Avro_Versions.json * Update AvroUtils.java * Fix nullness --- .../trigger_files/beam_PostCommit_Java_Avro_Versions.json | 4 ++++ .../beam/sdk/extensions/avro/schemas/utils/AvroUtils.java | 8 ++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) create mode 100644 .github/trigger_files/beam_PostCommit_Java_Avro_Versions.json diff --git a/.github/trigger_files/beam_PostCommit_Java_Avro_Versions.json b/.github/trigger_files/beam_PostCommit_Java_Avro_Versions.json new file mode 100644 index 000000000000..1efc8e9e4405 --- /dev/null +++ b/.github/trigger_files/beam_PostCommit_Java_Avro_Versions.json @@ -0,0 +1,4 @@ +{ + "comment": "Modify this file in a trivial way to cause this test suite to run", + "modification": 1 +} diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java index bfbab6fe87f6..da7daf605d89 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/schemas/utils/AvroUtils.java @@ -141,7 +141,10 @@ * * is used. */ -@SuppressWarnings({"rawtypes"}) +@SuppressWarnings({ + "nullness", // TODO(https://github.com/apache/beam/issues/20497) + "rawtypes" +}) public class AvroUtils { private static final ForLoadedType BYTES = new ForLoadedType(byte[].class); private static final ForLoadedType JAVA_INSTANT = new ForLoadedType(java.time.Instant.class); @@ -484,7 +487,8 @@ public static Field toBeamField(org.apache.avro.Schema.Field field) { public static org.apache.avro.Schema.Field toAvroField(Field field, String namespace) { org.apache.avro.Schema fieldSchema = getFieldSchema(field.getType(), field.getName(), namespace); - return new org.apache.avro.Schema.Field(field.getName(), fieldSchema, field.getDescription()); + return new org.apache.avro.Schema.Field( + field.getName(), fieldSchema, field.getDescription(), (Object) null); } private AvroUtils() {} From 6a4562488f1a7910b423a26c305a0f6070b82e59 Mon Sep 17 00:00:00 2001 From: Jeff Kinard Date: Wed, 13 Nov 2024 15:49:25 -0500 Subject: [PATCH 178/181] fix JDBC providers (#32985) * fix JDBC providers Signed-off-by: Jeffrey Kinard * fix test failures Signed-off-by: Jeffrey Kinard * fix typo Signed-off-by: Jeffrey Kinard --------- Signed-off-by: Jeffrey Kinard --- sdks/python/apache_beam/yaml/standard_io.yaml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml b/sdks/python/apache_beam/yaml/standard_io.yaml index 400ab07a41fa..269c14e17baa 100644 --- a/sdks/python/apache_beam/yaml/standard_io.yaml +++ b/sdks/python/apache_beam/yaml/standard_io.yaml @@ -235,21 +235,21 @@ 'WriteToSqlServer': 'WriteToJdbc' defaults: 'ReadFromMySql': - jdbcType: 'mysql' + jdbc_type: 'mysql' 'WriteToMySql': - jdbcType: 'mysql' + jdbc_type: 'mysql' 'ReadFromPostgres': - jdbcType: 'postgres' + jdbc_type: 'postgres' 'WriteToPostgres': - jdbcType: 'postgres' + jdbc_type: 'postgres' 'ReadFromOracle': - jdbcType: 'oracle' + jdbc_type: 'oracle' 'WriteToOracle': - jdbcType: 'oracle' + jdbc_type: 'oracle' 'ReadFromSqlServer': - jdbcType: 'mssql' + jdbc_type: 'mssql' 'WriteToSqlServer': - jdbcType: 'mssql' + jdbc_type: 'mssql' underlying_provider: type: beamJar transforms: From f4d07c40460be87cfa2c3ed17ef865f22d834212 Mon Sep 17 00:00:00 2001 From: Naireen Hussain Date: Wed, 13 Nov 2024 13:50:38 -0800 Subject: [PATCH 179/181] Add logging to see which topic each split is reading from (#33031) Co-authored-by: Naireen --- .../sdk/io/kafka/KafkaUnboundedReader.java | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java index 209dee14da1e..069607955c6d 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaUnboundedReader.java @@ -28,6 +28,7 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -398,10 +399,10 @@ public long getSplitBacklogBytes() { /** watermark before any records have been read. */ private static Instant initialWatermark = BoundedWindow.TIMESTAMP_MIN_VALUE; - // Created in each next batch, and updated at the end. public KafkaMetrics kafkaResults = KafkaSinkMetrics.kafkaMetrics(); private Stopwatch stopwatch = Stopwatch.createUnstarted(); - private String kafkaTopic = ""; + + private Set kafkaTopics; @Override public String toString() { @@ -510,12 +511,9 @@ Instant updateAndGetWatermark() { List partitions = Preconditions.checkArgumentNotNull(source.getSpec().getTopicPartitions()); - // Each source has a single unique topic. - for (TopicPartition topicPartition : partitions) { - this.kafkaTopic = topicPartition.topic(); - break; - } + this.kafkaTopics = partitions.stream().map(TopicPartition::topic).collect(Collectors.toSet()); + LOG.info("{} is reading from topics {}", this.name, kafkaTopics); List> states = new ArrayList<>(partitions.size()); if (checkpointMark != null) { @@ -573,16 +571,14 @@ private void consumerPollLoop() { while (!closed.get()) { try { if (records.isEmpty()) { - // Each source has a single unique topic. - List topicPartitions = source.getSpec().getTopicPartitions(); - Preconditions.checkStateNotNull(topicPartitions); - stopwatch.start(); records = consumer.poll(KAFKA_POLL_TIMEOUT.getMillis()); stopwatch.stop(); - kafkaResults.updateSuccessfulRpcMetrics( - kafkaTopic, java.time.Duration.ofMillis(stopwatch.elapsed(TimeUnit.MILLISECONDS))); - + for (String kafkaTopic : kafkaTopics) { + kafkaResults.updateSuccessfulRpcMetrics( + kafkaTopic, + java.time.Duration.ofMillis(stopwatch.elapsed(TimeUnit.MILLISECONDS))); + } } else if (availableRecordsQueue.offer( records, RECORDS_ENQUEUE_POLL_TIMEOUT.getMillis(), TimeUnit.MILLISECONDS)) { records = ConsumerRecords.empty(); From bff3eac8b6a156f836cdf956e92fd4916cb44e1d Mon Sep 17 00:00:00 2001 From: tvalentyn Date: Wed, 13 Nov 2024 13:58:57 -0800 Subject: [PATCH 180/181] Update names.py (#33107) Update container image to pick up recent changes. --- sdks/python/apache_beam/runners/dataflow/internal/names.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/names.py b/sdks/python/apache_beam/runners/dataflow/internal/names.py index a68f422e08e6..912814b6bb5f 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/names.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/names.py @@ -34,6 +34,6 @@ # Unreleased sdks use container image tag specified below. # Update this tag whenever there is a change that # requires changes to SDK harness container or SDK harness launcher. -BEAM_DEV_SDK_CONTAINER_TAG = 'beam-master-20241007' +BEAM_DEV_SDK_CONTAINER_TAG = 'beam-master-20241113' DATAFLOW_CONTAINER_IMAGE_REPOSITORY = 'gcr.io/cloud-dataflow/v1beta3' From ab5c0695530fb73568cf4810da101a683da56874 Mon Sep 17 00:00:00 2001 From: Jeff Kinard Date: Wed, 13 Nov 2024 18:31:36 -0500 Subject: [PATCH 181/181] [yaml] Fix examples catalog tests (#33027) --- .../yaml/examples/testing/examples_test.py | 105 ++++++++++++++++-- .../{ => transforms}/io/spanner_read.yaml | 23 ++-- .../{ => transforms}/io/spanner_write.yaml | 17 +-- .../apache_beam/yaml/generate_yaml_docs.py | 2 +- sdks/python/apache_beam/yaml/yaml_errors.py | 88 +++++++++++++++ sdks/python/apache_beam/yaml/yaml_io.py | 6 +- sdks/python/apache_beam/yaml/yaml_mapping.py | 78 +------------ sdks/python/apache_beam/yaml/yaml_provider.py | 7 +- .../yaml/yaml_transform_scope_test.py | 6 +- .../yaml/yaml_transform_unit_test.py | 2 +- 10 files changed, 224 insertions(+), 110 deletions(-) rename sdks/python/apache_beam/yaml/examples/{ => transforms}/io/spanner_read.yaml (73%) rename sdks/python/apache_beam/yaml/examples/{ => transforms}/io/spanner_write.yaml (69%) create mode 100644 sdks/python/apache_beam/yaml/yaml_errors.py diff --git a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py index 6c8efac980aa..3b497ed1efab 100644 --- a/sdks/python/apache_beam/yaml/examples/testing/examples_test.py +++ b/sdks/python/apache_beam/yaml/examples/testing/examples_test.py @@ -40,8 +40,8 @@ def check_output(expected: List[str]): - def _check_inner(actual: PCollection[str]): - formatted_actual = actual | beam.Map( + def _check_inner(actual: List[PCollection[str]]): + formatted_actual = actual | beam.Flatten() | beam.Map( lambda row: str(beam.Row(**row._asdict()))) assert_matches_stdout(formatted_actual, expected) @@ -59,6 +59,57 @@ def products_csv(): ]) +def spanner_data(): + return [{ + 'shipment_id': 'S1', + 'customer_id': 'C1', + 'shipment_date': '2023-05-01', + 'shipment_cost': 150.0, + 'customer_name': 'Alice', + 'customer_email': 'alice@example.com' + }, + { + 'shipment_id': 'S2', + 'customer_id': 'C2', + 'shipment_date': '2023-06-12', + 'shipment_cost': 300.0, + 'customer_name': 'Bob', + 'customer_email': 'bob@example.com' + }, + { + 'shipment_id': 'S3', + 'customer_id': 'C1', + 'shipment_date': '2023-05-10', + 'shipment_cost': 20.0, + 'customer_name': 'Alice', + 'customer_email': 'alice@example.com' + }, + { + 'shipment_id': 'S4', + 'customer_id': 'C4', + 'shipment_date': '2024-07-01', + 'shipment_cost': 150.0, + 'customer_name': 'Derek', + 'customer_email': 'derek@example.com' + }, + { + 'shipment_id': 'S5', + 'customer_id': 'C5', + 'shipment_date': '2023-05-09', + 'shipment_cost': 300.0, + 'customer_name': 'Erin', + 'customer_email': 'erin@example.com' + }, + { + 'shipment_id': 'S6', + 'customer_id': 'C4', + 'shipment_date': '2024-07-02', + 'shipment_cost': 150.0, + 'customer_name': 'Derek', + 'customer_email': 'derek@example.com' + }] + + def create_test_method( pipeline_spec_file: str, custom_preprocessors: List[Callable[..., Union[Dict, List]]]): @@ -84,9 +135,12 @@ def test_yaml_example(self): pickle_library='cloudpickle', **yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get( 'options', {})))) as p: - actual = yaml_transform.expand_pipeline(p, pipeline_spec) - if not actual: - actual = p.transforms_stack[0].parts[-1].outputs[None] + actual = [yaml_transform.expand_pipeline(p, pipeline_spec)] + if not actual[0]: + actual = list(p.transforms_stack[0].parts[-1].outputs.values()) + for transform in p.transforms_stack[0].parts[:-1]: + if transform.transform.label == 'log_for_testing': + actual += list(transform.outputs.values()) check_output(expected)(actual) return test_yaml_example @@ -155,9 +209,13 @@ def _wordcount_test_preprocessor( env.input_file('kinglear.txt', '\n'.join(lines))) -@YamlExamplesTestSuite.register_test_preprocessor( - ['test_simple_filter_yaml', 'test_simple_filter_and_combine_yaml']) -def _file_io_write_test_preprocessor( +@YamlExamplesTestSuite.register_test_preprocessor([ + 'test_simple_filter_yaml', + 'test_simple_filter_and_combine_yaml', + 'test_spanner_read_yaml', + 'test_spanner_write_yaml' +]) +def _io_write_test_preprocessor( test_spec: dict, expected: List[str], env: TestEnvironment): if pipeline := test_spec.get('pipeline', None): @@ -166,8 +224,8 @@ def _file_io_write_test_preprocessor( transform['type'] = 'LogForTesting' transform['config'] = { k: v - for k, - v in transform.get('config', {}).items() if k.startswith('__') + for (k, v) in transform.get('config', {}).items() + if (k.startswith('__') or k == 'error_handling') } return test_spec @@ -191,7 +249,30 @@ def _file_io_read_test_preprocessor( return test_spec +@YamlExamplesTestSuite.register_test_preprocessor(['test_spanner_read_yaml']) +def _spanner_io_read_test_preprocessor( + test_spec: dict, expected: List[str], env: TestEnvironment): + + if pipeline := test_spec.get('pipeline', None): + for transform in pipeline.get('transforms', []): + if transform.get('type', '').startswith('ReadFromSpanner'): + config = transform['config'] + instance, database = config['instance_id'], config['database_id'] + if table := config.get('table', None) is None: + table = config.get('query', '').split('FROM')[-1].strip() + transform['type'] = 'Create' + transform['config'] = { + k: v + for k, v in config.items() if k.startswith('__') + } + transform['config']['elements'] = INPUT_TABLES[( + str(instance), str(database), str(table))] + + return test_spec + + INPUT_FILES = {'products.csv': products_csv()} +INPUT_TABLES = {('shipment-test', 'shipment', 'shipments'): spanner_data()} YAML_DOCS_DIR = os.path.join(os.path.dirname(__file__)) ExamplesTest = YamlExamplesTestSuite( @@ -205,6 +286,10 @@ def _file_io_read_test_preprocessor( 'AggregationExamplesTest', os.path.join(YAML_DOCS_DIR, '../transforms/aggregation/*.yaml')).run() +IOTest = YamlExamplesTestSuite( + 'IOExamplesTest', os.path.join(YAML_DOCS_DIR, + '../transforms/io/*.yaml')).run() + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/yaml/examples/io/spanner_read.yaml b/sdks/python/apache_beam/yaml/examples/transforms/io/spanner_read.yaml similarity index 73% rename from sdks/python/apache_beam/yaml/examples/io/spanner_read.yaml rename to sdks/python/apache_beam/yaml/examples/transforms/io/spanner_read.yaml index c86d42c1e0c6..26f68b68d931 100644 --- a/sdks/python/apache_beam/yaml/examples/io/spanner_read.yaml +++ b/sdks/python/apache_beam/yaml/examples/transforms/io/spanner_read.yaml @@ -18,10 +18,10 @@ pipeline: transforms: - # Reading data from a Spanner database. The table used here has the following columns: - # shipment_id (String), customer_id (String), shipment_date (String), shipment_cost (Float64), customer_name (String), customer_email (String) - # ReadFromSpanner transform is called using project_id, instance_id, database_id and a query - # A table with a list of columns can also be specified instead of a query + # Reading data from a Spanner database. The table used here has the following columns: + # shipment_id (String), customer_id (String), shipment_date (String), shipment_cost (Float64), customer_name (String), customer_email (String) + # ReadFromSpanner transform is called using project_id, instance_id, database_id and a query + # A table with a list of columns can also be specified instead of a query - type: ReadFromSpanner name: ReadShipments config: @@ -30,8 +30,8 @@ pipeline: database_id: 'shipment' query: 'SELECT * FROM shipments' - # Filtering the data based on a specific condition - # Here, the condition is used to keep only the rows where the customer_id is 'C1' + # Filtering the data based on a specific condition + # Here, the condition is used to keep only the rows where the customer_id is 'C1' - type: Filter name: FilterShipments input: ReadShipments @@ -39,9 +39,9 @@ pipeline: language: python keep: "customer_id == 'C1'" - # Mapping the data fields and applying transformations - # A new field 'shipment_cost_category' is added with a custom transformation - # A callable is defined to categorize shipment cost + # Mapping the data fields and applying transformations + # A new field 'shipment_cost_category' is added with a custom transformation + # A callable is defined to categorize shipment cost - type: MapToFields name: MapFieldsForSpanner input: FilterShipments @@ -65,7 +65,7 @@ pipeline: else: return 'High Cost' - # Writing the transformed data to a CSV file + # Writing the transformed data to a CSV file - type: WriteToCsv name: WriteBig input: MapFieldsForSpanner @@ -73,8 +73,7 @@ pipeline: path: shipments.csv - # On executing the above pipeline, a new CSV file is created with the following records - +# On executing the above pipeline, a new CSV file is created with the following records # Expected: # Row(shipment_id='S1', customer_id='C1', shipment_date='2023-05-01', shipment_cost=150.0, customer_name='Alice', customer_email='alice@example.com', shipment_cost_category='Medium Cost') # Row(shipment_id='S3', customer_id='C1', shipment_date='2023-05-10', shipment_cost=20.0, customer_name='Alice', customer_email='alice@example.com', shipment_cost_category='Low Cost') diff --git a/sdks/python/apache_beam/yaml/examples/io/spanner_write.yaml b/sdks/python/apache_beam/yaml/examples/transforms/io/spanner_write.yaml similarity index 69% rename from sdks/python/apache_beam/yaml/examples/io/spanner_write.yaml rename to sdks/python/apache_beam/yaml/examples/transforms/io/spanner_write.yaml index 74ac35de260f..1667fcfcc163 100644 --- a/sdks/python/apache_beam/yaml/examples/io/spanner_write.yaml +++ b/sdks/python/apache_beam/yaml/examples/transforms/io/spanner_write.yaml @@ -18,8 +18,8 @@ pipeline: transforms: - # Step 1: Creating rows to be written to Spanner - # The element names correspond to the column names in the Spanner table + # Step 1: Creating rows to be written to Spanner + # The element names correspond to the column names in the Spanner table - type: Create name: CreateRows config: @@ -31,10 +31,10 @@ pipeline: customer_name: "Erin" customer_email: "erin@example.com" - # Step 2: Writing the created rows to a Spanner database - # We require the project ID, instance ID, database ID and table ID to connect to Spanner - # Error handling can be specified optionally to ensure any failed operations aren't lost - # The failed data is passed on in the pipeline and can be handled + # Step 2: Writing the created rows to a Spanner database + # We require the project ID, instance ID, database ID and table ID to connect to Spanner + # Error handling can be specified optionally to ensure any failed operations aren't lost + # The failed data is passed on in the pipeline and can be handled - type: WriteToSpanner name: WriteSpanner input: CreateRows @@ -46,8 +46,11 @@ pipeline: error_handling: output: my_error_output - # Step 3: Writing the failed records to a JSON file + # Step 3: Writing the failed records to a JSON file - type: WriteToJson input: WriteSpanner.my_error_output config: path: errors.json + +# Expected: +# Row(shipment_id='S5', customer_id='C5', shipment_date='2023-05-09', shipment_cost=300.0, customer_name='Erin', customer_email='erin@example.com') diff --git a/sdks/python/apache_beam/yaml/generate_yaml_docs.py b/sdks/python/apache_beam/yaml/generate_yaml_docs.py index 4088e17afe2c..2123c7a9f202 100644 --- a/sdks/python/apache_beam/yaml/generate_yaml_docs.py +++ b/sdks/python/apache_beam/yaml/generate_yaml_docs.py @@ -30,7 +30,7 @@ from apache_beam.version import __version__ as beam_version from apache_beam.yaml import json_utils from apache_beam.yaml import yaml_provider -from apache_beam.yaml.yaml_mapping import ErrorHandlingConfig +from apache_beam.yaml.yaml_errors import ErrorHandlingConfig def _singular(name): diff --git a/sdks/python/apache_beam/yaml/yaml_errors.py b/sdks/python/apache_beam/yaml/yaml_errors.py new file mode 100644 index 000000000000..c0d448473f42 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_errors.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import inspect +from typing import NamedTuple + +import apache_beam as beam +from apache_beam.typehints.row_type import RowTypeConstraint + + +class ErrorHandlingConfig(NamedTuple): + """Class to define Error Handling parameters. + + Args: + output (str): Name to use for the output error collection + """ + output: str + # TODO: Other parameters are valid here too, but not common to Java. + + +def exception_handling_args(error_handling_spec): + if error_handling_spec: + return { + 'dead_letter_tag' if k == 'output' else k: v + for (k, v) in error_handling_spec.items() + } + else: + return None + + +def map_errors_to_standard_format(input_type): + # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple. + + return beam.Map( + lambda x: beam.Row( + element=x[0], msg=str(x[1][1]), stack=''.join(x[1][2])) + ).with_output_types( + RowTypeConstraint.from_fields([("element", input_type), ("msg", str), + ("stack", str)])) + + +def maybe_with_exception_handling(inner_expand): + def expand(self, pcoll): + wrapped_pcoll = beam.core._MaybePValueWithErrors( + pcoll, self._exception_handling_args) + return inner_expand(self, wrapped_pcoll).as_result( + map_errors_to_standard_format(pcoll.element_type)) + + return expand + + +def maybe_with_exception_handling_transform_fn(transform_fn): + @functools.wraps(transform_fn) + def expand(pcoll, error_handling=None, **kwargs): + wrapped_pcoll = beam.core._MaybePValueWithErrors( + pcoll, exception_handling_args(error_handling)) + return transform_fn(wrapped_pcoll, **kwargs).as_result( + map_errors_to_standard_format(pcoll.element_type)) + + original_signature = inspect.signature(transform_fn) + new_parameters = list(original_signature.parameters.values()) + error_handling_param = inspect.Parameter( + 'error_handling', + inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=ErrorHandlingConfig) + if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: + new_parameters.insert(-1, error_handling_param) + else: + new_parameters.append(error_handling_param) + expand.__signature__ = original_signature.replace(parameters=new_parameters) + + return expand diff --git a/sdks/python/apache_beam/yaml/yaml_io.py b/sdks/python/apache_beam/yaml/yaml_io.py index 22663bdb8461..a6525aef9877 100644 --- a/sdks/python/apache_beam/yaml/yaml_io.py +++ b/sdks/python/apache_beam/yaml/yaml_io.py @@ -45,7 +45,7 @@ from apache_beam.portability.api import schema_pb2 from apache_beam.typehints import schemas from apache_beam.yaml import json_utils -from apache_beam.yaml import yaml_mapping +from apache_beam.yaml import yaml_errors from apache_beam.yaml import yaml_provider @@ -289,7 +289,7 @@ def formatter(row): @beam.ptransform_fn -@yaml_mapping.maybe_with_exception_handling_transform_fn +@yaml_errors.maybe_with_exception_handling_transform_fn def read_from_pubsub( root, *, @@ -393,7 +393,7 @@ def mapper(msg): @beam.ptransform_fn -@yaml_mapping.maybe_with_exception_handling_transform_fn +@yaml_errors.maybe_with_exception_handling_transform_fn def write_to_pubsub( pcoll, *, diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 130bde75ed96..3bef1a0a1101 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -16,8 +16,6 @@ # """This module defines the basic MapToFields operation.""" -import functools -import inspect import itertools import re from collections import abc @@ -27,7 +25,6 @@ from typing import Dict from typing import List from typing import Mapping -from typing import NamedTuple from typing import Optional from typing import TypeVar from typing import Union @@ -41,7 +38,6 @@ from apache_beam.typehints import trivial_inference from apache_beam.typehints import typehints from apache_beam.typehints.native_type_compatibility import convert_to_beam_type -from apache_beam.typehints.row_type import RowTypeConstraint from apache_beam.typehints.schemas import named_fields_from_element_type from apache_beam.typehints.schemas import schema_from_element_type from apache_beam.typehints.schemas import typing_from_runner_api @@ -49,6 +45,10 @@ from apache_beam.yaml import json_utils from apache_beam.yaml import options from apache_beam.yaml import yaml_provider +from apache_beam.yaml.yaml_errors import exception_handling_args +from apache_beam.yaml.yaml_errors import map_errors_to_standard_format +from apache_beam.yaml.yaml_errors import maybe_with_exception_handling +from apache_beam.yaml.yaml_errors import maybe_with_exception_handling_transform_fn from apache_beam.yaml.yaml_provider import dicts_to_rows # Import js2py package if it exists @@ -418,71 +418,6 @@ def checking_func(row): return func -class ErrorHandlingConfig(NamedTuple): - """Class to define Error Handling parameters. - - Args: - output (str): Name to use for the output error collection - """ - output: str - # TODO: Other parameters are valid here too, but not common to Java. - - -def exception_handling_args(error_handling_spec): - if error_handling_spec: - return { - 'dead_letter_tag' if k == 'output' else k: v - for (k, v) in error_handling_spec.items() - } - else: - return None - - -def _map_errors_to_standard_format(input_type): - # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple. - - return beam.Map( - lambda x: beam.Row( - element=x[0], msg=str(x[1][1]), stack=''.join(x[1][2])) - ).with_output_types( - RowTypeConstraint.from_fields([("element", input_type), ("msg", str), - ("stack", str)])) - - -def maybe_with_exception_handling(inner_expand): - def expand(self, pcoll): - wrapped_pcoll = beam.core._MaybePValueWithErrors( - pcoll, self._exception_handling_args) - return inner_expand(self, wrapped_pcoll).as_result( - _map_errors_to_standard_format(pcoll.element_type)) - - return expand - - -def maybe_with_exception_handling_transform_fn(transform_fn): - @functools.wraps(transform_fn) - def expand(pcoll, error_handling=None, **kwargs): - wrapped_pcoll = beam.core._MaybePValueWithErrors( - pcoll, exception_handling_args(error_handling)) - return transform_fn(wrapped_pcoll, **kwargs).as_result( - _map_errors_to_standard_format(pcoll.element_type)) - - original_signature = inspect.signature(transform_fn) - new_parameters = list(original_signature.parameters.values()) - error_handling_param = inspect.Parameter( - 'error_handling', - inspect.Parameter.KEYWORD_ONLY, - default=None, - annotation=ErrorHandlingConfig) - if new_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: - new_parameters.insert(-1, error_handling_param) - else: - new_parameters.append(error_handling_param) - expand.__signature__ = original_signature.replace(parameters=new_parameters) - - return expand - - class _StripErrorMetadata(beam.PTransform): """Strips error metadata from outputs returned via error handling. @@ -845,9 +780,8 @@ def split(element): splits = pcoll | mapping_transform.with_input_types(T).with_output_types(T) result = {out: getattr(splits, out) for out in output_set} if error_output: - result[ - error_output] = result[error_output] | _map_errors_to_standard_format( - pcoll.element_type) + result[error_output] = result[error_output] | map_errors_to_standard_format( + pcoll.element_type) return result diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index ef2316f51f0e..a07638953551 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -63,6 +63,7 @@ from apache_beam.utils import subprocess_server from apache_beam.version import __version__ as beam_version from apache_beam.yaml import json_utils +from apache_beam.yaml.yaml_errors import maybe_with_exception_handling_transform_fn class Provider: @@ -876,8 +877,10 @@ def _parse_window_spec(spec): return beam.WindowInto(window_fn) @staticmethod + @beam.ptransform_fn + @maybe_with_exception_handling_transform_fn def log_for_testing( - level: Optional[str] = 'INFO', prefix: Optional[str] = ''): + pcoll, *, level: Optional[str] = 'INFO', prefix: Optional[str] = ''): """Logs each element of its input PCollection. The output of this transform is a copy of its input for ease of use in @@ -918,7 +921,7 @@ def log_and_return(x): logger(prefix + json.dumps(to_loggable_json_recursive(x))) return x - return "LogForTesting" >> beam.Map(log_and_return) + return pcoll | "LogForTesting" >> beam.Map(log_and_return) @staticmethod def create_builtin_provider(): diff --git a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py index f00403b07e2a..2a5a96aa42df 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py @@ -72,10 +72,12 @@ def test_get_pcollection_output(self): str(scope.get_pcollection("Create"))) self.assertEqual( - "PCollection[Square.None]", str(scope.get_pcollection("Square"))) + "PCollection[Square/LogForTesting.None]", + str(scope.get_pcollection("Square"))) self.assertEqual( - "PCollection[Square.None]", str(scope.get_pcollection("LogForTesting"))) + "PCollection[Square/LogForTesting.None]", + str(scope.get_pcollection("LogForTesting"))) self.assertTrue( scope.get_pcollection("Square") == scope.get_pcollection( diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index 8c4b00351b24..bc0493509d5a 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -213,7 +213,7 @@ def test_expand_composite_transform_with_name_input(self): inputs={'elements': elements}) self.assertRegex( str(expand_composite_transform(spec, scope)['output']), - r"PCollection.*Composite/LogForTesting.*") + r"PCollection.*Composite/log_for_testing/LogForTesting.*") def test_expand_composite_transform_root(self): with new_pipeline() as p: