Skip to content

Commit

Permalink
Add mechanism for dynamically adding new arguments
Browse files Browse the repository at this point in the history
This will allow checks to dynamically add arguments based on the
values of other arguments. This will be needed when we add an API
to verify-copyright for projects to inject their own configuration
modules which specify a different branching strategy and optionally
have their own arguments.
  • Loading branch information
KyleFromNVIDIA committed Mar 12, 2024
1 parent 4d14f4c commit 9f9efca
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/rapids_pre_commit_hooks/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ def _calculate_lines(self):


class ExecutionContext(contextlib.AbstractContextManager):
def __init__(self, args):
def __init__(self, args, extra_args):
self.args = args
self.extra_args = extra_args
self.checks = []

def add_check(self, check):
Expand Down Expand Up @@ -291,6 +292,10 @@ def __exit__(self, exc_type, exc_value, traceback):
class LintMain:
context_class = ExecutionContext

@classmethod
def get_extra_argparser(cls, namespace):
return argparse.ArgumentParser()

def __init__(self):
self.argparser = argparse.ArgumentParser()
self.argparser.add_argument(
Expand All @@ -299,4 +304,7 @@ def __init__(self):
self.argparser.add_argument("files", nargs="+", metavar="file")

def execute(self):
return self.context_class(self.argparser.parse_args())
namespace, extra_args = self.argparser.parse_known_args()
extra_argparser = self.get_extra_argparser(namespace)
extra_namespace = extra_argparser.parse_args(extra_args)
return self.context_class(namespace, extra_namespace)
21 changes: 21 additions & 0 deletions test/rapids_pre_commit_hooks/test_lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import contextlib
import os.path
import tempfile
Expand Down Expand Up @@ -524,3 +525,23 @@ def test_bracket_file(self, bracket_file):
call().print("[bold]note:[/bold] suggested fix applied"),
call().print(),
]

def test_extra_args(self, hello_file):
class ExtraArgsMain(LintMain):
@classmethod
def get_extra_argparser(cls, args):
parser = argparse.ArgumentParser()
parser.add_argument("--extra-dynamic-arg", action="store")
return parser

with patch(
"sys.argv",
[
"check-test",
"--extra-dynamic-arg=Hello",
hello_file.name,
],
):
m = ExtraArgsMain()
with m.execute() as ctx:
assert ctx.extra_args.extra_dynamic_arg == "Hello"

0 comments on commit 9f9efca

Please sign in to comment.