diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 153b9d4b4588..0fa31bd70eec 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -37,6 +37,14 @@ _TypeMapEntry = collections.namedtuple( '_TypeMapEntry', ['match', 'arity', 'beam_type']) +_BUILTINS_TO_TYPING = { + dict: typing.Dict, + list: typing.List, + tuple: typing.Tuple, + set: typing.Set, + frozenset: typing.FrozenSet, +} + def _get_args(typ): """Returns a list of arguments to the given type. @@ -163,6 +171,22 @@ def is_forward_ref(typ): _type_var_cache = {} # type: typing.Dict[int, typehints.TypeVariable] +def convert_builtin_to_typing(typ): + """Convert recursively a given builtin to a typing object. + + Args: + typ (`builtins`): builtin object that exist in _BUILTINS_TO_TYPING. + + Returns: + type: The given builtins converted to a type. + + """ + if getattr(typ, '__origin__', None) in _BUILTINS_TO_TYPING: + args = map(convert_builtin_to_typing, typ.__args__) + typ = _BUILTINS_TO_TYPING[typ.__origin__].copy_with(tuple(args)) + return typ + + def convert_to_beam_type(typ): """Convert a given typing type to a Beam type. @@ -185,6 +209,9 @@ def convert_to_beam_type(typ): sys.version_info.minor >= 10) and (isinstance(typ, types.UnionType)): typ = typing.Union[typ] + if sys.version_info >= (3, 9) and isinstance(typ, types.GenericAlias): + typ = convert_builtin_to_typing(typ) + if isinstance(typ, typing.TypeVar): # This is a special case, as it's not parameterized by types. # Also, identity must be preserved through conversion (i.e. the same diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py index b13df6c20627..8dcff9fc2d7d 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility_test.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility_test.py @@ -24,6 +24,7 @@ import unittest from apache_beam.typehints import typehints +from apache_beam.typehints.native_type_compatibility import convert_builtin_to_typing from apache_beam.typehints.native_type_compatibility import convert_to_beam_type from apache_beam.typehints.native_type_compatibility import convert_to_beam_types from apache_beam.typehints.native_type_compatibility import convert_to_typing_type @@ -111,6 +112,47 @@ def test_convert_to_beam_type(self): converted_typing_type = convert_to_typing_type(converted_beam_type) self.assertEqual(converted_typing_type, typing_type, description) + def test_convert_to_beam_type_with_builtin_types(self): + if sys.version_info >= (3, 9): + test_cases = [ + ('builtin dict', dict[str, int], typehints.Dict[str, int]), + ('builtin list', list[str], typehints.List[str]), + ('builtin tuple', tuple[str], typehints.Tuple[str]), + ('builtin set', set[str], typehints.Set[str]), + ( + 'nested builtin', + dict[str, list[tuple[float]]], + typehints.Dict[str, typehints.List[typehints.Tuple[float]]]), + ] + + for test_case in test_cases: + description = test_case[0] + builtins_type = test_case[1] + expected_beam_type = test_case[2] + converted_beam_type = convert_to_beam_type(builtins_type) + self.assertEqual(converted_beam_type, expected_beam_type, description) + + def test_convert_builtin_to_typing(self): + if sys.version_info >= (3, 9): + test_cases = [ + ('dict', dict[str, int], typing.Dict[str, int]), + ('list', list[str], typing.List[str]), + ('tuple', tuple[str], typing.Tuple[str]), + ('set', set[str], typing.Set[str]), + ( + 'nested', + dict[str, list[tuple[float]]], + typing.Dict[str, typing.List[typing.Tuple[float]]]), + ] + + for test_case in test_cases: + description = test_case[0] + builtin_type = test_case[1] + expected_typing_type = test_case[2] + converted_typing_type = convert_builtin_to_typing(builtin_type) + self.assertEqual( + converted_typing_type, expected_typing_type, description) + def test_generator_converted_to_iterator(self): self.assertEqual( typehints.Iterator[int], diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 5cbb41e4d664..71d56ae4b4f9 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1186,20 +1186,18 @@ def __getitem__(self, type_params): def normalize(x, none_as_type=False): # None is inconsistantly used for Any, unknown, or NoneType. + + # Avoid circular imports + from apache_beam.typehints import native_type_compatibility + + if sys.version_info >= (3, 9) and isinstance(x, types.GenericAlias): + x = native_type_compatibility.convert_builtin_to_typing(x) + if none_as_type and x is None: return type(None) elif x in _KNOWN_PRIMITIVE_TYPES: return _KNOWN_PRIMITIVE_TYPES[x] - elif sys.version_info >= (3, 9) and isinstance(x, types.GenericAlias): - # TODO(https://github.com/apache/beam/issues/23366): handle PEP 585 - # generic type hints properly - raise TypeError( - 'PEP 585 generic type hints like %s are not yet supported, ' - 'use typing module containers instead. See equivalents listed ' - 'at https://docs.python.org/3/library/typing.html' % x) elif getattr(x, '__module__', None) == 'typing': - # Avoid circular imports - from apache_beam.typehints import native_type_compatibility beam_type = native_type_compatibility.convert_to_beam_type(x) if beam_type != x: # We were able to do the conversion. diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 7e2c390de320..cd7f9fc4e30f 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -525,14 +525,14 @@ def test_type_check_invalid_composite_type_arbitrary_length(self): def test_normalize_with_builtin_tuple(self): if sys.version_info >= (3, 9): - with self.assertRaises(TypeError) as e: - typehints.normalize(tuple[int, int], False) + expected_beam_type = typehints.Tuple[int, int] + converted_beam_type = typehints.normalize(tuple[int, int], False) + self.assertEqual(converted_beam_type, expected_beam_type) - self.assertEqual( - 'PEP 585 generic type hints like tuple[int, int] are not yet ' - 'supported, use typing module containers instead. See equivalents ' - 'listed at https://docs.python.org/3/library/typing.html', - e.exception.args[0]) + def test_builtin_and_type_compatibility(self): + if sys.version_info >= (3, 9): + self.assertCompatible(tuple, typing.Tuple) + self.assertCompatible(tuple[int, int], typing.Tuple[int, int]) class ListHintTestCase(TypeHintTestCase): @@ -595,14 +595,14 @@ def test_enforce_list_type_constraint_invalid_composite_type(self): def test_normalize_with_builtin_list(self): if sys.version_info >= (3, 9): - with self.assertRaises(TypeError) as e: - typehints.normalize(list[int], False) + expected_beam_type = typehints.List[int] + converted_beam_type = typehints.normalize(list[int], False) + self.assertEqual(converted_beam_type, expected_beam_type) - self.assertEqual( - 'PEP 585 generic type hints like list[int] are not yet supported, ' - 'use typing module containers instead. See equivalents listed ' - 'at https://docs.python.org/3/library/typing.html', - e.exception.args[0]) + def test_builtin_and_type_compatibility(self): + if sys.version_info >= (3, 9): + self.assertCompatible(list, typing.List) + self.assertCompatible(list[int], typing.List[int]) class KVHintTestCase(TypeHintTestCase): @@ -741,14 +741,16 @@ def test_match_type_variables(self): def test_normalize_with_builtin_dict(self): if sys.version_info >= (3, 9): - with self.assertRaises(TypeError) as e: - typehints.normalize(dict[int, str], False) + expected_beam_type = typehints.Dict[str, int] + converted_beam_type = typehints.normalize(dict[str, int], False) + self.assertEqual(converted_beam_type, expected_beam_type) - self.assertEqual( - 'PEP 585 generic type hints like dict[int, str] are not yet ' - 'supported, use typing module containers instead. See equivalents ' - 'listed at https://docs.python.org/3/library/typing.html', - e.exception.args[0]) + def test_builtin_and_type_compatibility(self): + if sys.version_info >= (3, 9): + self.assertCompatible(dict, typing.Dict) + self.assertCompatible(dict[str, int], typing.Dict[str, int]) + self.assertCompatible( + dict[str, list[int]], typing.Dict[str, typing.List[int]]) class BaseSetHintTest: