Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GeneratedClassRowTypeConstraint #22679

Merged
2 changes: 2 additions & 0 deletions sdks/python/apache_beam/coders/row_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def __reduce__(self):


typecoders.registry.register_coder(row_type.RowTypeConstraint, RowCoder)
typecoders.registry.register_coder(
row_type.GeneratedClassRowTypeConstraint, RowCoder)


def _coder_from_type(field_type):
Expand Down
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2977,7 +2977,7 @@ def _key_type_hint(self, input_type):
expr = self._key_fields[0][1]
return trivial_inference.infer_return_type(expr, [input_type])
else:
return row_type.RowTypeConstraint([
return row_type.RowTypeConstraint.from_fields([
(name, trivial_inference.infer_return_type(expr, [input_type]))
for (name, expr) in self._key_fields
])
Expand Down Expand Up @@ -3089,7 +3089,7 @@ def expand(self, pcoll):
for name, expr in self._fields}))

def infer_output_type(self, input_type):
return row_type.RowTypeConstraint([
return row_type.RowTypeConstraint.from_fields([
(name, trivial_inference.infer_return_type(expr, [input_type]))
for (name, expr) in self._fields
])
Expand Down
80 changes: 67 additions & 13 deletions sdks/python/apache_beam/typehints/row_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@

from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple

from apache_beam.typehints import typehints
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from apache_beam.typehints.schema_registry import SchemaTypeRegistry

# Name of the attribute added to user types (existing and generated) to store
# the corresponding schema ID
Expand All @@ -37,10 +37,10 @@
class RowTypeConstraint(typehints.TypeConstraint):
def __init__(
self,
fields: List[Tuple[str, type]],
user_type=None,
schema_options: Optional[List[Tuple[str, Any]]] = None,
field_options: Optional[Dict[str, List[Tuple[str, Any]]]] = None):
fields: Sequence[Tuple[str, type]],
user_type,
schema_options: Optional[Sequence[Tuple[str, Any]]] = None,
field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None):
"""For internal use only, no backwards comatibility guaratees. See
https://beam.apache.org/documentation/programming-guide/#schemas-for-pl-types
for guidance on creating PCollections with inferred schemas.
Expand Down Expand Up @@ -83,19 +83,16 @@ def __init__(
# Note schema ID can be None if the schema is not registered yet.
# Currently registration happens when converting to schema protos, in
# apache_beam.typehints.schemas
if self._user_type is not None:
self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None)
else:
self._schema_id = None
self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None)

self._schema_options = schema_options or []
self._field_options = field_options or {}

@staticmethod
def from_user_type(
user_type: type,
schema_options: Optional[List[Tuple[str, Any]]] = None,
field_options: Optional[Dict[str, List[Tuple[str, Any]]]] = None
schema_options: Optional[Sequence[Tuple[str, Any]]] = None,
field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None
) -> Optional[RowTypeConstraint]:
if match_is_named_tuple(user_type):
fields = [(name, user_type.__annotations__[name])
Expand All @@ -112,8 +109,19 @@ def from_user_type(
return None

@staticmethod
def from_fields(fields: Sequence[Tuple[str, type]]) -> RowTypeConstraint:
return RowTypeConstraint(fields=fields, user_type=None)
def from_fields(
fields: Sequence[Tuple[str, type]],
schema_id: Optional[str] = None,
schema_options: Optional[Sequence[Tuple[str, Any]]] = None,
field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None,
schema_registry: Optional[SchemaTypeRegistry] = None,
) -> RowTypeConstraint:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> RowTypeConstraint:
) -> GeneratedClassRowTypeConstraint:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually prefer to use the base-class here. GeneratedClassRowTypeConstraint can be an implementation detail.

return GeneratedClassRowTypeConstraint(
fields,
schema_id=schema_id,
schema_options=schema_options,
field_options=field_options,
schema_registry=schema_registry)

@property
def user_type(self):
Expand Down Expand Up @@ -160,3 +168,49 @@ def __repr__(self):

def get_type_for(self, name):
return dict(self._fields)[name]


class GeneratedClassRowTypeConstraint(RowTypeConstraint):
"""Specialization of RowTypeConstraint which relies on a generated user_type.

Since the generated user_type cannot be pickled, we supply a custom __reduce__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the generated user_type cannot be pickled

Can you please explain why this is the case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't fully understand it, but it's been a consistent issue with the Schema code. Each pickle library (built-in, dill, cloudpickle) fails for a different reason. I filed #22714 to track this, and added a (skipped) test.

function that will regenerate the user_type.
"""
def __init__(
self,
fields,
schema_id: Optional[str] = None,
schema_options: Optional[Sequence[Tuple[str, Any]]] = None,
field_options: Optional[Dict[str, Sequence[Tuple[str, Any]]]] = None,
schema_registry: Optional[SchemaTypeRegistry] = None,
):
from apache_beam.typehints.schemas import named_fields_to_schema
from apache_beam.typehints.schemas import named_tuple_from_schema

kwargs = {'schema_registry': schema_registry} if schema_registry else {}

schema = named_fields_to_schema(
fields,
schema_id=schema_id,
schema_options=schema_options,
field_options=field_options,
**kwargs)
user_type = named_tuple_from_schema(schema, **kwargs)
setattr(user_type, _BEAM_SCHEMA_ID, schema_id)

super().__init__(
fields,
user_type,
schema_options=schema_options,
field_options=field_options)

def __reduce__(self):
return (
RowTypeConstraint.from_fields,
(
self._fields,
self._schema_id,
self._schema_options,
self._field_options,
None,
))
54 changes: 54 additions & 0 deletions sdks/python/apache_beam/typehints/schema_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""This module is intended for internal use only. Nothing defined here provides
any backwards-compatibility guarantee.
"""

from uuid import uuid4


# Registry of typings for a schema by UUID
class SchemaTypeRegistry(object):
def __init__(self):
self.by_id = {}
self.by_typing = {}

def generate_new_id(self):
for _ in range(100):
schema_id = str(uuid4())
if schema_id not in self.by_id:
return schema_id

raise AssertionError(
"Failed to generate a unique UUID for schema after "
f"100 tries! Registry contains {len(self.by_id)} "
"schemas.")

def add(self, typing, schema):
self.by_id[schema.id] = (typing, schema)

def get_typing_by_id(self, unique_id):
result = self.by_id.get(unique_id, None)
return result[0] if result is not None else None

def get_schema_by_id(self, unique_id):
result = self.by_id.get(unique_id, None)
return result[1] if result is not None else None


SCHEMA_REGISTRY = SchemaTypeRegistry()
Loading