diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index 061ca6fc602..fb255c5b0b0 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -17,6 +17,7 @@ """This module defines yaml wrappings for some ML transforms.""" from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Optional @@ -40,7 +41,7 @@ class ModelHandlerProvider: - handler_types: Dict[str, "ModelHandlerProvider"] = {} + handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {} def __init__( self, @@ -435,16 +436,15 @@ def fn(x: PredictionResult): typ = model_handler['type'] model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None) if model_handler_provider and issubclass(model_handler_provider, - ModelHandlerProvider): + type(ModelHandlerProvider)): model_handler_provider.validate(model_handler['config']) else: raise NotImplementedError(f'Unknown model handler type: {typ}.') model_handler_provider = ModelHandlerProvider.create_handler(model_handler) + user_type = RowTypeConstraint.from_user_type(pcoll.element_type.user_type) schema = RowTypeConstraint.from_fields( - list( - RowTypeConstraint.from_user_type( - pcoll.element_type.user_type)._fields) + + list(user_type._fields if user_type else []) + [(inference_tag, model_handler_provider.inference_output_type())]) return ( diff --git a/sdks/python/apache_beam/yaml/yaml_utils_test.py b/sdks/python/apache_beam/yaml/yaml_utils_test.py index 70f6ba9b519..4fd2c793e57 100644 --- a/sdks/python/apache_beam/yaml/yaml_utils_test.py +++ b/sdks/python/apache_beam/yaml/yaml_utils_test.py @@ -15,6 +15,7 @@ # limitations under the License. # +import logging import unittest import yaml