From a5b77c9fbffb4687d9f6e31ac7a660134030d922 Mon Sep 17 00:00:00 2001
From: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
Date: Wed, 6 Oct 2021 14:29:41 +0200
Subject: [PATCH] Add guard helper functions from astroid

---
 ChangeLog                        |  3 ++
 pylint/checkers/imports.py       |  5 +-
 pylint/checkers/utils.py         | 36 ++++++++++++-
 tests/checkers/unittest_utils.py | 87 +++++++++++++++++++++++++++-----
 4 files changed, 114 insertions(+), 17 deletions(-)

diff --git a/ChangeLog b/ChangeLog
index e01a402a7f6..6f6d79a045e 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -64,6 +64,9 @@ Release date: TBA
 
 * Improve node information for ``invalid-name`` on function argument.
 
+* Add ``is_sys_guard`` and ``is_typing_guard`` helper functions from astroid
+  to ``pylint.checkers.utils``.
+
 
 What's New in Pylint 2.11.1?
 ============================
diff --git a/pylint/checkers/imports.py b/pylint/checkers/imports.py
index 05f1a28e9c2..abecc1712ec 100644
--- a/pylint/checkers/imports.py
+++ b/pylint/checkers/imports.py
@@ -62,6 +62,7 @@
     get_import_name,
     is_from_fallback_block,
     is_node_in_guarded_import_block,
+    is_typing_guard,
     node_ignores_exception,
 )
 from pylint.exceptions import EmptyReportError
@@ -843,8 +844,8 @@ def _add_imported_module(
         except ImportError:
             pass
 
-        in_type_checking_block = (
-            isinstance(node.parent, nodes.If) and node.parent.is_typing_guard()
+        in_type_checking_block = isinstance(node.parent, nodes.If) and is_typing_guard(
+            node.parent
         )
 
         if context_name == importedmodname:
diff --git a/pylint/checkers/utils.py b/pylint/checkers/utils.py
index 6e96b3a9783..ccbc1a8fb1c 100644
--- a/pylint/checkers/utils.py
+++ b/pylint/checkers/utils.py
@@ -1551,12 +1551,46 @@ def get_import_name(
     return modname
 
 
+def is_sys_guard(node: nodes.If) -> bool:
+    """Return True if IF stmt is a sys.version_info guard.
+
+    >>> import sys
+    >>> if sys.version_info > (3, 8):
+    >>>     from typing import Literal
+    >>> else:
+    >>>     from typing_extensions import Literal
+    """
+    if isinstance(node.test, nodes.Compare):
+        value = node.test.left
+        if isinstance(value, nodes.Subscript):
+            value = value.value
+        if (
+            isinstance(value, nodes.Attribute)
+            and value.as_string() == "sys.version_info"
+        ):
+            return True
+
+    return False
+
+
+def is_typing_guard(node: nodes.If) -> bool:
+    """Return True if IF stmt is a typing guard.
+
+    >>> from typing import TYPE_CHECKING
+    >>> if TYPE_CHECKING:
+    >>>     from xyz import a
+    """
+    return isinstance(
+        node.test, (nodes.Name, nodes.Attribute)
+    ) and node.test.as_string().endswith("TYPE_CHECKING")
+
+
 def is_node_in_guarded_import_block(node: nodes.NodeNG) -> bool:
     """Return True if node is part for guarded if block.
     I.e. `sys.version_info` or `typing.TYPE_CHECKING`
     """
     return isinstance(node.parent, nodes.If) and (
-        node.parent.is_sys_guard() or node.parent.is_typing_guard()
+        is_sys_guard(node.parent) or is_typing_guard(node.parent)
     )
 
 
diff --git a/tests/checkers/unittest_utils.py b/tests/checkers/unittest_utils.py
index 9774b2d5879..583bf3823a6 100644
--- a/tests/checkers/unittest_utils.py
+++ b/tests/checkers/unittest_utils.py
@@ -23,6 +23,7 @@
 
 import astroid
 import pytest
+from astroid import nodes
 
 from pylint.checkers import utils
 
@@ -79,7 +80,7 @@ def testGetArgumentFromCall() -> None:
 
 
 def test_error_of_type() -> None:
-    nodes = astroid.extract_node(
+    code = astroid.extract_node(
         """
     try: pass
     except AttributeError: #@
@@ -91,14 +92,14 @@ def test_error_of_type() -> None:
          pass
     """
     )
-    assert utils.error_of_type(nodes[0], AttributeError)
-    assert utils.error_of_type(nodes[0], (AttributeError,))
-    assert not utils.error_of_type(nodes[0], Exception)
-    assert utils.error_of_type(nodes[1], Exception)
+    assert utils.error_of_type(code[0], AttributeError)
+    assert utils.error_of_type(code[0], (AttributeError,))
+    assert not utils.error_of_type(code[0], Exception)
+    assert utils.error_of_type(code[1], Exception)
 
 
 def test_node_ignores_exception() -> None:
-    nodes = astroid.extract_node(
+    code = astroid.extract_node(
         """
     try:
         1/0 #@
@@ -118,14 +119,14 @@ def test_node_ignores_exception() -> None:
         pass
     """
     )
-    assert utils.node_ignores_exception(nodes[0], ZeroDivisionError)
-    assert not utils.node_ignores_exception(nodes[1], ZeroDivisionError)
-    assert not utils.node_ignores_exception(nodes[2], ZeroDivisionError)
-    assert not utils.node_ignores_exception(nodes[3], ZeroDivisionError)
+    assert utils.node_ignores_exception(code[0], ZeroDivisionError)
+    assert not utils.node_ignores_exception(code[1], ZeroDivisionError)
+    assert not utils.node_ignores_exception(code[2], ZeroDivisionError)
+    assert not utils.node_ignores_exception(code[3], ZeroDivisionError)
 
 
 def test_is_subclass_of_node_b_derived_from_node_a() -> None:
-    nodes = astroid.extract_node(
+    code = astroid.extract_node(
         """
     class Superclass: #@
         pass
@@ -134,11 +135,11 @@ class Subclass(Superclass): #@
         pass
     """
     )
-    assert utils.is_subclass_of(nodes[1], nodes[0])
+    assert utils.is_subclass_of(code[1], code[0])
 
 
 def test_is_subclass_of_node_b_not_derived_from_node_a() -> None:
-    nodes = astroid.extract_node(
+    code = astroid.extract_node(
         """
     class OneClass: #@
         pass
@@ -147,7 +148,7 @@ class AnotherClass: #@
         pass
     """
     )
-    assert not utils.is_subclass_of(nodes[1], nodes[0])
+    assert not utils.is_subclass_of(code[1], code[0])
 
 
 def test_is_subclass_of_not_classdefs() -> None:
@@ -404,3 +405,61 @@ def y(self):
         """
     )
     assert utils.get_node_last_lineno(node) == 11
+
+
+def test_if_sys_guard() -> None:
+    code = astroid.extract_node(
+        """
+    import sys
+    if sys.version_info > (3, 8):  #@
+        pass
+
+    if sys.version_info[:2] > (3, 8):  #@
+        pass
+
+    if sys.some_other_function > (3, 8):  #@
+        pass
+    """
+    )
+    assert isinstance(code, list) and len(code) == 3
+
+    assert isinstance(code[0], nodes.If)
+    assert utils.is_sys_guard(code[0]) is True
+    assert isinstance(code[1], nodes.If)
+    assert utils.is_sys_guard(code[1]) is True
+
+    assert isinstance(code[2], nodes.If)
+    assert utils.is_sys_guard(code[2]) is False
+
+
+def test_if_typing_guard() -> None:
+    code = astroid.extract_node(
+        """
+    import typing
+    import typing as t
+    from typing import TYPE_CHECKING
+
+    if typing.TYPE_CHECKING:  #@
+        pass
+
+    if t.TYPE_CHECKING:  #@
+        pass
+
+    if TYPE_CHECKING:  #@
+        pass
+
+    if typing.SOME_OTHER_CONST:  #@
+        pass
+    """
+    )
+    assert isinstance(code, list) and len(code) == 4
+
+    assert isinstance(code[0], nodes.If)
+    assert utils.is_typing_guard(code[0]) is True
+    assert isinstance(code[1], nodes.If)
+    assert utils.is_typing_guard(code[1]) is True
+    assert isinstance(code[2], nodes.If)
+    assert utils.is_typing_guard(code[2]) is True
+
+    assert isinstance(code[3], nodes.If)
+    assert utils.is_typing_guard(code[3]) is False