Skip to content

Commit

Permalink
Refactor main API (#6)
Browse files Browse the repository at this point in the history
Separate argument creation from execution. This will allow us to
access the list of files before any checks are run on them.
  • Loading branch information
KyleFromNVIDIA authored Jan 19, 2024
1 parent b3f8c8d commit 878df7b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 30 deletions.
30 changes: 19 additions & 11 deletions src/rapids_pre_commit_hooks/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,9 @@ def _calculate_lines(self):
self.lines.append((line_begin, line_end))


class LintMain(contextlib.AbstractContextManager):
def __init__(self):
self.argparser = argparse.ArgumentParser()
self.argparser.add_argument("--fix", action="store_true")
self.argparser.add_argument("file", nargs="+")
class ExecutionContext(contextlib.AbstractContextManager):
def __init__(self, args):
self.args = args
self.checks = []

def add_check(self, check):
Expand All @@ -210,18 +208,16 @@ def __exit__(self, exc_type, exc_value, traceback):

warnings = False

args = self.argparser.parse_args()

for file in args.file:
for file in self.args.files:
with open(file) as f:
content = f.read()

linter = Linter(file, content)
for check in self.checks:
check(linter, args)
check(linter, self.args)

linter.print_warnings(args.fix)
if args.fix:
linter.print_warnings(self.args.fix)
if self.args.fix:
fix = linter.fix()
if fix != content:
with open(file, "w") as f:
Expand All @@ -232,3 +228,15 @@ def __exit__(self, exc_type, exc_value, traceback):

if warnings:
exit(1)


class LintMain:
context_class = ExecutionContext

def __init__(self):
self.argparser = argparse.ArgumentParser()
self.argparser.add_argument("--fix", action="store_true")
self.argparser.add_argument("files", nargs="+", metavar="file")

def execute(self):
return self.context_class(self.argparser.parse_args())
12 changes: 8 additions & 4 deletions src/rapids_pre_commit_hooks/shell/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import bashlex

from ..lint import LintMain
from ..lint import ExecutionContext, LintMain


class LintVisitor(bashlex.ast.nodevisitor):
Expand All @@ -26,9 +26,9 @@ def add_warning(self, pos, msg):
return self.linter.add_warning(pos, msg)


class ShellMain(LintMain):
def __init__(self):
super().__init__()
class ShellExecutionContext(ExecutionContext):
def __init__(self, args):
super().__init__(args)
self.visitors = []
self.add_check(self.check_shell)

Expand All @@ -42,3 +42,7 @@ def check_shell(self, linter, args):
visitor = cls(linter, args)
for part in parts:
visitor.visit(part)


class ShellMain(LintMain):
context_class = ShellExecutionContext
5 changes: 3 additions & 2 deletions src/rapids_pre_commit_hooks/shell/verify_conda_yes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ def visitcommand(self, n, parts):


def main():
with ShellMain() as m:
m.add_visitor_class(VerifyCondaYesVisitor)
m = ShellMain()
with m.execute() as ctx:
ctx.add_visitor_class(VerifyCondaYesVisitor)


if __name__ == "__main__":
Expand Down
33 changes: 20 additions & 13 deletions test/rapids_pre_commit_hooks/test_lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,20 @@ def the_check(self, linter, args):

def test_no_warnings_no_fix(self, hello_world_file, capsys):
with MockArgv("check-test", "--check-test", hello_world_file.name):
with LintMain() as m:
m.argparser.add_argument("--check-test", action="store_true")
m = LintMain()
m.argparser.add_argument("--check-test", action="store_true")
with m.execute():
pass
assert hello_world_file.read() == "Hello world!"
captured = capsys.readouterr()
assert captured.out == ""

def test_no_warnings_fix(self, hello_world_file, capsys):
with MockArgv("check-test", "--check-test", "--fix", hello_world_file.name):
with LintMain() as m:
m.argparser.add_argument("--check-test", action="store_true")
m = LintMain()
m.argparser.add_argument("--check-test", action="store_true")
with m.execute():
pass
assert hello_world_file.read() == "Hello world!"
captured = capsys.readouterr()
assert captured.out == ""
Expand All @@ -170,9 +174,10 @@ def test_warnings_no_fix(self, hello_world_file, capsys):
with MockArgv(
"check-test", "--check-test", hello_world_file.name
), pytest.raises(SystemExit, match=r"^1$"):
with LintMain() as m:
m.argparser.add_argument("--check-test", action="store_true")
m.add_check(self.the_check)
m = LintMain()
m.argparser.add_argument("--check-test", action="store_true")
with m.execute() as ctx:
ctx.add_check(self.the_check)
assert hello_world_file.read() == "Hello world!"
captured = capsys.readouterr()
assert (
Expand Down Expand Up @@ -204,9 +209,10 @@ def test_warnings_fix(self, hello_world_file, capsys):
with MockArgv(
"check-test", "--check-test", "--fix", hello_world_file.name
), pytest.raises(SystemExit, match=r"^1$"):
with LintMain() as m:
m.argparser.add_argument("--check-test", action="store_true")
m.add_check(self.the_check)
m = LintMain()
m.argparser.add_argument("--check-test", action="store_true")
with m.execute() as ctx:
ctx.add_check(self.the_check)
assert hello_world_file.read() == "Good bye, world!"
captured = capsys.readouterr()
assert (
Expand Down Expand Up @@ -242,9 +248,10 @@ def test_multiple_files(self, hello_world_file, hello_file, capsys):
hello_world_file.name,
hello_file.name,
), pytest.raises(SystemExit, match=r"^1$"):
with LintMain() as m:
m.argparser.add_argument("--check-test", action="store_true")
m.add_check(self.the_check)
m = LintMain()
m.argparser.add_argument("--check-test", action="store_true")
with m.execute() as ctx:
ctx.add_check(self.the_check)
assert hello_world_file.read() == "Good bye, world!"
assert hello_file.read() == "Good bye!"
captured = capsys.readouterr()
Expand Down

0 comments on commit 878df7b

Please sign in to comment.