diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py old mode 100755 new mode 100644 index f6078769c654..55de96dd4993 --- a/sdks/python/apache_beam/yaml/yaml_provider.py +++ b/sdks/python/apache_beam/yaml/yaml_provider.py @@ -21,22 +21,18 @@ import collections import hashlib -import inspect import json -import logging import os -import re import subprocess import sys import urllib.parse +import uuid from typing import Any from typing import Callable from typing import Dict from typing import Iterable from typing import Mapping -from typing import Optional -import docstring_parser import yaml from yaml.loader import SafeLoader @@ -50,8 +46,6 @@ from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform from apache_beam.typehints import schemas from apache_beam.typehints import trivial_inference -from apache_beam.typehints.schemas import named_tuple_to_schema -from apache_beam.typehints.schemas import typing_to_runner_api from apache_beam.utils import python_callable from apache_beam.utils import subprocess_server from apache_beam.version import __version__ as beam_version @@ -63,29 +57,10 @@ def available(self) -> bool: """Returns whether this provider is available to use in this environment.""" raise NotImplementedError(type(self)) - def cache_artifacts(self) -> Optional[Iterable[str]]: - raise NotImplementedError(type(self)) - def provided_transforms(self) -> Iterable[str]: """Returns a list of transform type names this provider can handle.""" raise NotImplementedError(type(self)) - def config_schema(self, type): - return None - - def description(self, type): - return None - - def requires_inputs(self, typ: str, args: Mapping[str, Any]) -> bool: - """Returns whether this transform requires inputs. - - Specifically, if this returns True and inputs are not provided than an error - will be thrown. - - This is best-effort, primarily for better and earlier error messages. - """ - return not typ.startswith('Read') - def create_transform( self, typ: str, @@ -97,12 +72,6 @@ def create_transform( """ raise NotImplementedError(type(self)) - def underlying_provider(self): - """If this provider is simply a proxy to another provider, return the - provider that should actually be used for affinity checking. - """ - return self - def affinity(self, other: "Provider"): """Returns a value approximating how good it would be for this provider to be used immediately following a transform from the other provider @@ -112,9 +81,7 @@ def affinity(self, other: "Provider"): # E.g. we could look at the the expected environments themselves. # Possibly, we could provide multiple expansions and have the runner itself # choose the actual implementation based on fusion (and other) criteria. - a = self.underlying_provider() - b = other.underlying_provider() - return a._affinity(b) + b._affinity(a) + return self._affinity(other) + other._affinity(self) def _affinity(self, other: "Provider"): if self is other or self == other: @@ -150,43 +117,21 @@ def __init__(self, urns, service): def provided_transforms(self): return self._urns.keys() - def schema_transforms(self): + def create_transform(self, type, args, yaml_create_transform): if callable(self._service): self._service = self._service() if self._schema_transforms is None: try: - self._schema_transforms = { - config.identifier: config + self._schema_transforms = [ + config.identifier for config in external.SchemaAwareExternalTransform.discover( - self._service, ignore_errors=True) - } + self._service) + ] except Exception: - # It's possible this service doesn't vend schema transforms. - self._schema_transforms = {} - return self._schema_transforms - - def config_schema(self, type): - if self._urns[type] in self.schema_transforms(): - return named_tuple_to_schema( - self.schema_transforms()[self._urns[type]].configuration_schema) - - def description(self, type): - if self._urns[type] in self.schema_transforms(): - return self.schema_transforms()[self._urns[type]].description - - def requires_inputs(self, typ, args): - if self._urns[typ] in self.schema_transforms(): - return bool(self.schema_transforms()[self._urns[typ]].inputs) - else: - return super().requires_inputs(typ, args) - - def create_transform(self, type, args, yaml_create_transform): - if callable(self._service): - self._service = self._service() + self._schema_transforms = [] urn = self._urns[type] - if urn in self.schema_transforms(): - return external.SchemaAwareExternalTransform( - urn, self._service, rearrange_based_on_discovery=True, **args) + if urn in self._schema_transforms: + return external.SchemaAwareExternalTransform(urn, self._service, **args) else: return type >> self.create_external_transform(urn, args) @@ -198,21 +143,10 @@ def create_external_transform(self, urn, args): @classmethod def provider_from_spec(cls, spec): - from apache_beam.yaml.yaml_transform import SafeLineLoader - for required in ('type', 'transforms'): - if required not in spec: - raise ValueError( - f'Missing {required} in provider ' - f'at line {SafeLineLoader.get_line(spec)}') urns = spec['transforms'] type = spec['type'] + from apache_beam.yaml.yaml_transform import SafeLineLoader config = SafeLineLoader.strip_metadata(spec.get('config', {})) - extra_params = set(SafeLineLoader.strip_metadata(spec).keys()) - set( - ['transforms', 'type', 'config']) - if extra_params: - raise ValueError( - f'Unexpected parameters in provider of type {type} ' - f'at line {SafeLineLoader.get_line(spec)}: {extra_params}') if config.get('version', None) == 'BEAM_VERSION': config['version'] = beam_version if type in cls._provider_types: @@ -231,7 +165,6 @@ def provider_from_spec(cls, spec): def register_provider_type(cls, type_name): def apply(constructor): cls._provider_types[type_name] = constructor - return constructor return apply @@ -302,37 +235,17 @@ def available(self): self._is_available = False return self._is_available - def cache_artifacts(self): - pass - class ExternalJavaProvider(ExternalProvider): def __init__(self, urns, jar_provider): super().__init__( urns, lambda: external.JavaJarExpansionService(jar_provider())) - self._jar_provider = jar_provider def available(self): # pylint: disable=subprocess-run-check return subprocess.run(['which', 'java'], capture_output=True).returncode == 0 - def cache_artifacts(self): - return [self._jar_provider()] - - -@ExternalProvider.register_provider_type('python') -def python(urns, packages=()): - if packages: - return ExternalPythonProvider(urns, packages) - else: - return InlineProvider({ - name: - python_callable.PythonCallableWithSource.load_from_fully_qualified_name( - constructor) - for (name, constructor) in urns.items() - }) - @ExternalProvider.register_provider_type('pythonPackage') class ExternalPythonProvider(ExternalProvider): @@ -342,9 +255,6 @@ def __init__(self, urns, packages): def available(self): return True # If we're running this script, we have Python installed. - def cache_artifacts(self): - return [self._service._venv()] - def create_external_transform(self, urn, args): # Python transforms are "registered" by fully qualified name. return external.ExternalTransform( @@ -401,127 +311,27 @@ def fn_takes_side_inputs(fn): class InlineProvider(Provider): - def __init__(self, transform_factories, no_input_transforms=()): + def __init__(self, transform_factories): self._transform_factories = transform_factories - self._no_input_transforms = set(no_input_transforms) def available(self): return True - def cache_artifacts(self): - pass - def provided_transforms(self): return self._transform_factories.keys() - def config_schema(self, typ): - factory = self._transform_factories[typ] - if isinstance(factory, type) and issubclass(factory, beam.PTransform): - # https://bugs.python.org/issue40897 - params = dict(inspect.signature(factory.__init__).parameters) - if 'self' in params: - del params['self'] - else: - params = inspect.signature(factory).parameters - - def type_of(p): - t = p.annotation - if t == p.empty: - return Any - else: - return t - - docs = { - param.arg_name: param.description - for param in self.get_docs(typ).params - } - - names_and_types = [ - (name, typing_to_runner_api(type_of(p))) for name, p in params.items() - ] - return schema_pb2.Schema( - fields=[ - schema_pb2.Field(name=name, type=type, description=docs.get(name)) - for (name, type) in names_and_types - ]) - - def description(self, typ): - def empty_if_none(s): - return s or '' - - docs = self.get_docs(typ) - return ( - empty_if_none(docs.short_description) + '\n\n' + - empty_if_none(docs.long_description)).strip() or None - - def get_docs(self, typ): - docstring = self._transform_factories[typ].__doc__ or '' - # These "extra" docstring parameters are not relevant for YAML and mess - # up the parsing. - docstring = re.sub( - r'Pandas Parameters\s+-----.*', '', docstring, flags=re.S) - return docstring_parser.parse( - docstring, docstring_parser.DocstringStyle.GOOGLE) - def create_transform(self, type, args, yaml_create_transform): return self._transform_factories[type](**args) def to_json(self): return {'type': "InlineProvider"} - def requires_inputs(self, typ, args): - if typ in self._no_input_transforms: - return False - elif hasattr(self._transform_factories[typ], '_yaml_requires_inputs'): - return self._transform_factories[typ]._yaml_requires_inputs - else: - return super().requires_inputs(typ, args) - class MetaInlineProvider(InlineProvider): def create_transform(self, type, args, yaml_create_transform): return self._transform_factories[type](yaml_create_transform, **args) -class SqlBackedProvider(Provider): - def __init__( - self, - transforms: Mapping[str, Callable[..., beam.PTransform]], - sql_provider: Optional[Provider] = None): - self._transforms = transforms - if sql_provider is None: - sql_provider = beam_jar( - urns={'Sql': 'beam:external:java:sql:v1'}, - gradle_target='sdks:java:extensions:sql:expansion-service:shadowJar') - self._sql_provider = sql_provider - - def sql_provider(self): - return self._sql_provider - - def provided_transforms(self): - return self._transforms.keys() - - def available(self): - return self.sql_provider().available() - - def cache_artifacts(self): - return self.sql_provider().cache_artifacts() - - def underlying_provider(self): - return self.sql_provider() - - def to_json(self): - return {'type': "SqlBackedProvider"} - - def create_transform( - self, typ: str, args: Mapping[str, Any], - yaml_create_transform: Any) -> beam.PTransform: - return self._transforms[typ]( - lambda query: self.sql_provider().create_transform( - 'Sql', {'query': query}, yaml_create_transform), - **args) - - PRIMITIVE_NAMES_TO_ATOMIC_TYPE = { py_type.__name__: schema_type for (py_type, schema_type) in schemas.PRIMITIVE_TO_ATOMIC_TYPE.items() @@ -529,79 +339,50 @@ def create_transform( } -def element_to_rows(e): - if isinstance(e, dict): - return dicts_to_rows(e) - else: - return beam.Row(element=dicts_to_rows(e)) - - -def dicts_to_rows(o): - if isinstance(o, dict): - return beam.Row(**{k: dicts_to_rows(v) for k, v in o.items()}) - elif isinstance(o, list): - return [dicts_to_rows(e) for e in o] - else: - return o - - def create_builtin_provider(): - def create(elements: Iterable[Any], reshuffle: Optional[bool] = True): - """Creates a collection containing a specified set of elements. - - YAML/JSON-style mappings will be interpreted as Beam rows. For example:: - - type: Create - elements: - - {first: 0, second: {str: "foo", values: [1, 2, 3]}} - - will result in a schema of the form (int, Row(string, List[int])). + def with_schema(**args): + # TODO: This is preliminary. + def parse_type(spec): + if spec in PRIMITIVE_NAMES_TO_ATOMIC_TYPE: + return schema_pb2.FieldType( + atomic_type=PRIMITIVE_NAMES_TO_ATOMIC_TYPE[spec]) + elif isinstance(spec, list): + if len(spec) != 1: + raise ValueError("Use single-element lists to denote list types.") + else: + return schema_pb2.FieldType( + iterable_type=schema_pb2.IterableType( + element_type=parse_type(spec[0]))) + elif isinstance(spec, dict): + return schema_pb2.FieldType( + iterable_type=schema_pb2.RowType(schema=parse_schema(spec[0]))) + else: + raise ValueError("Unknown schema type: {spec}") + + def parse_schema(spec): + return schema_pb2.Schema( + fields=[ + schema_pb2.Field(name=key, type=parse_type(value), id=ix) + for (ix, (key, value)) in enumerate(spec.items()) + ], + id=str(uuid.uuid4())) + + named_tuple = schemas.named_tuple_from_schema(parse_schema(args)) + names = list(args.keys()) + + def extract_field(x, name): + if isinstance(x, dict): + return x[name] + else: + return getattr(x, name) - Args: - elements: The set of elements that should belong to the PCollection. - YAML/JSON-style mappings will be interpreted as Beam rows. - reshuffle (optional): Whether to introduce a reshuffle (to possibly - redistribute the work) if there is more than one element in the - collection. Defaults to True. - """ - return beam.Create([element_to_rows(e) for e in elements], - reshuffle=reshuffle is not False) + return 'WithSchema(%s)' % ', '.join(names) >> beam.Map( + lambda x: named_tuple(*[extract_field(x, name) for name in names]) + ).with_output_types(named_tuple) # Or should this be posargs, args? # pylint: disable=dangerous-default-value - def fully_qualified_named_transform( - constructor: str, - args: Optional[Iterable[Any]] = (), - kwargs: Optional[Mapping[str, Any]] = {}): - """A Python PTransform identified by fully qualified name. - - This allows one to import, construct, and apply any Beam Python transform. - This can be useful for using transforms that have not yet been exposed - via a YAML interface. Note, however, that conversion may be required if this - transform does not accept or produce Beam Rows. - - For example, - - type: PyTransform - config: - constructor: apache_beam.pkg.mod.SomeClass - args: [1, 'foo'] - kwargs: - baz: 3 - - can be used to access the transform - `apache_beam.pkg.mod.SomeClass(1, 'foo', baz=3)`. - - Args: - constructor: Fully qualified name of a callable used to construct the - transform. Often this is a class such as - `apache_beam.pkg.mod.SomeClass` but it can also be a function or - any other callable that returns a PTransform. - args: A list of parameters to pass to the callable as positional - arguments. - kwargs: A list of parameters to pass to the callable as keyword - arguments. - """ + def fully_qualified_named_transform(constructor, args=(), kwargs={}): with FullyQualifiedNamedTransform.with_filter('*'): return constructor >> FullyQualifiedNamedTransform( constructor, args, kwargs) @@ -610,19 +391,6 @@ def fully_qualified_named_transform( # exactly zero or one PCollection in yaml (as they would be interpreted as # PBegin and the PCollection itself respectively). class Flatten(beam.PTransform): - """Flattens multiple PCollections into a single PCollection. - - The elements of the resulting PCollection will be the (disjoint) union of - all the elements of all the inputs. - - Note that in YAML transforms can always take a list of inputs which will - be implicitly flattened. - """ - def __init__(self): - # Suppress the "label" argument from the superclass for better docs. - # pylint: disable=useless-parent-delegation - super().__init__() - def expand(self, pcolls): if isinstance(pcolls, beam.PCollection): pipeline_arg = {} @@ -636,24 +404,6 @@ def expand(self, pcolls): return pcolls | beam.Flatten(**pipeline_arg) class WindowInto(beam.PTransform): - # pylint: disable=line-too-long - - """A window transform assigning windows to each element of a PCollection. - - The assigned windows will affect all downstream aggregating operations, - which will aggregate by window as well as by key. - - See [the Beam documentation on windowing](https://beam.apache.org/documentation/programming-guide/#windowing) - for more details. - - Note that any Yaml transform can have a - [windowing parameter](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/yaml/README.md#windowing), - which is applied to its inputs (if any) or outputs (if there are no inputs) - which means that explicit WindowInto operations are not typically needed. - - Args: - windowing: the type and parameters of the windowing to perform - """ def __init__(self, windowing): self._window_transform = self._parse_window_spec(windowing) @@ -679,26 +429,43 @@ def _parse_window_spec(spec): # TODO: Triggering, etc. return beam.WindowInto(window_fn) - def LogForTesting(): - """Logs each element of its input PCollection. - - The output of this transform is a copy of its input for ease of use in - chain-style pipelines. - """ - def log_and_return(x): - logging.info(x) - return x - - return beam.Map(log_and_return) - - return InlineProvider({ - 'Create': create, - 'LogForTesting': LogForTesting, - 'PyTransform': fully_qualified_named_transform, - 'Flatten': Flatten, - 'WindowInto': WindowInto, - }, - no_input_transforms=('Create', )) + ios = { + key: getattr(apache_beam.io, key) + for key in dir(apache_beam.io) + if key.startswith('ReadFrom') or key.startswith('WriteTo') + } + + return InlineProvider( + dict({ + 'Create': lambda elements, + reshuffle=True: beam.Create(elements, reshuffle), + 'PyMap': lambda fn: beam.Map( + python_callable.PythonCallableWithSource(fn)), + 'PyMapTuple': lambda fn: beam.MapTuple( + python_callable.PythonCallableWithSource(fn)), + 'PyFlatMap': lambda fn: beam.FlatMap( + python_callable.PythonCallableWithSource(fn)), + 'PyFlatMapTuple': lambda fn: beam.FlatMapTuple( + python_callable.PythonCallableWithSource(fn)), + 'PyFilter': lambda keep: beam.Filter( + python_callable.PythonCallableWithSource(keep)), + 'PyTransform': fully_qualified_named_transform, + 'PyToRow': lambda fields: beam.Select( + **{ + name: python_callable.PythonCallableWithSource(fn) + for (name, fn) in fields.items() + }), + 'WithSchema': with_schema, + 'Flatten': Flatten, + 'WindowInto': WindowInto, + 'GroupByKey': beam.GroupByKey, + 'CombinePerKey': lambda combine_fn: beam.CombinePerKey( + python_callable.PythonCallableWithSource(combine_fn)), + 'TopNLargest': lambda n, + key=None: beam.combiners.Top.Largest( + n=n, key=python_callable.PythonCallableWithSource(key)) + }, + **ios)) class PypiExpansionService: @@ -711,60 +478,23 @@ def __init__(self, packages, base_python=sys.executable): self._packages = packages self._base_python = base_python - @classmethod - def _key(cls, base_python, packages): - return json.dumps({ - 'binary': base_python, 'packages': sorted(packages) - }, - sort_keys=True) - - @classmethod - def _path(cls, base_python, packages): - return os.path.join( - cls.VENV_CACHE, - hashlib.sha256(cls._key(base_python, - packages).encode('utf-8')).hexdigest()) - - @classmethod - def _create_venv_from_scratch(cls, base_python, packages): - venv = cls._path(base_python, packages) - if not os.path.exists(venv): - subprocess.run([base_python, '-m', 'venv', venv], check=True) - venv_python = os.path.join(venv, 'bin', 'python') - subprocess.run([venv_python, '-m', 'ensurepip'], check=True) - subprocess.run([venv_python, '-m', 'pip', 'install'] + packages, - check=True) - with open(venv + '-requirements.txt', 'w') as fout: - fout.write('\n'.join(packages)) - return venv + def _key(self): + return json.dumps({'binary': self._base_python, 'packages': self._packages}) - @classmethod - def _create_venv_from_clone(cls, base_python, packages): - venv = cls._path(base_python, packages) + def _venv(self): + venv = os.path.join( + self.VENV_CACHE, + hashlib.sha256(self._key().encode('utf-8')).hexdigest()) if not os.path.exists(venv): - clonable_venv = cls._create_venv_to_clone(base_python) - clonable_python = os.path.join(clonable_venv, 'bin', 'python') - subprocess.run( - [clonable_python, '-m', 'clonevirtualenv', clonable_venv, venv], - check=True) - venv_binary = os.path.join(venv, 'bin', 'python') - subprocess.run([venv_binary, '-m', 'pip', 'install'] + packages, + python_binary = os.path.join(venv, 'bin', 'python') + subprocess.run([self._base_python, '-m', 'venv', venv], check=True) + subprocess.run([python_binary, '-m', 'ensurepip'], check=True) + subprocess.run([python_binary, '-m', 'pip', 'install'] + self._packages, check=True) with open(venv + '-requirements.txt', 'w') as fout: - fout.write('\n'.join(packages)) + fout.write('\n'.join(self._packages)) return venv - @classmethod - def _create_venv_to_clone(cls, base_python): - return cls._create_venv_from_scratch( - base_python, [ - 'apache_beam[dataframe,gcp,test]==' + beam_version, - 'virtualenv-clone' - ]) - - def _venv(self): - return self._create_venv_from_clone(self._base_python, self._packages) - def __enter__(self): venv = self._venv() self._service_provider = subprocess_server.SubprocessServer( @@ -787,107 +517,6 @@ def __exit__(self, *args): self._service = None -@ExternalProvider.register_provider_type('renaming') -class RenamingProvider(Provider): - def __init__(self, transforms, mappings, underlying_provider, defaults=None): - if isinstance(underlying_provider, dict): - underlying_provider = ExternalProvider.provider_from_spec( - underlying_provider) - self._transforms = transforms - self._underlying_provider = underlying_provider - for transform in transforms.keys(): - if transform not in mappings: - raise ValueError(f'Missing transform {transform} in mappings.') - self._mappings = self.expand_mappings(mappings) - self._defaults = defaults or {} - - @staticmethod - def expand_mappings(mappings): - if not isinstance(mappings, dict): - raise ValueError( - "RenamingProvider mappings must be dict of transform " - "mappings.") - for key, value in mappings.items(): - if isinstance(value, str): - if value not in mappings.keys(): - raise ValueError( - "RenamingProvider transform mappings must be dict or " - "specify transform that has mappings within same " - "provider.") - mappings[key] = mappings[value] - return mappings - - def available(self) -> bool: - return self._underlying_provider.available() - - def provided_transforms(self) -> Iterable[str]: - return self._transforms.keys() - - def config_schema(self, type): - underlying_schema = self._underlying_provider.config_schema( - self._transforms[type]) - if underlying_schema is None: - return None - defaults = self._defaults.get(type, {}) - underlying_schema_fields = {f.name: f for f in underlying_schema.fields} - missing = set(self._mappings[type].values()) - set( - underlying_schema_fields.keys()) - if missing: - raise ValueError( - f"Mapping destinations {missing} for {type} are not in the " - f"underlying config schema {list(underlying_schema_fields.keys())}") - - def with_name( - original: schema_pb2.Field, new_name: str) -> schema_pb2.Field: - result = schema_pb2.Field() - result.CopyFrom(original) - result.name = new_name - return result - - return schema_pb2.Schema( - fields=[ - with_name(underlying_schema_fields[dest], src) - for (src, dest) in self._mappings[type].items() - if dest not in defaults - ]) - - def description(self, typ): - return self._underlying_provider.description(typ) - - def requires_inputs(self, typ, args): - return self._underlying_provider.requires_inputs(typ, args) - - def create_transform( - self, - typ: str, - args: Mapping[str, Any], - yaml_create_transform: Callable[ - [Mapping[str, Any], Iterable[beam.PCollection]], beam.PTransform] - ) -> beam.PTransform: - """Creates a PTransform instance for the given transform type and arguments. - """ - mappings = self._mappings[typ] - remapped_args = { - mappings.get(key, key): value - for key, value in args.items() - } - for key, value in self._defaults.get(typ, {}).items(): - if key not in remapped_args: - remapped_args[key] = value - return self._underlying_provider.create_transform( - self._transforms[typ], remapped_args, yaml_create_transform) - - def _affinity(self, other): - raise NotImplementedError( - 'Should not be calling _affinity directly on this provider.') - - def underlying_provider(self): - return self._underlying_provider.underlying_provider() - - def cache_artifacts(self): - self._underlying_provider.cache_artifacts() - - def parse_providers(provider_specs): providers = collections.defaultdict(list) for provider_spec in provider_specs: @@ -908,24 +537,17 @@ def merge_providers(*provider_sets): transform_type: [provider] for transform_type in provider.provided_transforms() } - elif isinstance(provider_set, list): - provider_set = merge_providers(*provider_set) for transform_type, providers in provider_set.items(): result[transform_type].extend(providers) return result def standard_providers(): - from apache_beam.yaml.yaml_combine import create_combine_providers - from apache_beam.yaml.yaml_mapping import create_mapping_providers - from apache_beam.yaml.yaml_io import io_providers + from apache_beam.yaml.yaml_mapping import create_mapping_provider with open(os.path.join(os.path.dirname(__file__), 'standard_providers.yaml')) as fin: standard_providers = yaml.load(fin, Loader=SafeLoader) - return merge_providers( create_builtin_provider(), - create_mapping_providers(), - create_combine_providers(), - io_providers(), + create_mapping_provider(), parse_providers(standard_providers))