Skip to content

Commit

Permalink
Fixes to disallowed_imports (#646)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Jun 12, 2023
1 parent 6128e20 commit 3766cf8
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## Unreleased

- Add `disallowed_imports` configuration option to disallow
imports of specific modules (#645)
imports of specific modules (#645, #646)
- Consider an annotated assignment without a value to be
an exported name (#644)
- Improve the location where `missing_parameter_annotation`
Expand Down
23 changes: 14 additions & 9 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2320,15 +2320,20 @@ def check_deprecation(self, node: ast.AST, value: Value) -> bool:

# Imports

def check_for_disallowed_import(self, node: ast.AST, name: str) -> None:
print("CHECK", name)
def check_for_disallowed_import(
self, node: ast.AST, name: str, *, check_parents: bool = True
) -> None:
disallowed = self.options.get_value_for(DisallowedImports)
if name in disallowed:
self._show_error_if_checking(
node,
f"Disallowed import of module {name!r}",
error_code=ErrorCode.disallowed_import,
)
parts = name.split(".") if check_parents else [name]
for i in range(len(parts)):
name_to_check = ".".join(parts[: i + 1])
if name_to_check in disallowed:
self._show_error_if_checking(
node,
f"Disallowed import of module {name!r}",
error_code=ErrorCode.disallowed_import,
)
break

def visit_Import(self, node: ast.Import) -> None:
self.generic_visit(node)
Expand Down Expand Up @@ -2434,7 +2439,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
else:
error_node = node
self.check_for_disallowed_import(
error_node, f"{node.module}.{alias.name}"
error_node, f"{node.module}.{alias.name}", check_parents=False
)

self._maybe_record_usages_from_import(node)
Expand Down
1 change: 1 addition & 0 deletions pyanalyze/test.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ class_attribute_transformers = [
disallowed_imports = [
"getopt",
"email.quoprimime",
"xml",
]
21 changes: 14 additions & 7 deletions pyanalyze/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,45 @@ class TestDisallowedImport(TestNameCheckVisitorBase):
@assert_passes()
def test_top_level(self):
import getopt # E: disallowed_import
import xml.etree.ElementTree # E: disallowed_import

from getopt import GetoptError # E: disallowed_import

print(getopt, GetoptError) # shut up flake8
print(getopt, GetoptError, xml) # shut up flake8

def capybara():
import getopt # E: disallowed_import
from getopt import GetoptError # E: disallowed_import
import xml.etree.ElementTree # E: disallowed_import

print(getopt, GetoptError)
print(getopt, GetoptError, xml)

@assert_passes()
def test_nested(self):
import email.quoprimime # E: disallowed_import
import email.base64mime # ok
from email.quoprimime import unquote # E: disallowed_import
from xml.etree import ElementTree # E: disallowed_import

print(email, unquote)
print(email, unquote, ElementTree)

def capybara():
import email.quoprimime # E: disallowed_import
import email.base64mime # ok
from email.quoprimime import unquote # E: disallowed_import
from email import quoprimime # E: disallowed_import
from xml.etree import ElementTree # E: disallowed_import

print(email, unquote, quoprimime)
print(email, unquote, ElementTree)

@assert_passes()
def test_import_from(self):
from email import quoprimime # E: disallowed_import
from email import base64mime # ok

print(quoprimime)
print(quoprimime, base64mime)

def capybara():
from email import quoprimime # E: disallowed_import
from email import base64mime # ok

print(quoprimime)
print(quoprimime, base64mime)

0 comments on commit 3766cf8

Please sign in to comment.