diff --git a/ChangeLog b/ChangeLog index 5ad8277452..e65a4c4104 100644 --- a/ChangeLog +++ b/ChangeLog @@ -11,6 +11,7 @@ What's New in astroid 2.6.6? ============================ Release date: TBA +* Added support to infer return type of ``typing.cast()`` What's New in astroid 2.6.5? diff --git a/astroid/brain/brain_typing.py b/astroid/brain/brain_typing.py index a949bd91e3..a28da573b5 100644 --- a/astroid/brain/brain_typing.py +++ b/astroid/brain/brain_typing.py @@ -28,6 +28,7 @@ Call, Const, Name, + NodeNG, Subscript, ) from astroid.scoped_nodes import ClassDef, FunctionDef @@ -356,6 +357,36 @@ def infer_tuple_alias( return iter([class_def]) +def _looks_like_typing_cast(node: Call) -> bool: + return isinstance(node, Call) and ( + isinstance(node.func, Name) + and node.func.name == "cast" + or isinstance(node.func, Attribute) + and node.func.attrname == "cast" + ) + + +def infer_typing_cast( + node: Call, ctx: context.InferenceContext = None +) -> typing.Iterator[NodeNG]: + """Infer call to cast() returning same type as casted-from var""" + if not isinstance(node.func, (Name, Attribute)): + raise UseInferenceDefault + + try: + func = next(node.func.infer(context=ctx)) + except InferenceError as exc: + raise UseInferenceDefault from exc + if ( + not isinstance(func, FunctionDef) + or func.qname() != "typing.cast" + or len(node.args) != 2 + ): + raise UseInferenceDefault + + return node.args[1].infer(context=ctx) + + AstroidManager().register_transform( Call, inference_tip(infer_typing_typevar_or_newtype), @@ -364,6 +395,9 @@ def infer_tuple_alias( AstroidManager().register_transform( Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript ) +AstroidManager().register_transform( + Call, inference_tip(infer_typing_cast), _looks_like_typing_cast +) if PY39_PLUS: AstroidManager().register_transform( diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index 31d190d201..dd040bd4b0 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -1810,6 +1810,38 @@ def test_typing_object_builtin_subscriptable(self): self.assertIsInstance(inferred, nodes.ClassDef) self.assertIsInstance(inferred.getattr("__iter__")[0], nodes.FunctionDef) + def test_typing_cast(self): + node = builder.extract_node( + """ + from typing import cast + class A: + pass + + b = 42 + a = cast(A, b) + a + """ + ) + inferred = next(node.infer()) + assert isinstance(inferred, nodes.Const) + assert inferred.value == 42 + + def test_typing_cast_attribute(self): + node = builder.extract_node( + """ + import typing + class A: + pass + + b = 42 + a = typing.cast(A, b) + a + """ + ) + inferred = next(node.infer()) + assert isinstance(inferred, nodes.Const) + assert inferred.value == 42 + class ReBrainTest(unittest.TestCase): def test_regex_flags(self):