From 53bcdded534494674f893112f71d3be344d65363 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Mon, 3 Jun 2024 04:45:13 -0700 Subject: [PATCH] Avoid error if origin has a buggy __eq__ (#422) Fixes #419 Co-authored-by: Alex Waygood --- CHANGELOG.md | 6 ++++++ src/test_typing_extensions.py | 16 ++++++++++++++++ src/typing_extensions.py | 17 ++++++++++++----- 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a5937a6..776a101e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# Unreleased + +- Fix regression in v4.12.0 where specialization of certain + generics with an overridden `__eq__` method would raise errors. + Patch by Jelle Zijlstra. + # Release 4.12.1 (June 1, 2024) - Preliminary changes for compatibility with the draft implementation diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 8ba0bf74..bf7600a1 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -6617,6 +6617,22 @@ def test_allow_default_after_non_default_in_alias(self): a4 = Callable[[Unpack[Ts]], T] self.assertEqual(a4.__args__, (Unpack[Ts], T)) + @skip_if_py313_beta_1 + def test_generic_with_broken_eq(self): + # See https://github.com/python/typing_extensions/pull/422 for context + class BrokenEq(type): + def __eq__(self, other): + if other is typing_extensions.Protocol: + raise TypeError("I'm broken") + return False + + class G(Generic[T], metaclass=BrokenEq): + pass + + alias = G[int] + self.assertIs(get_origin(alias), G) + self.assertEqual(get_args(alias), (int,)) + @skipIf( sys.version_info < (3, 11, 1), "Not yet backported for older versions of Python" diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 46084fa5..dec429ca 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -2954,13 +2954,20 @@ def _check_generic(cls, parameters, elen): def _has_generic_or_protocol_as_origin() -> bool: try: frame = sys._getframe(2) - # not all platforms have sys._getframe() - except AttributeError: + # - Catch AttributeError: not all Python implementations have sys._getframe() + # - Catch ValueError: maybe we're called from an unexpected module + # and the call stack isn't deep enough + except (AttributeError, ValueError): return False # err on the side of leniency else: - return frame.f_locals.get("origin") in ( - typing.Generic, Protocol, typing.Protocol - ) + # If we somehow get invoked from outside typing.py, + # also err on the side of leniency + if frame.f_globals.get("__name__") != "typing": + return False + origin = frame.f_locals.get("origin") + # Cannot use "in" because origin may be an object with a buggy __eq__ that + # throws an error. + return origin is typing.Generic or origin is Protocol or origin is typing.Protocol _TYPEVARTUPLE_TYPES = {TypeVarTuple, getattr(typing, "TypeVarTuple", None)}