From 909178ba0eaf6f1c7fbde0c04121e495e007df1f Mon Sep 17 00:00:00 2001 From: Brian Hulette Date: Fri, 12 Aug 2022 12:45:53 -0700 Subject: [PATCH] Add GeneratedClassRowTypeConstraint (#22679) * Add (failing) pickling tests * Add GeneratedClassRowTypeConstraint, plumb options * Add top-level option conversion functions * Refactor NamedTuple generation, always create GeneratedClassRowTypeConstraint * Move registry to apache_beam.typehints.schema_registry * yapf,lint * fixup! Move registry to apache_beam.typehints.schema_registry * Apply suggestions from code review Co-authored-by: Andy Ye * Add None SchemaRegistry * Add skipped test for pickling generated type Co-authored-by: Andy Ye --- sdks/python/apache_beam/coders/row_coder.py | 2 + sdks/python/apache_beam/transforms/core.py | 4 +- sdks/python/apache_beam/typehints/row_type.py | 80 +++++++-- .../apache_beam/typehints/schema_registry.py | 54 ++++++ sdks/python/apache_beam/typehints/schemas.py | 165 +++++++++--------- .../apache_beam/typehints/schemas_test.py | 96 +++++++--- .../typehints/trivial_inference_test.py | 8 +- 7 files changed, 291 insertions(+), 118 deletions(-) create mode 100644 sdks/python/apache_beam/typehints/schema_registry.py diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py index 8f3421ca70b1..600d6595f105 100644 --- a/sdks/python/apache_beam/coders/row_coder.py +++ b/sdks/python/apache_beam/coders/row_coder.py @@ -133,6 +133,8 @@ def __reduce__(self): typecoders.registry.register_coder(row_type.RowTypeConstraint, RowCoder) +typecoders.registry.register_coder( + row_type.GeneratedClassRowTypeConstraint, RowCoder) def _coder_from_type(field_type): diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 516928967401..50ff32e57a33 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -2977,7 +2977,7 @@ def _key_type_hint(self, input_type): expr = self._key_fields[0][1] return trivial_inference.infer_return_type(expr, [input_type]) else: - return row_type.RowTypeConstraint([ + return row_type.RowTypeConstraint.from_fields([ (name, trivial_inference.infer_return_type(expr, [input_type])) for (name, expr) in self._key_fields ]) @@ -3089,7 +3089,7 @@ def expand(self, pcoll): for name, expr in self._fields})) def infer_output_type(self, input_type): - return row_type.RowTypeConstraint([ + return row_type.RowTypeConstraint.from_fields([ (name, trivial_inference.infer_return_type(expr, [input_type])) for (name, expr) in self._fields ]) diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py index 50d7ff6a50b8..b1f6fd99d979 100644 --- a/sdks/python/apache_beam/typehints/row_type.py +++ b/sdks/python/apache_beam/typehints/row_type.py @@ -21,13 +21,13 @@ from typing import Any from typing import Dict -from typing import List from typing import Optional from typing import Sequence from typing import Tuple from apache_beam.typehints import typehints from apache_beam.typehints.native_type_compatibility import match_is_named_tuple +from apache_beam.typehints.schema_registry import SchemaTypeRegistry # Name of the attribute added to user types (existing and generated) to store # the corresponding schema ID @@ -37,10 +37,10 @@ class RowTypeConstraint(typehints.TypeConstraint): def __init__( self, - fields: List[Tuple[str, type]], - user_type=None, - schema_options: Optional[List[Tuple[str, Any]]] = None, - field_options: Optional[Dict[str, List[Tuple[str, Any]]]] = None): + fields: Sequence[Tuple[str, type]], + user_type, + schema_options: Optional[Sequence[Tuple[str, Any]]] = None, + field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None): """For internal use only, no backwards comatibility guaratees. See https://beam.apache.org/documentation/programming-guide/#schemas-for-pl-types for guidance on creating PCollections with inferred schemas. @@ -83,10 +83,7 @@ def __init__( # Note schema ID can be None if the schema is not registered yet. # Currently registration happens when converting to schema protos, in # apache_beam.typehints.schemas - if self._user_type is not None: - self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None) - else: - self._schema_id = None + self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None) self._schema_options = schema_options or [] self._field_options = field_options or {} @@ -94,8 +91,8 @@ def __init__( @staticmethod def from_user_type( user_type: type, - schema_options: Optional[List[Tuple[str, Any]]] = None, - field_options: Optional[Dict[str, List[Tuple[str, Any]]]] = None + schema_options: Optional[Sequence[Tuple[str, Any]]] = None, + field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None ) -> Optional[RowTypeConstraint]: if match_is_named_tuple(user_type): fields = [(name, user_type.__annotations__[name]) @@ -112,8 +109,19 @@ def from_user_type( return None @staticmethod - def from_fields(fields: Sequence[Tuple[str, type]]) -> RowTypeConstraint: - return RowTypeConstraint(fields=fields, user_type=None) + def from_fields( + fields: Sequence[Tuple[str, type]], + schema_id: Optional[str] = None, + schema_options: Optional[Sequence[Tuple[str, Any]]] = None, + field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, + schema_registry: Optional[SchemaTypeRegistry] = None, + ) -> RowTypeConstraint: + return GeneratedClassRowTypeConstraint( + fields, + schema_id=schema_id, + schema_options=schema_options, + field_options=field_options, + schema_registry=schema_registry) @property def user_type(self): @@ -160,3 +168,49 @@ def __repr__(self): def get_type_for(self, name): return dict(self._fields)[name] + + +class GeneratedClassRowTypeConstraint(RowTypeConstraint): + """Specialization of RowTypeConstraint which relies on a generated user_type. + + Since the generated user_type cannot be pickled, we supply a custom __reduce__ + function that will regenerate the user_type. + """ + def __init__( + self, + fields, + schema_id: Optional[str] = None, + schema_options: Optional[Sequence[Tuple[str, Any]]] = None, + field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, + schema_registry: Optional[SchemaTypeRegistry] = None, + ): + from apache_beam.typehints.schemas import named_fields_to_schema + from apache_beam.typehints.schemas import named_tuple_from_schema + + kwargs = {'schema_registry': schema_registry} if schema_registry else {} + + schema = named_fields_to_schema( + fields, + schema_id=schema_id, + schema_options=schema_options, + field_options=field_options, + **kwargs) + user_type = named_tuple_from_schema(schema, **kwargs) + setattr(user_type, _BEAM_SCHEMA_ID, schema_id) + + super().__init__( + fields, + user_type, + schema_options=schema_options, + field_options=field_options) + + def __reduce__(self): + return ( + RowTypeConstraint.from_fields, + ( + self._fields, + self._schema_id, + self._schema_options, + self._field_options, + None, + )) diff --git a/sdks/python/apache_beam/typehints/schema_registry.py b/sdks/python/apache_beam/typehints/schema_registry.py new file mode 100644 index 000000000000..9ec7b1b65ccf --- /dev/null +++ b/sdks/python/apache_beam/typehints/schema_registry.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This module is intended for internal use only. Nothing defined here provides +any backwards-compatibility guarantee. +""" + +from uuid import uuid4 + + +# Registry of typings for a schema by UUID +class SchemaTypeRegistry(object): + def __init__(self): + self.by_id = {} + self.by_typing = {} + + def generate_new_id(self): + for _ in range(100): + schema_id = str(uuid4()) + if schema_id not in self.by_id: + return schema_id + + raise AssertionError( + "Failed to generate a unique UUID for schema after " + f"100 tries! Registry contains {len(self.by_id)} " + "schemas.") + + def add(self, typing, schema): + self.by_id[schema.id] = (typing, schema) + + def get_typing_by_id(self, unique_id): + result = self.by_id.get(unique_id, None) + return result[0] if result is not None else None + + def get_schema_by_id(self, unique_id): + result = self.by_id.get(unique_id, None) + return result[1] if result is not None else None + + +SCHEMA_REGISTRY = SchemaTypeRegistry() diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index db4415b87509..4f38b6695f38 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -56,6 +56,7 @@ from typing import Any from typing import ByteString +from typing import Dict from typing import Generic from typing import List from typing import Mapping @@ -78,47 +79,14 @@ from apache_beam.typehints.native_type_compatibility import _safe_issubclass from apache_beam.typehints.native_type_compatibility import extract_optional_type from apache_beam.typehints.native_type_compatibility import match_is_named_tuple +from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY +from apache_beam.typehints.schema_registry import SchemaTypeRegistry from apache_beam.utils import proto_utils from apache_beam.utils.python_callable import PythonCallableWithSource from apache_beam.utils.timestamp import Timestamp PYTHON_ANY_URN = "beam:logical:pythonsdk_any:v1" - -# Registry of typings for a schema by UUID -class SchemaTypeRegistry(object): - def __init__(self): - self.by_id = {} - self.by_typing = {} - - def generate_new_id(self): - # Import uuid locally to guarantee we don't actually generate a uuid - # elsewhere in this file. - from uuid import uuid4 - for _ in range(100): - schema_id = str(uuid4()) - if schema_id not in self.by_id: - return schema_id - - raise AssertionError( - "Failed to generate a unique UUID for schema after " - f"100 tries! Registry contains {len(self.by_id)} " - "schemas.") - - def add(self, typing, schema): - self.by_id[schema.id] = (typing, schema) - - def get_typing_by_id(self, unique_id): - result = self.by_id.get(unique_id, None) - return result[0] if result is not None else None - - def get_schema_by_id(self, unique_id): - result = self.by_id.get(unique_id, None) - return result[1] if result is not None else None - - -SCHEMA_REGISTRY = SchemaTypeRegistry() - # Bi-directional mappings _PRIMITIVES = ( (np.int8, schema_pb2.BYTE), @@ -146,16 +114,37 @@ def get_schema_by_id(self, unique_id): }) -def named_fields_to_schema(names_and_types): - # type: (Union[Dict[str, type], Sequence[Tuple[str, type]]]) -> schema_pb2.Schema # noqa: F821 +def named_fields_to_schema( + names_and_types: Union[Dict[str, type], Sequence[Tuple[str, type]]], + schema_id: Optional[str] = None, + schema_options: Optional[Sequence[Tuple[str, Any]]] = None, + field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None, + schema_registry: SchemaTypeRegistry = SCHEMA_REGISTRY, +): + schema_options = schema_options or [] + field_options = field_options or {} + if isinstance(names_and_types, dict): names_and_types = names_and_types.items() + + if schema_id is None: + schema_id = schema_registry.generate_new_id() + return schema_pb2.Schema( fields=[ - schema_pb2.Field(name=name, type=typing_to_runner_api(type)) - for (name, type) in names_and_types + schema_pb2.Field( + name=name, + type=typing_to_runner_api(type), + options=[ + option_to_runner_api(option_tuple) + for option_tuple in field_options.get(name, []) + ], + ) for (name, type) in names_and_types ], - id=SCHEMA_REGISTRY.generate_new_id()) + options=[ + option_to_runner_api(option_tuple) for option_tuple in schema_options + ], + id=schema_id) def named_fields_from_schema( @@ -179,6 +168,20 @@ def typing_from_runner_api( schema_registry=schema_registry).typing_from_runner_api(fieldtype_proto) +def option_to_runner_api( + option: Tuple[str, Any], + schema_registry: SchemaTypeRegistry = SCHEMA_REGISTRY) -> schema_pb2.Option: + return SchemaTranslation( + schema_registry=schema_registry).option_to_runner_api(option) + + +def option_from_runner_api( + option_proto: schema_pb2.Option, + schema_registry: SchemaTypeRegistry = SCHEMA_REGISTRY) -> type: + return SchemaTranslation( + schema_registry=schema_registry).option_from_runner_api(option_proto) + + class SchemaTranslation(object): def __init__(self, schema_registry: SchemaTypeRegistry = SCHEMA_REGISTRY): self.schema_registry = schema_registry @@ -397,41 +400,15 @@ def typing_from_runner_api( if user_type is None: # If not in SDK options (the coder likely came from another SDK), # generate a NamedTuple type to use. - from apache_beam import coders - type_name = 'BeamSchema_{}'.format(schema.id.replace('-', '_')) - - subfields = [] - for field in schema.fields: - try: - field_py_type = self.typing_from_runner_api(field.type) - if isinstance(field_py_type, row_type.RowTypeConstraint): - field_py_type = field_py_type.user_type - except ValueError as e: - raise ValueError( - "Failed to decode schema due to an issue with Field proto:\n\n" - f"{text_format.MessageToString(field)}") from e - - subfields.append((field.name, field_py_type)) - - user_type = NamedTuple(type_name, subfields) - - # Define a reduce function, otherwise these types can't be pickled - # (See BEAM-9574) - def __reduce__(self): - return ( - _hydrate_namedtuple_instance, - (schema.SerializeToString(), tuple(self))) - - setattr(user_type, '__reduce__', __reduce__) - - self.schema_registry.add(user_type, schema) - coders.registry.register_coder(user_type, coders.RowCoder) - result = row_type.RowTypeConstraint.from_user_type( - user_type, + fields = named_fields_from_schema(schema) + result = row_type.RowTypeConstraint.from_fields( + fields=fields, + schema_id=schema.id, schema_options=schema_options, - field_options=field_options) - result.set_schema_id(schema.id) + field_options=field_options, + schema_registry=self.schema_registry, + ) return result else: return row_type.RowTypeConstraint.from_user_type( @@ -449,6 +426,41 @@ def __reduce__(self): else: raise ValueError(f"Unrecognized type_info: {type_info!r}") + def named_tuple_from_schema(self, schema: schema_pb2.Schema) -> type: + from apache_beam import coders + + assert schema.id + type_name = 'BeamSchema_{}'.format(schema.id.replace('-', '_')) + + subfields = [] + for field in schema.fields: + try: + field_py_type = self.typing_from_runner_api(field.type) + if isinstance(field_py_type, row_type.RowTypeConstraint): + field_py_type = field_py_type.user_type + except ValueError as e: + raise ValueError( + "Failed to decode schema due to an issue with Field proto:\n\n" + f"{text_format.MessageToString(field)}") from e + + subfields.append((field.name, field_py_type)) + + user_type = NamedTuple(type_name, subfields) + + # Define a reduce function, otherwise these types can't be pickled + # (See BEAM-9574) + def __reduce__(self): + return ( + _hydrate_namedtuple_instance, + (schema.SerializeToString(), tuple(self))) + + setattr(user_type, '__reduce__', __reduce__) + + self.schema_registry.add(user_type, schema) + coders.registry.register_coder(user_type, coders.RowCoder) + + return user_type + def _hydrate_namedtuple_instance(encoded_schema, values): return named_tuple_from_schema( @@ -457,11 +469,8 @@ def _hydrate_namedtuple_instance(encoded_schema, values): def named_tuple_from_schema( schema, schema_registry: SchemaTypeRegistry = SCHEMA_REGISTRY) -> type: - row_type_constraint = typing_from_runner_api( - schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=schema)), - schema_registry=schema_registry) - assert isinstance(row_type_constraint, row_type.RowTypeConstraint) - return row_type_constraint.user_type + return SchemaTranslation( + schema_registry=schema_registry).named_tuple_from_schema(schema) def named_tuple_to_schema( diff --git a/sdks/python/apache_beam/typehints/schemas_test.py b/sdks/python/apache_beam/typehints/schemas_test.py index 9b4f5c785150..370b9c92cde7 100644 --- a/sdks/python/apache_beam/typehints/schemas_test.py +++ b/sdks/python/apache_beam/typehints/schemas_test.py @@ -29,8 +29,11 @@ from typing import Optional from typing import Sequence +import cloudpickle +import dill import numpy as np from parameterized import parameterized +from parameterized import parameterized_class from apache_beam.portability import common_urns from apache_beam.portability.api import schema_pb2 @@ -312,7 +315,7 @@ def test_namedtuple_roundtrip(self, user_type): def test_row_type_constraint_to_schema(self): result_type = typing_to_runner_api( - row_type.RowTypeConstraint([ + row_type.RowTypeConstraint.from_fields([ ('foo', np.int8), ('bar', float), ('baz', bytes), @@ -337,7 +340,7 @@ def test_row_type_constraint_to_schema(self): self.assertEqual(list(schema.fields), expected) def test_row_type_constraint_to_schema_with_options(self): - row_type_with_options = row_type.RowTypeConstraint( + row_type_with_options = row_type.RowTypeConstraint.from_fields( [ ('foo', np.int8), ('bar', float), @@ -382,15 +385,17 @@ def test_row_type_constraint_to_schema_with_options(self): def test_row_type_constraint_to_schema_with_field_options(self): result_type = typing_to_runner_api( - row_type.RowTypeConstraint([ + row_type.RowTypeConstraint.from_fields([ ('foo', np.int8), ('bar', float), ('baz', bytes), ], - field_options={ - 'foo': [('some_metadata', 123), - ('some_flag', None)] - })) + field_options={ + 'foo': [ + ('some_metadata', 123), + ('some_flag', None) + ] + })) self.assertIsInstance(result_type, schema_pb2.FieldType) self.assertEqual(result_type.WhichOneof("type_info"), "row_type") @@ -529,20 +534,6 @@ def test_trivial_example(self): expected.row_type.schema.fields, typing_to_runner_api(MyCuteClass).row_type.schema.fields) - def test_generated_class_pickle(self): - schema = schema_pb2.Schema( - id="some-uuid", - fields=[ - schema_pb2.Field( - name='name', - type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), - ) - ]) - user_type = named_tuple_from_schema(schema) - instance = user_type(name="test") - - self.assertEqual(instance, pickle.loads(pickle.dumps(instance))) - def test_user_type_annotated_with_id_after_conversion(self): MyCuteClass = NamedTuple('MyCuteClass', [ ('name', str), @@ -572,5 +563,68 @@ def test_schema_with_bad_field_raises_helpful_error(self): schema_registry=SchemaTypeRegistry())) +@parameterized_class([ + { + 'pickler': pickle, + }, + { + 'pickler': dill, + }, + { + 'pickler': cloudpickle, + }, +]) +class PickleTest(unittest.TestCase): + def test_generated_class_pickle_instance(self): + schema = schema_pb2.Schema( + id="some-uuid", + fields=[ + schema_pb2.Field( + name='name', + type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), + ) + ]) + user_type = named_tuple_from_schema(schema) + instance = user_type(name="test") + + self.assertEqual(instance, self.pickler.loads(self.pickler.dumps(instance))) + + @unittest.skip("https://github.com/apache/beam/issues/22714") + def test_generated_class_pickle(self): + schema = schema_pb2.Schema( + id="some-uuid", + fields=[ + schema_pb2.Field( + name='name', + type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), + ) + ]) + user_type = named_tuple_from_schema(schema) + + self.assertEqual( + user_type, self.pickler.loads(self.pickler.dumps(user_type))) + + def test_generated_class_row_type_pickle(self): + row_proto = schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + id="some-other-uuid", + fields=[ + schema_pb2.Field( + name='name', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING), + ) + ]))) + row_type_constraint = typing_from_runner_api( + row_proto, schema_registry=SchemaTypeRegistry()) + + self.assertIsInstance(row_type_constraint, row_type.RowTypeConstraint) + + self.assertEqual( + row_type_constraint, + self.pickler.loads(self.pickler.dumps(row_type_constraint))) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/typehints/trivial_inference_test.py b/sdks/python/apache_beam/typehints/trivial_inference_test.py index aaac6f4a6e0e..4cb3e1b04ee9 100644 --- a/sdks/python/apache_beam/typehints/trivial_inference_test.py +++ b/sdks/python/apache_beam/typehints/trivial_inference_test.py @@ -401,7 +401,7 @@ def method(self): (MyClass, MyClass()), (type(MyClass.method), MyClass.method), (types.MethodType, MyClass().method), - (row_type.RowTypeConstraint([('x', int)]), beam.Row(x=37)), + (row_type.RowTypeConstraint.from_fields([('x', int)]), beam.Row(x=37)), ] for expected_type, instance in test_cases: self.assertEqual( @@ -411,18 +411,18 @@ def method(self): def testRow(self): self.assertReturnType( - row_type.RowTypeConstraint([('x', int), ('y', str)]), + row_type.RowTypeConstraint.from_fields([('x', int), ('y', str)]), lambda x, y: beam.Row(x=x + 1, y=y), [int, str]) self.assertReturnType( - row_type.RowTypeConstraint([('x', int), ('y', str)]), + row_type.RowTypeConstraint.from_fields([('x', int), ('y', str)]), lambda x: beam.Row(x=x, y=str(x)), [int]) def testRowAttr(self): self.assertReturnType( typehints.Tuple[int, str], lambda row: (row.x, getattr(row, 'y')), - [row_type.RowTypeConstraint([('x', int), ('y', str)])]) + [row_type.RowTypeConstraint.from_fields([('x', int), ('y', str)])]) if __name__ == '__main__':