Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classify imports correctly with isort #93

Merged
merged 3 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions fawltydeps/extract_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@
logger = logging.getLogger(__name__)


def isort_config(path: Path) -> isort.Config:
"""Configure isort to correctly classify import statements.

In order for isort to correctly differentiate between first- and third-party
imports, we need to pass in a configuration object that tells isort where
to look for first-party imports.
"""
return isort.Config(
directory=str(path), # Resolve first-party imports from this directory
py_version="all", # Ignore stdlib imports from all stdlib versions
)


ISORT_CONFIG = isort_config(Path("."))


class ArgParseError(Exception):
"""Indicate errors while parsing command-line arguments"""

Expand All @@ -29,15 +45,15 @@ def parse_code(code: str, *, source: Location) -> Iterator[ParsedImport]:
they appear in the code.
"""

def is_stdlib_import(name: str) -> bool:
return isort.place_module(name) == "STDLIB"
def is_external_import(name: str) -> bool:
return isort.place_module(name, config=ISORT_CONFIG) == "THIRDPARTY"

for node in ast.walk(ast.parse(code, filename=str(source.path))):
if isinstance(node, ast.Import):
logger.debug(ast.dump(node))
for alias in node.names:
name = alias.name.split(".", 1)[0]
if not is_stdlib_import(name):
if is_external_import(name):
yield ParsedImport(
name=name, source=source.supply(lineno=node.lineno)
)
Expand All @@ -48,7 +64,7 @@ def is_stdlib_import(name: str) -> bool:
# They are therefore uninteresting to us.
if node.level == 0 and node.module is not None:
name = node.module.split(".", 1)[0]
if not is_stdlib_import(name):
if is_external_import(name):
yield ParsedImport(
name=name, source=source.supply(lineno=node.lineno)
)
Expand Down Expand Up @@ -116,11 +132,17 @@ def parse_dir(path: Path) -> Iterator[ParsedImport]:
unspecified. Modules that are imported multiple times (in the same file or
across several files) will be yielded multiple times.
"""
for file in walk_dir(path):
if file.suffix == ".py":
yield from parse_python_file(file)
elif file.suffix == ".ipynb":
yield from parse_notebook_file(file)
global ISORT_CONFIG # pylint: disable=global-statement
old_config = ISORT_CONFIG
ISORT_CONFIG = isort_config(path)
try:
for file in walk_dir(path):
if file.suffix == ".py":
yield from parse_python_file(file)
elif file.suffix == ".ipynb":
yield from parse_notebook_file(file)
finally:
ISORT_CONFIG = old_config


def parse_any_arg(arg: PathOrSpecial) -> Iterator[ParsedImport]:
Expand Down
12 changes: 0 additions & 12 deletions tests/real_projects/requests.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,16 @@ sha256 = "375d6bb6b73af27c69487dcf1df51659a8ee7428420caff21253825fb338ce10"
# Value: list (set) of external packages imported
# TODO: Analyze requests properly
"<unused>" = [
"BaseHTTPServer",
"certifi",
"chardet",
"charset_normalizer",
"cryptography",
"cStringIO",
"idna",
"OpenSSL",
"pygments",
"pytest",
"requests",
"setuptools",
"SimpleHTTPServer",
"simplejson",
"StringIO",
"tests",
"trustme",
"urllib3",
]
Expand All @@ -58,21 +52,15 @@ sha256 = "375d6bb6b73af27c69487dcf1df51659a8ee7428420caff21253825fb338ce10"
# Value: list (set) of external packages imported without being declared
# TODO: Analyze requests properly
"<unused>" = [
"BaseHTTPServer",
"certifi",
"charset_normalizer",
"cryptography",
"cStringIO",
"idna",
"OpenSSL",
"pygments",
"pytest",
"requests",
"setuptools",
"SimpleHTTPServer",
"simplejson",
"StringIO",
"tests",
"trustme",
"urllib3",
]
Expand Down
55 changes: 55 additions & 0 deletions tests/test_extract_imports_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def cell_template(cell_type: str, source: List[str]):
[],
id="simple_import_from_stdlib__is_omitted",
),
pytest.param(
"from __future__ import annotations",
[],
id="import_from_future__is_omitted",
),
pytest.param(
"from numpy import array",
[("numpy", 1)],
Expand Down Expand Up @@ -148,6 +153,32 @@ def cell_template(cell_type: str, source: List[str]):
[("requests", 5), ("foo", 6), ("numpy", 7)],
id="combo_of_simple_imports__extracts_all_external_imports",
),
pytest.param(
dedent(
"""\
try: # Python 3
from http.server import HTTPServer, SimpleHTTPRequestHandler
except ImportError: # Python 2
from BaseHTTPServer import HTTPServer
from SimpleHTTPServer import SimpleHTTPRequestHandler
"""
),
[],
id="stdlib_import_with_ImportError_fallback__ignores_all",
),
pytest.param(
dedent(
"""\
if sys.version_info >= (3, 0):
from http.server import HTTPServer, SimpleHTTPRequestHandler
else:
from BaseHTTPServer import HTTPServer
from SimpleHTTPServer import SimpleHTTPRequestHandler
"""
),
[],
id="stdlib_import_with_if_else_fallback__ignores_all",
),
],
)
def test_parse_code(code, expected_import_line_pairs):
Expand Down Expand Up @@ -355,6 +386,30 @@ def test_parse_dir__imports__are_extracted_in_order_of_encounter(write_tmp_files
assert list(parse_dir(tmp_path)) == expect


def test_parse_dir__imports_from_same_dir__are_ignored(write_tmp_files):
tmp_path = write_tmp_files(
{
"my_application.py": "import my_utils",
"my_utils.py": "import sys",
}
)

assert set(parse_dir(tmp_path)) == set()


def test_parse_dir__self_imports__are_ignored(write_tmp_files):
tmp_path = write_tmp_files(
{
"my_app/__init__.py": "",
"my_app/main.py": "from my_app import utils",
"my_app/utils.py": "import numpy",
}
)

expect = {ParsedImport("numpy", Location(tmp_path / "my_app/utils.py", lineno=1))}
assert set(parse_dir(tmp_path)) == expect


def test_parse_dir__files_in_dot_dirs__are_ignored(write_tmp_files):
tmp_path = write_tmp_files(
{
Expand Down