Skip to content

Commit

Permalink
Add GeneratedClassRowTypeConstraint (apache#22679)
Browse files Browse the repository at this point in the history
* Add (failing) pickling tests

* Add GeneratedClassRowTypeConstraint, plumb options

* Add top-level option conversion functions

* Refactor NamedTuple generation, always create GeneratedClassRowTypeConstraint

* Move registry to apache_beam.typehints.schema_registry

* yapf,lint

* fixup! Move registry to apache_beam.typehints.schema_registry

* Apply suggestions from code review

Co-authored-by: Andy Ye <[email protected]>

* Add None SchemaRegistry

* Add skipped test for pickling generated type

Co-authored-by: Andy Ye <[email protected]>
  • Loading branch information
2 people authored and MarcoRob committed Aug 26, 2022
1 parent 5f01a81 commit 909178b
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 118 deletions.
2 changes: 2 additions & 0 deletions sdks/python/apache_beam/coders/row_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def __reduce__(self):


typecoders.registry.register_coder(row_type.RowTypeConstraint, RowCoder)
typecoders.registry.register_coder(
row_type.GeneratedClassRowTypeConstraint, RowCoder)


def _coder_from_type(field_type):
Expand Down
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2977,7 +2977,7 @@ def _key_type_hint(self, input_type):
expr = self._key_fields[0][1]
return trivial_inference.infer_return_type(expr, [input_type])
else:
return row_type.RowTypeConstraint([
return row_type.RowTypeConstraint.from_fields([
(name, trivial_inference.infer_return_type(expr, [input_type]))
for (name, expr) in self._key_fields
])
Expand Down Expand Up @@ -3089,7 +3089,7 @@ def expand(self, pcoll):
for name, expr in self._fields}))

def infer_output_type(self, input_type):
return row_type.RowTypeConstraint([
return row_type.RowTypeConstraint.from_fields([
(name, trivial_inference.infer_return_type(expr, [input_type]))
for (name, expr) in self._fields
])
Expand Down
80 changes: 67 additions & 13 deletions sdks/python/apache_beam/typehints/row_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@

from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple

from apache_beam.typehints import typehints
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.typehints.schema_registry import SchemaTypeRegistry

# Name of the attribute added to user types (existing and generated) to store
# the corresponding schema ID
Expand All @@ -37,10 +37,10 @@
class RowTypeConstraint(typehints.TypeConstraint):
def __init__(
self,
fields: List[Tuple[str, type]],
user_type=None,
schema_options: Optional[List[Tuple[str, Any]]] = None,
field_options: Optional[Dict[str, List[Tuple[str, Any]]]] = None):
fields: Sequence[Tuple[str, type]],
user_type,
schema_options: Optional[Sequence[Tuple[str, Any]]] = None,
field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None):
"""For internal use only, no backwards comatibility guaratees. See
https://beam.apache.org/documentation/programming-guide/#schemas-for-pl-types
for guidance on creating PCollections with inferred schemas.
Expand Down Expand Up @@ -83,19 +83,16 @@ def __init__(
# Note schema ID can be None if the schema is not registered yet.
# Currently registration happens when converting to schema protos, in
# apache_beam.typehints.schemas
if self._user_type is not None:
self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None)
else:
self._schema_id = None
self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None)

self._schema_options = schema_options or []
self._field_options = field_options or {}

@staticmethod
def from_user_type(
user_type: type,
schema_options: Optional[List[Tuple[str, Any]]] = None,
field_options: Optional[Dict[str, List[Tuple[str, Any]]]] = None
schema_options: Optional[Sequence[Tuple[str, Any]]] = None,
field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None
) -> Optional[RowTypeConstraint]:
if match_is_named_tuple(user_type):
fields = [(name, user_type.__annotations__[name])
Expand All @@ -112,8 +109,19 @@ def from_user_type(
return None

@staticmethod
def from_fields(fields: Sequence[Tuple[str, type]]) -> RowTypeConstraint:
return RowTypeConstraint(fields=fields, user_type=None)
def from_fields(
fields: Sequence[Tuple[str, type]],
schema_id: Optional[str] = None,
schema_options: Optional[Sequence[Tuple[str, Any]]] = None,
field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None,
schema_registry: Optional[SchemaTypeRegistry] = None,
) -> RowTypeConstraint:
return GeneratedClassRowTypeConstraint(
fields,
schema_id=schema_id,
schema_options=schema_options,
field_options=field_options,
schema_registry=schema_registry)

@property
def user_type(self):
Expand Down Expand Up @@ -160,3 +168,49 @@ def __repr__(self):

def get_type_for(self, name):
return dict(self._fields)[name]


class GeneratedClassRowTypeConstraint(RowTypeConstraint):
"""Specialization of RowTypeConstraint which relies on a generated user_type.
Since the generated user_type cannot be pickled, we supply a custom __reduce__
function that will regenerate the user_type.
"""
def __init__(
self,
fields,
schema_id: Optional[str] = None,
schema_options: Optional[Sequence[Tuple[str, Any]]] = None,
field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None,
schema_registry: Optional[SchemaTypeRegistry] = None,
):
from apache_beam.typehints.schemas import named_fields_to_schema
from apache_beam.typehints.schemas import named_tuple_from_schema

kwargs = {'schema_registry': schema_registry} if schema_registry else {}

schema = named_fields_to_schema(
fields,
schema_id=schema_id,
schema_options=schema_options,
field_options=field_options,
**kwargs)
user_type = named_tuple_from_schema(schema, **kwargs)
setattr(user_type, _BEAM_SCHEMA_ID, schema_id)

super().__init__(
fields,
user_type,
schema_options=schema_options,
field_options=field_options)

def __reduce__(self):
return (
RowTypeConstraint.from_fields,
(
self._fields,
self._schema_id,
self._schema_options,
self._field_options,
None,
))
54 changes: 54 additions & 0 deletions sdks/python/apache_beam/typehints/schema_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""This module is intended for internal use only. Nothing defined here provides
any backwards-compatibility guarantee.
"""

from uuid import uuid4


# Registry of typings for a schema by UUID
class SchemaTypeRegistry(object):
def __init__(self):
self.by_id = {}
self.by_typing = {}

def generate_new_id(self):
for _ in range(100):
schema_id = str(uuid4())
if schema_id not in self.by_id:
return schema_id

raise AssertionError(
"Failed to generate a unique UUID for schema after "
f"100 tries! Registry contains {len(self.by_id)} "
"schemas.")

def add(self, typing, schema):
self.by_id[schema.id] = (typing, schema)

def get_typing_by_id(self, unique_id):
result = self.by_id.get(unique_id, None)
return result[0] if result is not None else None

def get_schema_by_id(self, unique_id):
result = self.by_id.get(unique_id, None)
return result[1] if result is not None else None


SCHEMA_REGISTRY = SchemaTypeRegistry()
Loading

0 comments on commit 909178b

Please sign in to comment.