From 7e077dc1cd72854a6e5a088edf72679a14ef2717 Mon Sep 17 00:00:00 2001 From: Jeff Kinard Date: Thu, 26 Dec 2024 04:24:23 -0600 Subject: [PATCH 1/8] [yaml] add RunInference support with VertexAI (#33406) * [yaml] add RunInference support with VertexAI Signed-off-by: Jeffrey Kinard * address comments and fix tests Signed-off-by: Jeffrey Kinard * add more docs Signed-off-by: Jeffrey Kinard * fix failing tests Signed-off-by: Jeffrey Kinard * fix errors Signed-off-by: Jeffrey Kinard * fix lint Signed-off-by: Jeffrey Kinard --------- Signed-off-by: Jeffrey Kinard --- .../apache_beam/yaml/standard_providers.yaml | 1 + sdks/python/apache_beam/yaml/yaml_ml.py | 438 +++++++++++++++++- .../python/apache_beam/yaml/yaml_transform.py | 56 +-- .../yaml/yaml_transform_unit_test.py | 53 +-- sdks/python/apache_beam/yaml/yaml_utils.py | 75 +++ .../apache_beam/yaml/yaml_utils_test.py | 79 ++++ 6 files changed, 592 insertions(+), 110 deletions(-) create mode 100644 sdks/python/apache_beam/yaml/yaml_utils.py create mode 100644 sdks/python/apache_beam/yaml/yaml_utils_test.py diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml b/sdks/python/apache_beam/yaml/standard_providers.yaml index 242faaa9a77b..31eb5e1c6daa 100644 --- a/sdks/python/apache_beam/yaml/standard_providers.yaml +++ b/sdks/python/apache_beam/yaml/standard_providers.yaml @@ -56,6 +56,7 @@ config: {} transforms: MLTransform: 'apache_beam.yaml.yaml_ml.ml_transform' + RunInference: 'apache_beam.yaml.yaml_ml.run_inference' - type: renaming transforms: diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index 33f2eeefd296..e958ea70aff8 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -16,13 +16,20 @@ # """This module defines yaml wrappings for some ML transforms.""" - from typing import Any +from typing import Callable +from typing import Dict from typing import List from typing import Optional import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference import RunInference +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.typehints.row_type import RowTypeConstraint +from apache_beam.utils import python_callable from apache_beam.yaml import options +from apache_beam.yaml.yaml_utils import SafeLineLoader try: from apache_beam.ml.transforms import tft @@ -33,11 +40,436 @@ tft = None # type: ignore +class ModelHandlerProvider: + handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {} + + def __init__( + self, + handler, + preprocess: Optional[Dict[str, str]] = None, + postprocess: Optional[Dict[str, str]] = None): + self._handler = handler + self._preprocess_fn = self.parse_processing_transform( + preprocess, 'preprocess') or self.default_preprocess_fn() + self._postprocess_fn = self.parse_processing_transform( + postprocess, 'postprocess') or self.default_postprocess_fn() + + def inference_output_type(self): + return Any + + @staticmethod + def parse_processing_transform(processing_transform, typ): + def _parse_config(callable=None, path=None, name=None): + if callable and (path or name): + raise ValueError( + f"Cannot specify 'callable' with 'path' and 'name' for {typ} " + f"function.") + if path and name: + return python_callable.PythonCallableWithSource.load_from_script( + FileSystems.open(path).read().decode(), name) + elif callable: + return python_callable.PythonCallableWithSource(callable) + else: + raise ValueError( + f"Must specify one of 'callable' or 'path' and 'name' for {typ} " + f"function.") + + if processing_transform: + if isinstance(processing_transform, dict): + return _parse_config(**processing_transform) + else: + raise ValueError("Invalid model_handler specification.") + + def underlying_handler(self): + return self._handler + + @staticmethod + def default_preprocess_fn(): + raise ValueError( + 'Model Handler does not implement a default preprocess ' + 'method. Please define a preprocessing method using the ' + '\'preprocess\' tag. This is required in most cases because ' + 'most models will have a different input shape, so the model ' + 'cannot generalize how the input Row should be transformed. For ' + 'an example preprocess method, see VertexAIModelHandlerJSONProvider') + + def _preprocess_fn_internal(self): + return lambda row: (row, self._preprocess_fn(row)) + + @staticmethod + def default_postprocess_fn(): + return lambda x: x + + def _postprocess_fn_internal(self): + return lambda result: (result[0], self._postprocess_fn(result[1])) + + @staticmethod + def validate(model_handler_spec): + raise NotImplementedError(type(ModelHandlerProvider)) + + @classmethod + def register_handler_type(cls, type_name): + def apply(constructor): + cls.handler_types[type_name] = constructor + return constructor + + return apply + + @classmethod + def create_handler(cls, model_handler_spec) -> "ModelHandlerProvider": + typ = model_handler_spec['type'] + config = model_handler_spec['config'] + try: + result = cls.handler_types[typ](**config) + if not hasattr(result, 'to_json'): + result.to_json = lambda: model_handler_spec + return result + except Exception as exn: + raise ValueError( + f'Unable to instantiate model handler of type {typ}. {exn}') + + +@ModelHandlerProvider.register_handler_type('VertexAIModelHandlerJSON') +class VertexAIModelHandlerJSONProvider(ModelHandlerProvider): + def __init__( + self, + endpoint_id: str, + project: str, + location: str, + preprocess: Dict[str, str], + postprocess: Optional[Dict[str, str]] = None, + experiment: Optional[str] = None, + network: Optional[str] = None, + private: bool = False, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, + env_vars: Optional[Dict[str, Any]] = None): + """ + ModelHandler for Vertex AI. + + This Model Handler can be used with RunInference to load a model hosted + on VertexAI. Every model that is hosted on VertexAI should have three + distinct, required, parameters - `endpoint_id`, `project` and `location`. + These parameters tell the Model Handler how to access the model's endpoint + so that input data can be sent using an API request, and inferences can be + received as a response. + + This Model Handler also requires a `preprocess` function to be defined. + Preprocessing and Postprocessing are described in more detail in the + RunInference docs: + https://beam.apache.org/releases/yamldoc/current/#runinference + + Every model will have a unique input, but all requests should be + JSON-formatted. For example, most language models such as Llama and Gemma + expect a JSON with the key "prompt" (among other optional keys). In Python, + JSON can be expressed as a dictionary. + + For example: :: + + - type: RunInference + config: + inference_tag: 'my_inference' + model_handler: + type: VertexAIModelHandlerJSON + config: + endpoint_id: 9876543210 + project: my-project + location: us-east1 + preprocess: + callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}' + + In the above example, which mimics a call to a Llama 3 model hosted on + VertexAI, the preprocess function (in this case a lambda) takes in a Beam + Row with a single field, "prompt", and maps it to a dict with the same + field. It also specifies an optional parameter, "max_tokens", that tells the + model the allowed token size (in this case input + output token size). + + Args: + endpoint_id: the numerical ID of the Vertex AI endpoint to query. + project: the GCP project name where the endpoint is deployed. + location: the GCP location where the endpoint is deployed. + preprocess: A python callable, defined either inline, or using a file, + that is invoked on the input row before sending to the model to be + loaded by this ModelHandler. This parameter is required by the + `VertexAIModelHandlerJSON` ModelHandler. + postprocess: A python callable, defined either inline, or using a file, + that is invoked on the PredictionResult output by the ModelHandler + before parsing into the output Beam Row under the field name defined + by the inference_tag. + experiment: Experiment label to apply to the + queries. See + https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments + for more information. + network: The full name of the Compute Engine + network the endpoint is deployed on; used for private + endpoints. The network or subnetwork Dataflow pipeline + option must be set and match this network for pipeline + execution. + Ex: "projects/12345/global/networks/myVPC" + private: If the deployed Vertex AI endpoint is + private, set to true. Requires a network to be provided + as well. + min_batch_size: The minimum batch size to use when batching + inputs. + max_batch_size: The maximum batch size to use when batching + inputs. + max_batch_duration_secs: The maximum amount of time to buffer + a batch before emitting; used in streaming contexts. + env_vars: Environment variables. + """ + + try: + from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON + except ImportError: + raise ValueError( + 'Unable to import VertexAIModelHandlerJSON. Please ' + 'install gcp dependencies: `pip install apache_beam[gcp]`') + + _handler = VertexAIModelHandlerJSON( + endpoint_id=str(endpoint_id), + project=project, + location=location, + experiment=experiment, + network=network, + private=private, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + max_batch_duration_secs=max_batch_duration_secs, + env_vars=env_vars or {}) + + super().__init__(_handler, preprocess, postprocess) + + @staticmethod + def validate(model_handler_spec): + pass + + def inference_output_type(self): + return RowTypeConstraint.from_fields([('example', Any), ('inference', Any), + ('model_id', Optional[str])]) + + +def get_user_schema_fields(user_type): + return [(name, type(typ) if not isinstance(typ, type) else typ) + for (name, typ) in user_type._fields] if user_type else [] + + +@beam.ptransform.ptransform_fn +def run_inference( + pcoll, + model_handler: Dict[str, Any], + inference_tag: Optional[str] = 'inference', + inference_args: Optional[Dict[str, Any]] = None) -> beam.PCollection[beam.Row]: # pylint: disable=line-too-long + """ + A transform that takes the input rows, containing examples (or features), for + use on an ML model. The transform then appends the inferences + (or predictions) for those examples to the input row. + + A ModelHandler must be passed to the `model_handler` parameter. The + ModelHandler is responsible for configuring how the ML model will be loaded + and how input data will be passed to it. Every ModelHandler has a config tag, + similar to how a transform is defined, where the parameters are defined. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + ... + + By default, the RunInference transform will return the + input row with a single field appended named by the `inference_tag` parameter + ("inference" by default) that contains the inference directly returned by the + underlying ModelHandler, after any optional postprocessing. + + For example, if the input had the following: :: + + Row(question="What is a car?") + + The output row would look like: :: + + Row(question="What is a car?", inference=...) + + where the `inference` tag can be overridden with the `inference_tag` + parameter. + + However, if one specified the following transform config: :: + + - type: RunInference + config: + inference_tag: my_inference + model_handler: ... + + The output row would look like: :: + + Row(question="What is a car?", my_inference=...) + + See more complete documentation on the underlying + [RunInference](https://beam.apache.org/documentation/ml/inference-overview/) + transform. + + ### Preprocessing input data + + In most cases, the model will be expecting data in a particular data format, + whether it be a Python Dict, PyTorch tensor, etc. However, the outputs of all + built-in Beam YAML transforms are Beam Rows. To allow for transforming + the Beam Row into a data format the model recognizes, each ModelHandler is + equipped with a `preprocessing` parameter for performing necessary data + preprocessing. It is possible for a ModelHandler to define a default + preprocessing function, but in most cases, one will need to be specified by + the caller. + + For example, using `callable`: :: + + pipeline: + type: chain + + transforms: + - type: Create + config: + elements: + - question: "What is a car?" + - question: "Where is the Eiffel Tower located?" + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + preprocess: + callable: 'lambda row: {"prompt": row.question}' + ... + + In the above example, the Create transform generates a collection of two Beam + Row elements, each with a single field - "question". The model, however, + expects a Python Dict with a single key, "prompt". In this case, we can + specify a simple Lambda function (alternatively could define a full function), + to map the data. + + ### Postprocessing predictions + + It is also possible to define a postprocessing function to postprocess the + data output by the ModelHandler. See the documentation for the ModelHandler + you intend to use (list defined below under `model_handler` parameter doc). + + In many cases, before postprocessing, the object + will be a + [PredictionResult](https://beam.apache.org/releases/pydoc/BEAM_VERSION/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult). # pylint: disable=line-too-long + This type behaves very similarly to a Beam Row and fields can be accessed + using dot notation. However, make sure to check the docs for your ModelHandler + to see which fields its PredictionResult contains or if it returns a + different object altogether. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + postprocess: + callable: | + def fn(x: PredictionResult): + return beam.Row(x.example, x.inference, x.model_id) + ... + + The above example demonstrates converting the original output data type (in + this case it is PredictionResult), and converts to a Beam Row, which allows + for easier mapping in a later transform. + + ### File-based pre/postprocessing functions + + For both preprocessing and postprocessing, it is also possible to specify a + Python UDF (User-defined function) file that contains the function. This is + possible by specifying the `path` to the file (local file or GCS path) and + the `name` of the function in the file. + + For example: :: + + - type: RunInference + config: + model_handler: + type: ModelHandler + config: + param_1: arg1 + param_2: arg2 + preprocess: + path: gs://my-bucket/path/to/preprocess.py + name: my_preprocess_fn + postprocess: + path: gs://my-bucket/path/to/postprocess.py + name: my_postprocess_fn + ... + + Args: + model_handler: Specifies the parameters for the respective + enrichment_handler in a YAML/JSON format. To see the full set of + handler_config parameters, see their corresponding doc pages: + + - [VertexAIModelHandlerJSON](https://beam.apache.org/releases/pydoc/current/apache_beam.yaml.yaml_ml.VertexAIModelHandlerJSONProvider) # pylint: disable=line-too-long + inference_tag: The tag to use for the returned inference. Default is + 'inference'. + inference_args: Extra arguments for models whose inference call requires + extra parameters. Make sure to check the underlying ModelHandler docs to + see which args are allowed. + + """ + + options.YamlOptions.check_enabled(pcoll.pipeline, 'ML') + + if not isinstance(model_handler, dict): + raise ValueError( + 'Invalid model_handler specification. Expected dict but was ' + f'{type(model_handler)}.') + expected_model_handler_params = {'type', 'config'} + given_model_handler_params = set( + SafeLineLoader.strip_metadata(model_handler).keys()) + extra_params = given_model_handler_params - expected_model_handler_params + if extra_params: + raise ValueError(f'Unexpected parameters in model_handler: {extra_params}') + missing_params = expected_model_handler_params - given_model_handler_params + if missing_params: + raise ValueError(f'Missing parameters in model_handler: {missing_params}') + typ = model_handler['type'] + model_handler_provider_type = ModelHandlerProvider.handler_types.get( + typ, None) + if not model_handler_provider_type: + raise NotImplementedError(f'Unknown model handler type: {typ}.') + + model_handler_provider = ModelHandlerProvider.create_handler(model_handler) + model_handler_provider.validate(model_handler['config']) + user_type = RowTypeConstraint.from_user_type(pcoll.element_type.user_type) + schema = RowTypeConstraint.from_fields( + get_user_schema_fields(user_type) + + [(str(inference_tag), model_handler_provider.inference_output_type())]) + + return ( + pcoll | RunInference( + model_handler=KeyedModelHandler( + model_handler_provider.underlying_handler()).with_preprocess_fn( + model_handler_provider._preprocess_fn_internal()). + with_postprocess_fn( + model_handler_provider._postprocess_fn_internal()), + inference_args=inference_args) + | beam.Map( + lambda row: beam.Row(**{ + inference_tag: row[1], **row[0]._asdict() + })).with_output_types(schema)) + + def _config_to_obj(spec): if 'type' not in spec: - raise ValueError(r"Missing type in ML transform spec {spec}") + raise ValueError(f"Missing type in ML transform spec {spec}") if 'config' not in spec: - raise ValueError(r"Missing config in ML transform spec {spec}") + raise ValueError(f"Missing config in ML transform spec {spec}") constructor = _transform_constructors.get(spec['type']) if constructor is None: raise ValueError("Unknown ML transform type: %r" % spec['type']) diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py index 327023742bc6..12161d3d580d 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform.py +++ b/sdks/python/apache_beam/yaml/yaml_transform.py @@ -23,7 +23,6 @@ import os import pprint import re -import uuid from typing import Any from typing import Iterable from typing import List @@ -32,7 +31,6 @@ import jinja2 import yaml -from yaml.loader import SafeLoader import apache_beam as beam from apache_beam.io.filesystems import FileSystems @@ -42,6 +40,7 @@ from apache_beam.yaml.yaml_combine import normalize_combine from apache_beam.yaml.yaml_mapping import normalize_mapping from apache_beam.yaml.yaml_mapping import validate_generic_expressions +from apache_beam.yaml.yaml_utils import SafeLineLoader __all__ = ["YamlTransform"] @@ -130,59 +129,6 @@ def empty_if_explicitly_empty(io): return io -class SafeLineLoader(SafeLoader): - """A yaml loader that attaches line information to mappings and strings.""" - class TaggedString(str): - """A string class to which we can attach metadata. - - This is primarily used to trace a string's origin back to its place in a - yaml file. - """ - def __reduce__(self): - # Pickle as an ordinary string. - return str, (str(self), ) - - def construct_scalar(self, node): - value = super().construct_scalar(node) - if isinstance(value, str): - value = SafeLineLoader.TaggedString(value) - value._line_ = node.start_mark.line + 1 - return value - - def construct_mapping(self, node, deep=False): - mapping = super().construct_mapping(node, deep=deep) - mapping['__line__'] = node.start_mark.line + 1 - mapping['__uuid__'] = self.create_uuid() - return mapping - - @classmethod - def create_uuid(cls): - return str(uuid.uuid4()) - - @classmethod - def strip_metadata(cls, spec, tagged_str=True): - if isinstance(spec, Mapping): - return { - cls.strip_metadata(key, tagged_str): - cls.strip_metadata(value, tagged_str) - for (key, value) in spec.items() - if key not in ('__line__', '__uuid__') - } - elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)): - return [cls.strip_metadata(value, tagged_str) for value in spec] - elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str: - return str(spec) - else: - return spec - - @staticmethod - def get_line(obj): - if isinstance(obj, dict): - return obj.get('__line__', 'unknown') - else: - return getattr(obj, '_line_', 'unknown') - - class LightweightScope(object): def __init__(self, transforms): self._transforms = transforms diff --git a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py index 084e03cdb197..5bc9de24bb38 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_unit_test.py @@ -23,7 +23,6 @@ from apache_beam.yaml import YamlTransform from apache_beam.yaml import yaml_provider from apache_beam.yaml.yaml_provider import InlineProvider -from apache_beam.yaml.yaml_transform import SafeLineLoader from apache_beam.yaml.yaml_transform import Scope from apache_beam.yaml.yaml_transform import chain_as_composite from apache_beam.yaml.yaml_transform import ensure_errors_consumed @@ -39,57 +38,7 @@ from apache_beam.yaml.yaml_transform import preprocess_flattened_inputs from apache_beam.yaml.yaml_transform import preprocess_windowing from apache_beam.yaml.yaml_transform import push_windowing_to_roots - - -class SafeLineLoaderTest(unittest.TestCase): - def test_get_line(self): - pipeline_yaml = ''' - type: composite - input: - elements: input - transforms: - - type: PyMap - name: Square - input: elements - config: - fn: "lambda x: x * x" - - type: PyMap - name: Cube - input: elements - config: - fn: "lambda x: x * x * x" - output: - Flatten - ''' - spec = yaml.load(pipeline_yaml, Loader=SafeLineLoader) - self.assertEqual(SafeLineLoader.get_line(spec['type']), 2) - self.assertEqual(SafeLineLoader.get_line(spec['input']), 4) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]), 6) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['type']), 6) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['name']), 7) - self.assertEqual(SafeLineLoader.get_line(spec['transforms'][1]), 11) - self.assertEqual(SafeLineLoader.get_line(spec['output']), 17) - self.assertEqual(SafeLineLoader.get_line(spec['transforms']), "unknown") - - def test_strip_metadata(self): - spec_yaml = ''' - transforms: - - type: PyMap - name: Square - ''' - spec = yaml.load(spec_yaml, Loader=SafeLineLoader) - stripped = SafeLineLoader.strip_metadata(spec['transforms']) - - self.assertFalse(hasattr(stripped[0], '__line__')) - self.assertFalse(hasattr(stripped[0], '__uuid__')) - - def test_strip_metadata_nothing_to_strip(self): - spec_yaml = 'prop: 123' - spec = yaml.load(spec_yaml, Loader=SafeLineLoader) - stripped = SafeLineLoader.strip_metadata(spec['prop']) - - self.assertFalse(hasattr(stripped, '__line__')) - self.assertFalse(hasattr(stripped, '__uuid__')) +from apache_beam.yaml.yaml_utils import SafeLineLoader def new_pipeline(): diff --git a/sdks/python/apache_beam/yaml/yaml_utils.py b/sdks/python/apache_beam/yaml/yaml_utils.py new file mode 100644 index 000000000000..63beb90f0711 --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_utils.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import uuid +from typing import Iterable +from typing import Mapping + +from yaml import SafeLoader + + +class SafeLineLoader(SafeLoader): + """A yaml loader that attaches line information to mappings and strings.""" + class TaggedString(str): + """A string class to which we can attach metadata. + + This is primarily used to trace a string's origin back to its place in a + yaml file. + """ + def __reduce__(self): + # Pickle as an ordinary string. + return str, (str(self), ) + + def construct_scalar(self, node): + value = super().construct_scalar(node) + if isinstance(value, str): + value = SafeLineLoader.TaggedString(value) + value._line_ = node.start_mark.line + 1 + return value + + def construct_mapping(self, node, deep=False): + mapping = super().construct_mapping(node, deep=deep) + mapping['__line__'] = node.start_mark.line + 1 + mapping['__uuid__'] = self.create_uuid() + return mapping + + @classmethod + def create_uuid(cls): + return str(uuid.uuid4()) + + @classmethod + def strip_metadata(cls, spec, tagged_str=True): + if isinstance(spec, Mapping): + return { + cls.strip_metadata(key, tagged_str): + cls.strip_metadata(value, tagged_str) + for (key, value) in spec.items() + if key not in ('__line__', '__uuid__') + } + elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)): + return [cls.strip_metadata(value, tagged_str) for value in spec] + elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str: + return str(spec) + else: + return spec + + @staticmethod + def get_line(obj): + if isinstance(obj, dict): + return obj.get('__line__', 'unknown') + else: + return getattr(obj, '_line_', 'unknown') diff --git a/sdks/python/apache_beam/yaml/yaml_utils_test.py b/sdks/python/apache_beam/yaml/yaml_utils_test.py new file mode 100644 index 000000000000..4fd2c793e57e --- /dev/null +++ b/sdks/python/apache_beam/yaml/yaml_utils_test.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import unittest + +import yaml + +from apache_beam.yaml.yaml_utils import SafeLineLoader + + +class SafeLineLoaderTest(unittest.TestCase): + def test_get_line(self): + pipeline_yaml = ''' + type: composite + input: + elements: input + transforms: + - type: PyMap + name: Square + input: elements + config: + fn: "lambda x: x * x" + - type: PyMap + name: Cube + input: elements + config: + fn: "lambda x: x * x * x" + output: + Flatten + ''' + spec = yaml.load(pipeline_yaml, Loader=SafeLineLoader) + self.assertEqual(SafeLineLoader.get_line(spec['type']), 2) + self.assertEqual(SafeLineLoader.get_line(spec['input']), 4) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]), 6) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['type']), 6) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][0]['name']), 7) + self.assertEqual(SafeLineLoader.get_line(spec['transforms'][1]), 11) + self.assertEqual(SafeLineLoader.get_line(spec['output']), 17) + self.assertEqual(SafeLineLoader.get_line(spec['transforms']), "unknown") + + def test_strip_metadata(self): + spec_yaml = ''' + transforms: + - type: PyMap + name: Square + ''' + spec = yaml.load(spec_yaml, Loader=SafeLineLoader) + stripped = SafeLineLoader.strip_metadata(spec['transforms']) + + self.assertFalse(hasattr(stripped[0], '__line__')) + self.assertFalse(hasattr(stripped[0], '__uuid__')) + + def test_strip_metadata_nothing_to_strip(self): + spec_yaml = 'prop: 123' + spec = yaml.load(spec_yaml, Loader=SafeLineLoader) + stripped = SafeLineLoader.strip_metadata(spec['prop']) + + self.assertFalse(hasattr(stripped, '__line__')) + self.assertFalse(hasattr(stripped, '__uuid__')) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() From c1b9facf7e0ee50fcf9cd07fdbae9e45d1715391 Mon Sep 17 00:00:00 2001 From: Damon Date: Thu, 26 Dec 2024 10:19:38 -0800 Subject: [PATCH 2/8] Clean up Python Tests workflow (#33396) * Remove static GCP credentials from workflow * Remove workflow_dispatch blocking input * Remove conditional from Python SDK source step * Fix parenthesis error * Remove redundant Dataflow test --- .github/workflows/python_tests.yml | 56 ++---------------------------- 1 file changed, 2 insertions(+), 54 deletions(-) diff --git a/.github/workflows/python_tests.yml b/.github/workflows/python_tests.yml index 3000d1871be3..2c3b39a33c1d 100644 --- a/.github/workflows/python_tests.yml +++ b/.github/workflows/python_tests.yml @@ -30,10 +30,6 @@ on: tags: 'v*' paths: ['sdks/python/**', 'model/**'] workflow_dispatch: - inputs: - runDataflow: - description: 'Type "true" if you want to run Dataflow tests (GCP variables must be configured, check CI.md)' - default: false # This allows a subsequently queued workflow run to interrupt previous runs concurrency: @@ -57,7 +53,6 @@ jobs: GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} GCP_REGION: ${{ secrets.GCP_REGION }} GCP_SA_EMAIL: ${{ secrets.GCP_SA_EMAIL }} - GCP_SA_KEY: ${{ secrets.GCP_SA_KEY }} GCP_TESTING_BUCKET: ${{ secrets.GCP_TESTING_BUCKET }} GCP_PYTHON_WHEELS_BUCKET: "not-needed-here" @@ -65,8 +60,8 @@ jobs: name: 'Build python source distribution' if: | needs.check_gcp_variables.outputs.gcp-variables-set == 'true' && ( - (github.event_name == 'push' || github.event_name == 'schedule') || - (github.event_name == 'workflow_dispatch' && github.event.inputs.runDataflow == 'true') + ((github.event_name == 'push' || github.event_name == 'schedule') || + github.event_name == 'workflow_dispatch') ) needs: - check_gcp_variables @@ -153,50 +148,3 @@ jobs: working-directory: ./sdks/python shell: bash run: python -m apache_beam.examples.wordcount --input MANIFEST.in --output counts - - python_wordcount_dataflow: - name: 'Python Wordcount Dataflow' - # TODO(https://github.com/apache/beam/issues/31848) run on Dataflow after fixes credential on macOS/win GHA runner - if: (github.event_name == 'workflow_dispatch' && github.event.inputs.runDataflow == 'true') - needs: - - build_python_sdk_source - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [[self-hosted, ubuntu-20.04, main], macos-latest, windows-latest] - python: ["3.9", "3.10", "3.11", "3.12"] - steps: - - name: Checkout code - uses: actions/checkout@v4 - - name: Setup environment - uses: ./.github/actions/setup-environment-action - with: - python-version: ${{ matrix.python }} - go-version: default - - name: Download source from artifacts - uses: actions/download-artifact@v4.1.8 - with: - name: python_sdk_source - path: apache-beam-source - - name: Authenticate on GCP - id: auth - uses: google-github-actions/auth@v1 - with: - credentials_json: ${{ secrets.GCP_SA_KEY }} - project_id: ${{ secrets.GCP_PROJECT_ID }} - - name: Install requirements - working-directory: ./sdks/python - run: pip install setuptools --upgrade && pip install -e ".[gcp]" - - name: Run WordCount - working-directory: ./sdks/python - shell: bash - run: | - python -m apache_beam.examples.wordcount \ - --input gs://dataflow-samples/shakespeare/kinglear.txt \ - --output gs://${{ secrets.GCP_TESTING_BUCKET }}/python_wordcount_dataflow/counts \ - --runner DataflowRunner \ - --project ${{ secrets.GCP_PROJECT_ID }} \ - --region ${{ secrets.GCP_REGION }} \ - --temp_location gs://${{ secrets.GCP_TESTING_BUCKET }}/tmp/python_wordcount_dataflow/ \ - --sdk_location ../../apache-beam-source/apache-beam-source.tar.gz From 74bcba1180f80b9361b973351dc1a5eed035724e Mon Sep 17 00:00:00 2001 From: Jeff Kinard Date: Thu, 26 Dec 2024 14:55:17 -0600 Subject: [PATCH 3/8] [yaml] allow logging bytes in LogForTesting (#33433) Signed-off-by: Jeffrey Kinard --- sdks/python/apache_beam/yaml/yaml_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py index a07638953551..8dfa314aeb62 100755 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -907,7 +907,7 @@ def log_for_testing( def to_loggable_json_recursive(o): if isinstance(o, (str, bytes)): - return o + return str(o) elif callable(getattr(o, '_asdict', None)): return to_loggable_json_recursive(o._asdict()) elif isinstance(o, Mapping) and callable(getattr(o, 'items', None)): From ab738554eecf3b239a0d2a86f6eda3ebb8ee281c Mon Sep 17 00:00:00 2001 From: Damon Date: Thu, 26 Dec 2024 13:14:33 -0800 Subject: [PATCH 4/8] Update .GitHub/workflows README.md (#33341) * Update .GitHub/workflows README.md * Update .github/workflows/README.md Co-authored-by: Danny McCormick * Add screenshot showing how to run workflow against branch. * Remove trailing whitespace --------- Co-authored-by: Danny McCormick --- .github/workflows/README.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 206364f416f7..de85a99d7bc9 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -15,6 +15,27 @@ under the License. --> +# How to fix Workflows for Committers + +The following is guidance on how to practically make changes that fix workflows. + +1) Create a branch in https://github.com/apache/beam not your fork. + + The reason to perform changes to a branch of the main repo instead of your fork is due to the challenge in replicating the environment within which Beam GitHub workflows execute. GitHub workflows allow you to execute against a branch of a repo. + +2) Make changes in this branch you anticipate will fix the failing workflow. + +3) Run the workflow designating your branch. + + In the GitHub workflow interface, you can designate any branch of the repository to run the workflow against. Selecting your branch allows you to test the changes you made. The following screenshot shows an example of this feature. + ![image](https://github.com/user-attachments/assets/33ca43fb-b0f8-42c8-80e2-ac84a49e2490) + +5) Create a PR, pasting the link to your successful workflow run in the branch + + When doing a PR, the checks will not run against your branch. Your reviewer may not know this so you'll want to mention this in your PR description, pasting the link to your successful run. + +6) After PR merges, execute the workflow manually to validate your merged changes. + # Running Workflows Manually Most workflows will get kicked off automatically when you open a PR, push code, or on a schedule. From bf7e317a7aedd104099c4c1d029bb9716016f617 Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com> Date: Fri, 27 Dec 2024 14:09:15 +0000 Subject: [PATCH 5/8] [Managed BigQuery] use file loads with Avro format for better performance (#33392) * use avro file format * add comment * add unit test --- ...QueryFileLoadsSchemaTransformProvider.java | 9 ++++---- .../PortableBigQueryDestinations.java | 15 ++++++++++++ .../io/gcp/bigquery/BigQueryIOWriteTest.java | 23 +++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java index 7872c91d1f72..8899ac82eb06 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryFileLoadsSchemaTransformProvider.java @@ -25,7 +25,6 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.schemas.Schema; @@ -97,20 +96,22 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { return PCollectionRowTuple.empty(input.getPipeline()); } - BigQueryIO.Write toWrite(Schema schema, PipelineOptions options) { + @VisibleForTesting + public BigQueryIO.Write toWrite(Schema schema, PipelineOptions options) { PortableBigQueryDestinations dynamicDestinations = new PortableBigQueryDestinations(schema, configuration); BigQueryIO.Write write = BigQueryIO.write() .to(dynamicDestinations) .withMethod(BigQueryIO.Write.Method.FILE_LOADS) - .withFormatFunction(BigQueryUtils.toTableRow()) // TODO(https://github.com/apache/beam/issues/33074) BatchLoad's // createTempFilePrefixView() doesn't pick up the pipeline option .withCustomGcsTempLocation( ValueProvider.StaticValueProvider.of(options.getTempLocation())) .withWriteDisposition(WriteDisposition.WRITE_APPEND) - .withFormatFunction(dynamicDestinations.getFilterFormatFunction(false)); + // Use Avro format for better performance. Don't change this unless it's for a good + // reason. + .withAvroFormatFunction(dynamicDestinations.getAvroFilterFormatFunction(false)); if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) { CreateDisposition createDisposition = diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java index 54d125012eac..0cd2b65b0858 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/PortableBigQueryDestinations.java @@ -25,7 +25,10 @@ import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; import java.util.List; +import org.apache.avro.generic.GenericRecord; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.io.gcp.bigquery.AvroWriteRequest; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils; import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations; import org.apache.beam.sdk.io.gcp.bigquery.TableDestination; @@ -102,4 +105,16 @@ public SerializableFunction getFilterFormatFunction(boolean fetch return BigQueryUtils.toTableRow(filtered); }; } + + public SerializableFunction, GenericRecord> getAvroFilterFormatFunction( + boolean fetchNestedRecord) { + return request -> { + Row row = request.getElement(); + if (fetchNestedRecord) { + row = checkStateNotNull(row.getRow(RECORD)); + } + Row filtered = rowFilter.filter(row); + return AvroUtils.toGenericRecord(filtered); + }; + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java index 69994c019509..57c71c023fcb 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java @@ -19,6 +19,7 @@ import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.toJsonString; import static org.apache.beam.sdk.io.gcp.bigquery.WriteTables.ResultCoder.INSTANCE; +import static org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryFileLoadsSchemaTransformProvider.BigQueryFileLoadsSchemaTransform; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; @@ -32,6 +33,7 @@ import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -117,11 +119,13 @@ import org.apache.beam.sdk.io.gcp.bigquery.WritePartition.ResultCoder; import org.apache.beam.sdk.io.gcp.bigquery.WriteRename.TempTableCleanupFn; import org.apache.beam.sdk.io.gcp.bigquery.WriteTables.Result; +import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryFileLoadsSchemaTransformProvider; import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; import org.apache.beam.sdk.io.gcp.testing.FakeJobService; import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.schemas.JavaFieldSchema; import org.apache.beam.sdk.schemas.Schema; @@ -818,6 +822,25 @@ public void testStreamingFileLoadsWithAutoSharding() throws Exception { assertEquals(2 * numTables, fakeDatasetService.getInsertCount()); } + @Test + public void testFileLoadSchemaTransformUsesAvroFormat() { + // ensure we are writing with the more performant avro format + assumeTrue(!useStreaming); + assumeTrue(!useStorageApi); + BigQueryFileLoadsSchemaTransformProvider provider = + new BigQueryFileLoadsSchemaTransformProvider(); + Row configuration = + Row.withSchema(provider.configurationSchema()) + .withFieldValue("table", "some-table") + .build(); + BigQueryFileLoadsSchemaTransform schemaTransform = + (BigQueryFileLoadsSchemaTransform) provider.from(configuration); + BigQueryIO.Write write = + schemaTransform.toWrite(Schema.of(), PipelineOptionsFactory.create()); + assertNull(write.getFormatFunction()); + assertNotNull(write.getAvroRowWriterFactory()); + } + @Test public void testBatchFileLoads() throws Exception { assumeTrue(!useStreaming); From 1abc6c8195aefdcecace3890bacc6f650e6f1898 Mon Sep 17 00:00:00 2001 From: Filipe Regadas Date: Fri, 27 Dec 2024 15:38:22 +0000 Subject: [PATCH 6/8] Support Iceberg partition identity transform (#33332) * Support Iceberg partition identity transform * remove uneeded avro dep * Trigger icerberg integration tests * Revert "remove uneeded avro dep" This reverts commit 0b075af322c0cdec3f7ed06593d7c0766c8b654c. --- .../IO_Iceberg_Integration_Tests.json | 2 +- sdks/java/io/iceberg/build.gradle | 1 + .../beam/sdk/io/iceberg/ScanTaskReader.java | 36 +++++- .../sdk/io/iceberg/IcebergIOReadTest.java | 72 ++++++++++++ .../sdk/io/iceberg/TestDataWarehouse.java | 15 ++- .../beam/sdk/io/iceberg/TestFixtures.java | 111 +++++++++++------- 6 files changed, 188 insertions(+), 49 deletions(-) diff --git a/.github/trigger_files/IO_Iceberg_Integration_Tests.json b/.github/trigger_files/IO_Iceberg_Integration_Tests.json index a84f69a97721..5cf4f475f317 100644 --- a/.github/trigger_files/IO_Iceberg_Integration_Tests.json +++ b/.github/trigger_files/IO_Iceberg_Integration_Tests.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 6 + "modification": 7 } diff --git a/sdks/java/io/iceberg/build.gradle b/sdks/java/io/iceberg/build.gradle index 319848b7626b..cd9e7044632b 100644 --- a/sdks/java/io/iceberg/build.gradle +++ b/sdks/java/io/iceberg/build.gradle @@ -45,6 +45,7 @@ dependencies { implementation library.java.vendored_guava_32_1_2_jre implementation project(path: ":sdks:java:core", configuration: "shadow") implementation project(path: ":model:pipeline", configuration: "shadow") + implementation library.java.avro implementation library.java.slf4j_api implementation library.java.joda_time implementation "org.apache.parquet:parquet-column:$parquet_version" diff --git a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ScanTaskReader.java b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ScanTaskReader.java index b7cb42b2eacb..5784dfd79744 100644 --- a/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ScanTaskReader.java +++ b/sdks/java/io/iceberg/src/main/java/org/apache/beam/sdk/io/iceberg/ScanTaskReader.java @@ -21,15 +21,22 @@ import java.io.IOException; import java.util.ArrayDeque; +import java.util.Collections; +import java.util.Map; import java.util.NoSuchElementException; import java.util.Queue; +import java.util.Set; +import java.util.function.BiFunction; import javax.annotation.Nullable; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.values.Row; import org.apache.iceberg.DataFile; import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.Schema; import org.apache.iceberg.Table; import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.data.IdentityPartitionConverters; import org.apache.iceberg.data.Record; import org.apache.iceberg.data.avro.DataReader; import org.apache.iceberg.data.orc.GenericOrcReader; @@ -42,6 +49,9 @@ import org.apache.iceberg.io.InputFile; import org.apache.iceberg.orc.ORC; import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.util.PartitionUtil; import org.checkerframework.checker.nullness.qual.NonNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -112,6 +122,8 @@ public boolean advance() throws IOException { FileScanTask fileTask = fileScanTasks.remove(); DataFile file = fileTask.file(); InputFile input = decryptor.getInputFile(fileTask); + Map idToConstants = + constantsMap(fileTask, IdentityPartitionConverters::convertConstant, project); CloseableIterable iterable; switch (file.format()) { @@ -121,7 +133,9 @@ public boolean advance() throws IOException { ORC.read(input) .split(fileTask.start(), fileTask.length()) .project(project) - .createReaderFunc(fileSchema -> GenericOrcReader.buildReader(project, fileSchema)) + .createReaderFunc( + fileSchema -> + GenericOrcReader.buildReader(project, fileSchema, idToConstants)) .filter(fileTask.residual()) .build(); break; @@ -132,7 +146,8 @@ public boolean advance() throws IOException { .split(fileTask.start(), fileTask.length()) .project(project) .createReaderFunc( - fileSchema -> GenericParquetReaders.buildReader(project, fileSchema)) + fileSchema -> + GenericParquetReaders.buildReader(project, fileSchema, idToConstants)) .filter(fileTask.residual()) .build(); break; @@ -142,7 +157,8 @@ public boolean advance() throws IOException { Avro.read(input) .split(fileTask.start(), fileTask.length()) .project(project) - .createReaderFunc(DataReader::create) + .createReaderFunc( + fileSchema -> DataReader.create(project, fileSchema, idToConstants)) .build(); break; default: @@ -155,6 +171,20 @@ public boolean advance() throws IOException { return false; } + private Map constantsMap( + FileScanTask task, BiFunction converter, Schema schema) { + PartitionSpec spec = task.spec(); + Set idColumns = spec.identitySourceIds(); + Schema partitionSchema = TypeUtil.select(schema, idColumns); + boolean projectsIdentityPartitionColumns = !partitionSchema.columns().isEmpty(); + + if (projectsIdentityPartitionColumns) { + return PartitionUtil.constantsMap(task, converter); + } else { + return Collections.emptyMap(); + } + } + @Override public Row getCurrent() throws NoSuchElementException { if (current == null) { diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOReadTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOReadTest.java index fe4a07dedfdf..39c621975547 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOReadTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOReadTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.iceberg; +import static org.apache.beam.sdk.io.iceberg.TestFixtures.createRecord; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; @@ -35,8 +36,11 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.PartitionKey; +import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Table; import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.types.Types; import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; @@ -122,4 +126,72 @@ public void testSimpleScan() throws Exception { testPipeline.run(); } + + @Test + public void testIdentityColumnScan() throws Exception { + TableIdentifier tableId = + TableIdentifier.of("default", "table" + Long.toString(UUID.randomUUID().hashCode(), 16)); + Table simpleTable = warehouse.createTable(tableId, TestFixtures.SCHEMA); + + String identityColumnName = "identity"; + String identityColumnValue = "some-value"; + simpleTable.updateSchema().addColumn(identityColumnName, Types.StringType.get()).commit(); + simpleTable.updateSpec().addField(identityColumnName).commit(); + + PartitionSpec spec = simpleTable.spec(); + PartitionKey partitionKey = new PartitionKey(simpleTable.spec(), simpleTable.schema()); + partitionKey.set(0, identityColumnValue); + + simpleTable + .newFastAppend() + .appendFile( + warehouse.writeRecords( + "file1s1.parquet", + TestFixtures.SCHEMA, + spec, + partitionKey, + TestFixtures.FILE1SNAPSHOT1)) + .commit(); + + final Schema schema = IcebergUtils.icebergSchemaToBeamSchema(simpleTable.schema()); + final List expectedRows = + Stream.of(TestFixtures.FILE1SNAPSHOT1_DATA) + .flatMap(List::stream) + .map( + d -> + ImmutableMap.builder() + .putAll(d) + .put(identityColumnName, identityColumnValue) + .build()) + .map(r -> createRecord(simpleTable.schema(), r)) + .map(record -> IcebergUtils.icebergRecordToBeamRow(schema, record)) + .collect(Collectors.toList()); + + Map catalogProps = + ImmutableMap.builder() + .put("type", CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP) + .put("warehouse", warehouse.location) + .build(); + + IcebergCatalogConfig catalogConfig = + IcebergCatalogConfig.builder() + .setCatalogName("name") + .setCatalogProperties(catalogProps) + .build(); + + PCollection output = + testPipeline + .apply(IcebergIO.readRows(catalogConfig).from(tableId)) + .apply(ParDo.of(new PrintRow())) + .setCoder(RowCoder.of(IcebergUtils.icebergSchemaToBeamSchema(simpleTable.schema()))); + + PAssert.that(output) + .satisfies( + (Iterable rows) -> { + assertThat(rows, containsInAnyOrder(expectedRows.toArray())); + return null; + }); + + testPipeline.run(); + } } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestDataWarehouse.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestDataWarehouse.java index 1e1c84d31de9..9352123b5c77 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestDataWarehouse.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestDataWarehouse.java @@ -32,6 +32,7 @@ import org.apache.iceberg.FileFormat; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; import org.apache.iceberg.Table; import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.catalog.Namespace; @@ -108,6 +109,16 @@ protected void after() { public DataFile writeRecords(String filename, Schema schema, List records) throws IOException { + return writeRecords(filename, schema, PartitionSpec.unpartitioned(), null, records); + } + + public DataFile writeRecords( + String filename, + Schema schema, + PartitionSpec spec, + StructLike partition, + List records) + throws IOException { Path path = new Path(location, filename); FileFormat format = FileFormat.fromFileName(filename); @@ -134,9 +145,11 @@ public DataFile writeRecords(String filename, Schema schema, List record } appender.addAll(records); appender.close(); - return DataFiles.builder(PartitionSpec.unpartitioned()) + + return DataFiles.builder(spec) .withInputFile(HadoopInputFile.fromPath(path, hadoopConf)) .withMetrics(appender.metrics()) + .withPartition(partition) .build(); } diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestFixtures.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestFixtures.java index 6143bd03491d..a2ca86d1b5a2 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestFixtures.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/TestFixtures.java @@ -21,11 +21,13 @@ import static org.apache.iceberg.types.Types.NestedField.required; import java.util.ArrayList; +import java.util.List; +import java.util.Map; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.iceberg.Schema; -import org.apache.iceberg.data.GenericRecord; import org.apache.iceberg.data.Record; import org.apache.iceberg.types.Types; @@ -34,58 +36,75 @@ public class TestFixtures { new Schema( required(1, "id", Types.LongType.get()), optional(2, "data", Types.StringType.get())); - private static final Record genericRecord = GenericRecord.create(SCHEMA); - - /* First file in test table */ - public static final ImmutableList FILE1SNAPSHOT1 = + public static final List> FILE1SNAPSHOT1_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 0L, "data", "clarification")), - genericRecord.copy(ImmutableMap.of("id", 1L, "data", "risky")), - genericRecord.copy(ImmutableMap.of("id", 2L, "data", "falafel"))); - public static final ImmutableList FILE1SNAPSHOT2 = + ImmutableMap.of("id", 0L, "data", "clarification"), + ImmutableMap.of("id", 1L, "data", "risky"), + ImmutableMap.of("id", 2L, "data", "falafel")); + public static final List> FILE1SNAPSHOT2_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 3L, "data", "obscure")), - genericRecord.copy(ImmutableMap.of("id", 4L, "data", "secure")), - genericRecord.copy(ImmutableMap.of("id", 5L, "data", "feta"))); - public static final ImmutableList FILE1SNAPSHOT3 = + ImmutableMap.of("id", 3L, "data", "obscure"), + ImmutableMap.of("id", 4L, "data", "secure"), + ImmutableMap.of("id", 5L, "data", "feta")); + public static final List> FILE1SNAPSHOT3_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 6L, "data", "brainy")), - genericRecord.copy(ImmutableMap.of("id", 7L, "data", "film")), - genericRecord.copy(ImmutableMap.of("id", 8L, "data", "feta"))); - - /* Second file in test table */ - public static final ImmutableList FILE2SNAPSHOT1 = + ImmutableMap.of("id", 6L, "data", "brainy"), + ImmutableMap.of("id", 7L, "data", "film"), + ImmutableMap.of("id", 8L, "data", "feta")); + public static final List> FILE2SNAPSHOT1_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 10L, "data", "clammy")), - genericRecord.copy(ImmutableMap.of("id", 11L, "data", "evacuate")), - genericRecord.copy(ImmutableMap.of("id", 12L, "data", "tissue"))); - public static final ImmutableList FILE2SNAPSHOT2 = + ImmutableMap.of("id", 10L, "data", "clammy"), + ImmutableMap.of("id", 11L, "data", "evacuate"), + ImmutableMap.of("id", 12L, "data", "tissue")); + public static final List> FILE2SNAPSHOT2_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 14L, "data", "radical")), - genericRecord.copy(ImmutableMap.of("id", 15L, "data", "collocation")), - genericRecord.copy(ImmutableMap.of("id", 16L, "data", "book"))); - public static final ImmutableList FILE2SNAPSHOT3 = + ImmutableMap.of("id", 14L, "data", "radical"), + ImmutableMap.of("id", 15L, "data", "collocation"), + ImmutableMap.of("id", 16L, "data", "book")); + public static final List> FILE2SNAPSHOT3_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 16L, "data", "cake")), - genericRecord.copy(ImmutableMap.of("id", 17L, "data", "intrinsic")), - genericRecord.copy(ImmutableMap.of("id", 18L, "data", "paper"))); - - /* Third file in test table */ - public static final ImmutableList FILE3SNAPSHOT1 = + ImmutableMap.of("id", 16L, "data", "cake"), + ImmutableMap.of("id", 17L, "data", "intrinsic"), + ImmutableMap.of("id", 18L, "data", "paper")); + public static final List> FILE3SNAPSHOT1_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 20L, "data", "ocean")), - genericRecord.copy(ImmutableMap.of("id", 21L, "data", "holistic")), - genericRecord.copy(ImmutableMap.of("id", 22L, "data", "preventative"))); - public static final ImmutableList FILE3SNAPSHOT2 = + ImmutableMap.of("id", 20L, "data", "ocean"), + ImmutableMap.of("id", 21L, "data", "holistic"), + ImmutableMap.of("id", 22L, "data", "preventative")); + public static final List> FILE3SNAPSHOT2_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 24L, "data", "cloud")), - genericRecord.copy(ImmutableMap.of("id", 25L, "data", "zen")), - genericRecord.copy(ImmutableMap.of("id", 26L, "data", "sky"))); - public static final ImmutableList FILE3SNAPSHOT3 = + ImmutableMap.of("id", 24L, "data", "cloud"), + ImmutableMap.of("id", 25L, "data", "zen"), + ImmutableMap.of("id", 26L, "data", "sky")); + public static final List> FILE3SNAPSHOT3_DATA = ImmutableList.of( - genericRecord.copy(ImmutableMap.of("id", 26L, "data", "belleview")), - genericRecord.copy(ImmutableMap.of("id", 27L, "data", "overview")), - genericRecord.copy(ImmutableMap.of("id", 28L, "data", "tender"))); + ImmutableMap.of("id", 26L, "data", "belleview"), + ImmutableMap.of("id", 27L, "data", "overview"), + ImmutableMap.of("id", 28L, "data", "tender")); + + /* First file in test table */ + public static final List FILE1SNAPSHOT1 = + Lists.transform(FILE1SNAPSHOT1_DATA, d -> createRecord(SCHEMA, d)); + public static final List FILE1SNAPSHOT2 = + Lists.transform(FILE1SNAPSHOT2_DATA, d -> createRecord(SCHEMA, d)); + public static final List FILE1SNAPSHOT3 = + Lists.transform(FILE1SNAPSHOT3_DATA, d -> createRecord(SCHEMA, d)); + + /* Second file in test table */ + public static final List FILE2SNAPSHOT1 = + Lists.transform(FILE2SNAPSHOT1_DATA, d -> createRecord(SCHEMA, d)); + public static final List FILE2SNAPSHOT2 = + Lists.transform(FILE2SNAPSHOT2_DATA, d -> createRecord(SCHEMA, d)); + public static final List FILE2SNAPSHOT3 = + Lists.transform(FILE2SNAPSHOT3_DATA, d -> createRecord(SCHEMA, d)); + + /* Third file in test table */ + public static final List FILE3SNAPSHOT1 = + Lists.transform(FILE3SNAPSHOT1_DATA, d -> createRecord(SCHEMA, d)); + public static final List FILE3SNAPSHOT2 = + Lists.transform(FILE3SNAPSHOT2_DATA, d -> createRecord(SCHEMA, d)); + public static final List FILE3SNAPSHOT3 = + Lists.transform(FILE3SNAPSHOT3_DATA, d -> createRecord(SCHEMA, d)); public static final ImmutableList asRows(Iterable records) { ArrayList rows = new ArrayList<>(); @@ -98,4 +117,8 @@ public static final ImmutableList asRows(Iterable records) { } return ImmutableList.copyOf(rows); } + + public static Record createRecord(org.apache.iceberg.Schema schema, Map values) { + return org.apache.iceberg.data.GenericRecord.create(schema).copy(values); + } } From 067deed9e4e4c8bab4fa7c9fa55e150a0555e462 Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com> Date: Fri, 27 Dec 2024 15:42:03 +0000 Subject: [PATCH 7/8] python multi-lang with schematransforms guide (#33362) --- ...n-custom-multi-language-pipelines-guide.md | 307 ++++++++++++++++++ 1 file changed, 307 insertions(+) create mode 100644 website/www/site/content/en/documentation/sdks/python-custom-multi-language-pipelines-guide.md diff --git a/website/www/site/content/en/documentation/sdks/python-custom-multi-language-pipelines-guide.md b/website/www/site/content/en/documentation/sdks/python-custom-multi-language-pipelines-guide.md new file mode 100644 index 000000000000..60523cbb3b2a --- /dev/null +++ b/website/www/site/content/en/documentation/sdks/python-custom-multi-language-pipelines-guide.md @@ -0,0 +1,307 @@ +--- +type: languages +title: "Python custom multi-language pipelines guide" +--- + + +# Python custom multi-language pipelines guide + +Apache Beam's powerful model enables the development of scalable, resilient, and production-ready transforms, but the process often requires significant time and effort. + +With SDKs available in multiple languages (Java, Python, Golang, YAML, etc.), creating and maintaining transforms for each language becomes a challenge, particularly for IOs. Developers must navigate different APIs, address unique quirks, and manage ongoing maintenance—such as updates, new features, and documentation—while ensuring consistent behavior across SDKs. This results in redundant work, as the same functionality is implemented repeatedly for each language (M x N effort, where M is the number of SDKs and N is the number of transforms). + +To streamline this process, Beam’s portability framework enables the use of portable transforms that can be shared across languages. This reduces duplication, allowing developers to focus on maintaining only N transforms. Pipelines combining [portable transforms](#portable-transform) from other SDK(s) are known as [“multi-language” pipelines](../programming-guide.md#13-multi-language-pipelines-multi-language-pipelines). + +The SchemaTransform framework represents the latest advancement in enhancing this multi-language capability. + +The following jumps straight into the guide. Check out the [appendix](#appendix) section below for some of the terminology used here. For a runnable example, check out this [page](python-multi-language-pipelines-2.md). + +## Create a Java SchemaTransform + +For better readability, use [**TypedSchemaTransformProvider**](https://beam.apache.org/releases/javadoc/current/index.html?org/apache/beam/sdk/schemas/transforms/TypedSchemaTransformProvider.html), a [SchemaTransformProvider](#schematransformprovider) parameterized on a custom configuration type `T`. TypedSchemaTransformProvider will take care of converting the custom type definition to a Beam [Schema](../basics.md#schema), and converting an instance to a Beam Row. + +```java +TypedSchemaTransformProvider extends SchemaTransformProvider { + String identifier(); + + SchemaTransform from(T configuration); +} +``` + +### Implement a configuration + +First, set up a Beam Schema-compatible configuration. This will be used to construct the transform. AutoValue types are encouraged for readability. Adding the appropriate `@DefaultSchema` annotation will help Beam do the conversions mentioned above. + +```java +@DefaultSchema(AutoValueSchema.class) +@AutoValue +public abstract static class MyConfiguration { + public static Builder builder() { + return new AutoValue_MyConfiguration.Builder(); + } + @SchemaFieldDescription("Description of what foo does...") + public abstract String getFoo(); + + @SchemaFieldDescription("Description of what bar does...") + public abstract Integer getBar(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setFoo(String foo); + + public abstract Builder setBar(Integer bar); + + public abstract MyConfiguration build(); + } +} +``` + +This configuration is surfaced to foreign SDKs. For example, when using this transform in Python, use the following format: + +```python +with beam.Pipeline() as p: + (p + | Create([...]) + | MySchemaTransform(foo="abc", bar=123) +``` + +When using this transform in YAML, use the following format: + +```yaml +pipeline: + transforms: + - type: Create + ... + - type: MySchemaTransform + config: + foo: "abc" + bar: 123 +``` + +### Implement a TypedSchemaTransformProvider +Next, implement the `TypedSchemaTransformProvider`. The following two methods are required: + +- `identifier`: Returns a unique identifier for this transform. The [Beam standard](../programming-guide.md#1314-defining-a-urn) follows this structure: `:::`. +- `from`: Builds the transform using a provided configuration. + +An [expansion service](#expansion-service) uses these methods to find and build the transform. The `@AutoService(SchemaTransformProvider.class)` annotation is also required to ensure this provider is recognized by the expansion service. + +```java +@AutoService(SchemaTransformProvider.class) +public class MyProvider extends TypedSchemaTransformProvider { + @Override + public String identifier() { + return "beam:schematransform:org.apache.beam:my_transform:v1"; + } + + @Override + protected SchemaTransform from(MyConfiguration configuration) { + return new MySchemaTransform(configuration); + } + + private static class MySchemaTransform extends SchemaTransform { + private final MyConfiguration config; + MySchemaTransform(MyConfiguration configuration) { + this.config = configuration; + } + + @Override + public PCollectionRowTuple expand(PCollectionRowTuple input) { + PCollection inputRows = input.get("input"); + PCollection outputRows = inputRows.apply( + new MyJavaTransform(config.getFoo(), config.getBar())); + + return PCollectionRowTuple.of("output", outputRows); + } + } +} +``` + +#### Additional metadata (optional) +The following optional methods can help provide relevant metadata: +- `description`: Provide a human-readable description for the transform. Remote SDKs can use this text to generate documentation. +- `inputCollectionNames`: Provide PCollection tags that this transform expects to take in. +- `outputCollectionNames`: Provide PCollection tags this transform expects to produce. + +```java + @Override + public String description() { + return "This transform does this and that..."; + } + + @Override + public List inputCollectionNames() { + return Arrays.asList("input_1", "input_2"); + } + + @Override + public List outputCollectionNames() { + return Collections.singletonList("output"); + } +``` + +## Build an expansion service that contains the transform + +Use an expansion service to make the transform available to foreign SDKs. + +First, build a shaded JAR file that includes: +1. the transform, +2. the [**ExpansionService artifact**](https://central.sonatype.com/artifact/org.apache.beam/beam-sdks-java-expansion-service), +3. and some additional dependencies. + +### Gradle build file +```groovy +plugins { + id 'com.github.johnrengelman.shadow' version '8.1.1' + id 'application' +} + +mainClassName = "org.apache.beam.sdk.expansion.service.ExpansionService" + +dependencies { + // Dependencies for your transform + ... + + // Beam's expansion service + runtimeOnly "org.apache.beam:beam-sdks-java-expansion-service:$beamVersion" + // AutoService annotation for our SchemaTransform provider + compileOnly "com.google.auto.service:auto-service-annotations:1.0.1" + annotationProcessor "com.google.auto.service:auto-service:1.0.1" + // AutoValue annotation for our configuration object + annotationProcessor "com.google.auto.value:auto-value:1.9" +} +``` + +Next, run the shaded JAR file, and provide a port to host the service. A list of available SchemaTransformProviders will be displayed. + +```shell +$ java -jar path/to/my-expansion-service.jar 12345 + +Starting expansion service at localhost:12345 + +Registered transforms: + ... +Registered SchemaTransformProviders: + beam:schematransform:org.apache.beam:my_transform:v1 +``` + +The transform is discoverable at `localhost:12345`. Foreign SDKs can now discover and add it to their pipelines. The next section demonstrates how to do this with a Python pipeline. + +## Use the portable transform in a Python pipeline + +The Python SDK’s [**ExternalTransformProvider**](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.external_transform_provider.html#apache_beam.transforms.external_transform_provider.ExternalTransformProvider) +can dynamically generate wrappers for portable transforms. + +```python +from apache_beam.transforms.external_transform_provider import ExternalTransformProvider +``` + +### Connect to an expansion service +First, connect to an expansion service that contains the transform. This section demonstrates two methods of connecting to the expansion service. + +#### Connect to an already running service + +If your expansion service JAR file is already running, pass in the address: + +```python +provider = ExternalTransformProvider("localhost:12345") +``` + +#### Start a service based on a Java JAR file + +If the service lives in a JAR file but isn’t currently running, use Beam utilities to run the service in a subprocess: + +```python +from apache_beam.transforms.external import JavaJarExpansionService + +provider = ExternalTransformProvider( + JavaJarExpansionService("path/to/my-expansion-service.jar")) +``` + +You can also provide a list of services: + +```python +provider = ExternalTransformProvider([ + "localhost:12345", + JavaJarExpansionService("path/to/my-expansion-service.jar"), + JavaJarExpansionService("path/to/another-expansion-service.jar")]) +``` + +When initialized, the `ExternalTransformProvider` connects to the expansion service(s), retrieves all portable transforms, and generates a Pythonic wrapper for each one. + +### Retrieve and use the transform + +Retrieve the transform using its unique identifier and use it in your multi-language pipeline: + +```python +identifier = "beam:schematransform:org.apache.beam:my_transform:v1" +MyTransform = provider.get_urn(identifier) + +with beam.Pipeline() as p: + p | beam.Create(...) | MyTransform(foo="abc", bar=123) +``` + + +### Inspect the transform's metadata +You can learn more about a portable transform’s configuration by inspecting its metadata: + +```python +import inspect + +inspect.getdoc(MyTransform) +# Output: "This transform does this and that..." + +inspect.signature(MyTransform) +# Output: (foo: "str: Description of what foo does...", +# bar: "int: Description of what bar does....") +``` + +This metadata is generated directly from the provider's implementation. The class documentation is generated from the [optional **description** method](#additional-metadata). The signature information is generated from the `@SchemaFieldDescription` annotations in the [configuration object](#implement-a-configuration). + +## Appendix + +### Portable transform + +Also known as a [cross-language transform](../glossary.md#cross-language-transforms): a transform that is made available to other SDKs (i.e. other languages) via an expansion service. Such a transform must offer a way to be constructed using language-agnostic parameter types. + +### Expansion Service + +A container that can hold multiple portable transforms. During pipeline expansion, this service will +- Look up the transform in its internal registry +- Build the transform in its native language using the provided configuration +- Expand the transform – i.e. construct the transform’s sub-graph to be inserted in the pipeline +- Establish a gRPC communication channel with the runner to exchange data and signals during pipeline execution. + +### SchemaTransform + +A transform that takes and produces PCollections of Beam Rows with a predefined Schema, i.e.: + +```java +SchemaTransform extends PTransform {} +``` + +### SchemaTransformProvider + +Produces a SchemaTransform using a provided configuration. An expansion service uses this interface to identify and build the transform for foreign SDKs. + +```java +SchemaTransformProvider { + String identifier(); + + SchemaTransform from(Row configuration); + + Schema configurationSchema(); +} +``` \ No newline at end of file From 5944a3043f8c0dc7b8b59ff9293b894e8a8a08ac Mon Sep 17 00:00:00 2001 From: Razvan Culea <40352446+razvanculea@users.noreply.github.com> Date: Fri, 27 Dec 2024 16:48:35 +0100 Subject: [PATCH 8/8] BigQueryIO : control StorageWrite parallelism in batch, by reshuffling before write on the number of streams set for BigQueryIO.write() using .withNumStorageWriteApiStreams(numStorageWriteApiStreams) (#32805) * BigQueryIO : control StorageWrite parallelism in batch, by reshuffling before write on the number of streams set for BigQueryIO.write() using .withNumStorageWriteApiStreams(numStorageWriteApiStreams) * fix unused dep and comment * spotlessApply * spotlessApply * fix typo --- .../beam/it/gcp/bigquery/BigQueryIOLT.java | 50 +++++++++++++++++-- .../beam/sdk/io/gcp/bigquery/BigQueryIO.java | 11 ++-- .../sdk/io/gcp/bigquery/StorageApiLoads.java | 15 +++++- 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOLT.java b/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOLT.java index a9ae68142778..7ea8dece31bb 100644 --- a/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOLT.java +++ b/it/google-cloud-platform/src/test/java/org/apache/beam/it/gcp/bigquery/BigQueryIOLT.java @@ -79,11 +79,20 @@ * *

Example trigger command for specific test running on Dataflow runner: * + *

Maven + * *

  * mvn test -pl it/google-cloud-platform -am -Dtest="BigQueryIOLT#testAvroFileLoadsWriteThenRead" \
  * -Dconfiguration=medium -Dproject=[gcpProject] -DartifactBucket=[temp bucket] -DfailIfNoTests=false
  * 
* + *

Gradle + * + *

+ * ./gradlew :it:google-cloud-platform:BigQueryPerformanceTest --tests='BigQueryIOLT.testAvroFileLoadsWriteThenRead' \
+ * -Dconfiguration=medium -Dproject=[gcpProject] -DartifactBucket=[temp bucket] -DfailIfNoTests=false
+ * 
+ * *

Example trigger command for specific test and custom data configuration: * *

mvn test -pl it/google-cloud-platform -am \
@@ -172,11 +181,11 @@ public static void tearDownClass() {
                   Configuration.class), // 1 MB
               "medium",
               Configuration.fromJsonString(
-                  "{\"numRecords\":10000000,\"valueSizeBytes\":1000,\"pipelineTimeout\":20,\"runner\":\"DataflowRunner\"}",
+                  "{\"numRecords\":10000000,\"valueSizeBytes\":1000,\"pipelineTimeout\":20,\"runner\":\"DataflowRunner\",\"workerMachineType\":\"e2-standard-2\",\"experiments\":\"disable_runner_v2\",\"numWorkers\":\"1\",\"maxNumWorkers\":\"1\"}",
                   Configuration.class), // 10 GB
               "large",
               Configuration.fromJsonString(
-                  "{\"numRecords\":100000000,\"valueSizeBytes\":1000,\"pipelineTimeout\":80,\"runner\":\"DataflowRunner\"}",
+                  "{\"numRecords\":100000000,\"valueSizeBytes\":1000,\"pipelineTimeout\":80,\"runner\":\"DataflowRunner\",\"workerMachineType\":\"e2-standard-2\",\"experiments\":\"disable_runner_v2\",\"numWorkers\":\"1\",\"maxNumWorkers\":\"1\",\"numStorageWriteApiStreams\":4,\"storageWriteApiTriggeringFrequencySec\":20}",
                   Configuration.class) // 100 GB
               );
     } catch (IOException e) {
@@ -230,16 +239,19 @@ public void testWriteAndRead() throws IOException {
         writeIO =
             BigQueryIO.write()
                 .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_TRUNCATE)
+                .withNumStorageWriteApiStreams(
+                    configuration.numStorageWriteApiStreams) // control the number of streams
                 .withAvroFormatFunction(
                     new AvroFormatFn(
                         configuration.numColumns,
                         !("STORAGE_WRITE_API".equalsIgnoreCase(configuration.writeMethod))));
-
         break;
       case JSON:
         writeIO =
             BigQueryIO.write()
                 .withSuccessfulInsertsPropagation(false)
+                .withNumStorageWriteApiStreams(
+                    configuration.numStorageWriteApiStreams) // control the number of streams
                 .withFormatFunction(new JsonFormatFn(configuration.numColumns));
         break;
     }
@@ -268,6 +280,10 @@ private void testWrite(BigQueryIO.Write writeIO) throws IOException {
             .setSdk(PipelineLauncher.Sdk.JAVA)
             .setPipeline(writePipeline)
             .addParameter("runner", configuration.runner)
+            .addParameter("workerMachineType", configuration.workerMachineType)
+            .addParameter("experiments", configuration.experiments)
+            .addParameter("numWorkers", configuration.numWorkers)
+            .addParameter("maxNumWorkers", configuration.maxNumWorkers)
             .build();
 
     PipelineLauncher.LaunchInfo launchInfo = pipelineLauncher.launch(project, region, options);
@@ -304,6 +320,10 @@ private void testRead() throws IOException {
             .setSdk(PipelineLauncher.Sdk.JAVA)
             .setPipeline(readPipeline)
             .addParameter("runner", configuration.runner)
+            .addParameter("workerMachineType", configuration.workerMachineType)
+            .addParameter("experiments", configuration.experiments)
+            .addParameter("numWorkers", configuration.numWorkers)
+            .addParameter("maxNumWorkers", configuration.maxNumWorkers)
             .build();
 
     PipelineLauncher.LaunchInfo launchInfo = pipelineLauncher.launch(project, region, options);
@@ -445,12 +465,36 @@ static class Configuration extends SyntheticSourceOptions {
     /** Runner specified to run the pipeline. */
     @JsonProperty public String runner = "DirectRunner";
 
+    /** Worker machine type specified to run the pipeline with Dataflow Runner. */
+    @JsonProperty public String workerMachineType = "";
+
+    /** Experiments specified to run the pipeline. */
+    @JsonProperty public String experiments = "";
+
+    /** Number of workers to start the pipeline. Must be a positive value. */
+    @JsonProperty public String numWorkers = "1";
+
+    /** Maximum umber of workers for the pipeline. Must be a positive value. */
+    @JsonProperty public String maxNumWorkers = "1";
+
     /** BigQuery read method: DEFAULT/DIRECT_READ/EXPORT. */
     @JsonProperty public String readMethod = "DEFAULT";
 
     /** BigQuery write method: DEFAULT/FILE_LOADS/STREAMING_INSERTS/STORAGE_WRITE_API. */
     @JsonProperty public String writeMethod = "DEFAULT";
 
+    /**
+     * BigQuery number of streams for write method STORAGE_WRITE_API. 0 let's the runner determine
+     * the number of streams. Remark : max limit for open connections per hour is 10K streams.
+     */
+    @JsonProperty public int numStorageWriteApiStreams = 0;
+
+    /**
+     * BigQuery triggering frequency in second in combination with the number of streams for write
+     * method STORAGE_WRITE_API.
+     */
+    @JsonProperty public int storageWriteApiTriggeringFrequencySec = 20;
+
     /** BigQuery write format: AVRO/JSON. */
     @JsonProperty public String writeFormat = "AVRO";
   }
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
index 30626da31c7c..ca9dfdb65caf 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java
@@ -3040,9 +3040,14 @@ public Write withNumFileShards(int numFileShards) {
     }
 
     /**
-     * Control how many parallel streams are used when using Storage API writes. Applicable only for
-     * streaming pipelines, and when {@link #withTriggeringFrequency} is also set. To let runner
-     * determine the sharding at runtime, set this to zero, or {@link #withAutoSharding()} instead.
+     * Control how many parallel streams are used when using Storage API writes.
+     *
+     * 

For streaming pipelines, and when {@link #withTriggeringFrequency} is also set. To let + * runner determine the sharding at runtime, set this to zero, or {@link #withAutoSharding()} + * instead. + * + *

For batch pipelines, it inserts a redistribute. To not reshufle and keep the pipeline + * parallelism as is, set this to zero. */ public Write withNumStorageWriteApiStreams(int numStorageWriteApiStreams) { return toBuilder().setNumStorageWriteApiStreams(numStorageWriteApiStreams).build(); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java index 0bc60e98b253..22e0f955abb5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java @@ -36,6 +36,7 @@ import org.apache.beam.sdk.transforms.GroupIntoBatches; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Redistribute; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.errorhandling.BadRecord; import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; @@ -360,9 +361,19 @@ public WriteResult expandUntriggered( rowUpdateFn, badRecordRouter)); + PCollection> successfulConvertedRows = + convertMessagesResult.get(successfulConvertedRowsTag); + + if (numShards > 0) { + successfulConvertedRows = + successfulConvertedRows.apply( + "ResdistibuteNumShards", + Redistribute.>arbitrarily() + .withNumBuckets(numShards)); + } + PCollectionTuple writeRecordsResult = - convertMessagesResult - .get(successfulConvertedRowsTag) + successfulConvertedRows .apply( "StorageApiWriteUnsharded", new StorageApiWriteUnshardedRecords<>(