Skip to content

Commit

Permalink
Infer the type of the result of calling typing.cast()
Browse files Browse the repository at this point in the history
  • Loading branch information
timmartin committed Jul 22, 2021
1 parent 22e0cdc commit 9e51d37
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ What's New in astroid 2.6.6?
============================
Release date: TBA


* Added support to infer return type of typing.cast()

What's New in astroid 2.6.5?
============================
Expand Down
27 changes: 27 additions & 0 deletions astroid/brain/brain_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Call,
Const,
Name,
NodeNG,
Subscript,
)
from astroid.scoped_nodes import ClassDef, FunctionDef
Expand Down Expand Up @@ -356,6 +357,29 @@ def infer_tuple_alias(
return iter([class_def])


def _looks_like_typing_cast(node: Call) -> bool:
return isinstance(node, Call) and (
isinstance(node.func, Name)
and node.func.name == "cast"
or isinstance(node.func, Attribute)
and node.func.attrname == "cast"
)


def infer_typing_cast(
node: Call, ctx: context.InferenceContext = None
) -> typing.Iterator[NodeNG]:
"""Infer call to cast() returning same type as casted-from var"""
try:
func = next(node.func.infer(context=ctx))
except InferenceError as exc:
raise UseInferenceDefault from exc
if func.qname() != "typing.cast" or len(node.args) != 2:
raise UseInferenceDefault

return node.args[1].infer(context=ctx)


AstroidManager().register_transform(
Call,
inference_tip(infer_typing_typevar_or_newtype),
Expand All @@ -364,6 +388,9 @@ def infer_tuple_alias(
AstroidManager().register_transform(
Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript
)
AstroidManager().register_transform(
Call, inference_tip(infer_typing_cast), _looks_like_typing_cast
)

if PY39_PLUS:
AstroidManager().register_transform(
Expand Down
32 changes: 32 additions & 0 deletions tests/unittest_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,38 @@ def test_typing_object_builtin_subscriptable(self):
self.assertIsInstance(inferred, nodes.ClassDef)
self.assertIsInstance(inferred.getattr("__iter__")[0], nodes.FunctionDef)

def test_typing_cast(self):
node = builder.extract_node(
"""
from typing import cast
class A:
pass
b = list()
a = cast(A, b)
a
"""
)
inferred = next(node.infer())
assert isinstance(inferred, bases.Instance)
assert inferred.name == "list"

def test_typing_cast_attribute(self):
node = builder.extract_node(
"""
import typing
class A:
pass
b = list()
a = typing.cast(A, b)
a
"""
)
inferred = next(node.infer())
assert isinstance(inferred, bases.Instance)
assert inferred.name == "list"


class ReBrainTest(unittest.TestCase):
def test_regex_flags(self):
Expand Down

0 comments on commit 9e51d37

Please sign in to comment.