From 9c85ba80046a5dad777fad0198b66c1226ec1e8a Mon Sep 17 00:00:00 2001 From: Brian Kohan Date: Tue, 19 Nov 2024 15:09:11 -0800 Subject: [PATCH] fix #143, some work on #140 --- django_typer/management/__init__.py | 171 +++++++++++++++--- django_typer/utils.py | 21 +++ .../commands/dj_params_all_suppressed.py | 13 ++ .../commands/dj_params_all_suppressed_init.py | 20 ++ .../management/commands/dj_params_inherit.py | 33 ++++ .../commands/dj_params_none_suppressed.py | 13 ++ .../commands/dj_params_some_suppressed.py | 22 +++ .../dj_params_some_suppressed_init.py | 34 ++++ .../commands/dj_params_subgroups.py | 48 +++++ tests/test_django_params.py | 90 +++++++++ tests/test_native.py | 7 +- 11 files changed, 445 insertions(+), 27 deletions(-) create mode 100644 tests/apps/test_app/management/commands/dj_params_all_suppressed.py create mode 100644 tests/apps/test_app/management/commands/dj_params_all_suppressed_init.py create mode 100644 tests/apps/test_app/management/commands/dj_params_inherit.py create mode 100644 tests/apps/test_app/management/commands/dj_params_none_suppressed.py create mode 100644 tests/apps/test_app/management/commands/dj_params_some_suppressed.py create mode 100644 tests/apps/test_app/management/commands/dj_params_some_suppressed_init.py create mode 100644 tests/apps/test_app/management/commands/dj_params_subgroups.py diff --git a/django_typer/management/__init__.py b/django_typer/management/__init__.py index 7be7dbc..216156a 100644 --- a/django_typer/management/__init__.py +++ b/django_typer/management/__init__.py @@ -48,8 +48,10 @@ from ..utils import ( # noqa: E402 _command_context, _load_command_plugins, + accepted_kwargs, called_from_command_definition, called_from_module, + get_current_command, get_usage_script, is_method, with_typehint, @@ -323,13 +325,35 @@ def _common_options( _common_params: t.Sequence[t.Union[click.Argument, click.Option]] = [] -def _get_common_params() -> t.Sequence[t.Union[click.Argument, click.Option]]: +def _normalize_suppressed_arguments( + command: t.Union[t.Type["TyperCommand"], "TyperCommand"], +) -> t.Set[str]: + if command.suppressed_base_arguments: + return set( + [ + arg.lstrip("--").replace("-", "_") + for arg in command.suppressed_base_arguments + ] + ) + return set() + + +def _get_common_params( + command: t.Type["TyperCommand"], +) -> t.Sequence[t.Union[click.Argument, click.Option]]: """Use typer to convert the common options to click options""" global _common_params if not _common_params: _common_params = get_params_convertors_ctx_param_name_from_function( _common_options )[0] + suppressed = _normalize_suppressed_arguments(command) + if suppressed: + return [ + param + for param in _common_params + if param.name and param.name not in suppressed + ] return _common_params @@ -339,6 +363,13 @@ def _get_common_params() -> t.Sequence[t.Union[click.Argument, click.Option]]: } +def _remove_suppressed( + command: "TyperCommand", params: t.Dict[str, t.Any], manual: t.Set[str] = set() +) -> t.Dict[str, t.Any]: + suppressed = _normalize_suppressed_arguments(command) + return {k: v for k, v in params.items() if k not in suppressed or k in manual} + + class _ParsedArgs(SimpleNamespace): """ Emulate the argparse.Namespace class so that we can pass the parsed arguments @@ -349,9 +380,13 @@ def __init__(self, args: t.Sequence[t.Any], **kwargs: t.Any): super().__init__(**kwargs) self.args = args self.traceback = kwargs.get("traceback", TyperCommand._traceback) + if not hasattr(self, "pythonpath"): + self.pythonpath = None + if not hasattr(self, "settings"): + self.settings = None def _get_kwargs(self): - return {**COMMON_DEFAULTS, **vars(self)} + return vars(self) class Context(TyperContext): @@ -404,14 +439,30 @@ def __init__( **kwargs: t.Any, ): super().__init__(command, parent=parent, **kwargs) - if supplied_params: - self._supplied_params = supplied_params if django_command: self.django_command = django_command + if supplied_params: + # if we're at the top level, default django parameters that + # were suppressed may have been injected into execute() and + # wound up here. We remove them to keep the interface honest + supplied_params = _remove_suppressed( + self.django_command, + supplied_params, + { + param.name + for param in get_typer_command( + self.django_command.typer_app + ).params + if param.name + }, + ) else: assert parent self.django_command = parent.django_command + if supplied_params: + self._supplied_params = supplied_params + self.params = self.ParamDict( {**self.params, **self.supplied_params}, supplied=list(self.supplied_params.keys()), @@ -428,7 +479,7 @@ class DjangoTyperMixin(with_typehint(CoreTyperGroup)): # type: ignore[misc] """ context_class: t.Type[click.Context] = Context - django_command: "TyperCommand" + django_command: t.Type["TyperCommand"] _callback: t.Optional[t.Callable[..., t.Any]] = None _callback_is_method: t.Optional[bool] = None common_init: bool = False @@ -502,17 +553,8 @@ def common_params(self) -> t.Sequence[t.Union[click.Argument, click.Option]]: Add the common parameters to this group only if this group is the root command's user specified initialize callback. """ - suppressed = getattr(self.django_command, "suppressed_base_arguments", None) return ( - [ - param - for param in _get_common_params() - if param.name - and param.name - not in ( - {arg.lstrip("--").replace("-", "_") for arg in suppressed or []} - ) - ] + _get_common_params(self.django_command) if self.common_init or self.no_callback else [] ) @@ -637,6 +679,37 @@ def list_commands(self, ctx: click.Context) -> t.List[str]: # hence the following mishegoss +class Finalizer(t.Generic[P, R]): + """ + A class that wraps a finalizer function and makes it callable while passing the + django command instance if expected. + """ + + finalizer: t.Callable[P, R] + is_method: bool + + def __init__(self, finalizer: t.Callable[P, R]): + self.finalizer = finalizer + self.is_method = bool(is_method(finalizer)) + + def __call__( + self, + *args: P.args, + **kwargs: P.kwargs, + ) -> R: + if self.is_method: + cmd = kwargs.pop("_command", None) or ( + getattr( + t.cast(Context, click.get_current_context(silent=True)), + "django_command", + None, + ) + or get_current_command() + ) + args = [cmd, *args] # type: ignore + return self.finalizer(*args, **accepted_kwargs(self.finalizer, kwargs)) + + @t.overload # pragma: no cover def _check_static( func: typer.models.CommandFunctionType, @@ -682,7 +755,7 @@ def _cache_initializer( **kwargs: t.Any, ): def register( - cmd: "TyperCommand", + cmd: t.Type["TyperCommand"], _name: t.Optional[str] = Default(None), _help: t.Optional[t.Union[str, Promise]] = Default(None), **extra, @@ -702,6 +775,14 @@ def register( setattr(callback, _CACHE_KEY, register) +def _cache_finalizer(callback: t.Callable[..., t.Any]): + def register(cmd: t.Type["TyperCommand"]): + finalizer = Finalizer(_strip_static(callback)) + cmd.typer_app.info.result_callback = finalizer + + setattr(callback, _CACHE_KEY, register) + + def _cache_command( callback: t.Callable[..., t.Any], name: t.Optional[str] = None, @@ -710,7 +791,7 @@ def _cache_command( **kwargs: t.Any, ): def register( - cmd: "TyperCommand", + cmd: t.Type["TyperCommand"], _name: t.Optional[str] = None, _help: t.Optional[t.Union[str, Promise]] = None, **extra, @@ -1319,9 +1400,6 @@ def __getattr__(self, name: str) -> t.Any: ) ) - # def __repr__(self) -> str: - # return f'<{self.__class__.__module__}.{self.__class__.__name__} for {repr(self.proxied)}>' - def initialize( name: t.Optional[str] = Default(None), @@ -1468,6 +1546,19 @@ def make_initializer(func: t.Callable[P2, R2]) -> t.Callable[P2, R2]: callback = initialize # allow callback as an alias +def finalize() -> t.Callable[[t.Callable[P2, R2]], t.Callable[P2, R2]]: + """ + TODO + """ + + def make_finalizer(func: t.Callable[P2, R2]) -> t.Callable[P2, R2]: + func = _check_static(func) + _cache_finalizer(func) + return func + + return make_finalizer + + def command( name: t.Optional[str] = None, *, @@ -1878,6 +1969,7 @@ class TyperCommandMeta(type): typer_app: Typer no_color: bool force_color: bool + skip_checks: bool is_compound_command: bool _handle: t.Optional[t.Callable[..., t.Any]] @@ -2332,7 +2424,14 @@ def parse_args(self, args=None, namespace=None) -> _ParsedArgs: django_command=self.django_command, args=list(args or []), ) as ctx: - common = {**COMMON_DEFAULTS, **ctx.params} + common = { + **_remove_suppressed( + self.django_command, + COMMON_DEFAULTS, + {param.name for param in cmd.params if param.name}, + ), + **ctx.params, + } self.django_command._traceback = common.get( "traceback", self.django_command._traceback ) @@ -2500,10 +2599,13 @@ def command2(self, option: t.Optional[str] = None): typer_app: Typer no_color: bool = False force_color: bool = False + skip_checks: bool = False _handle: t.Callable[..., t.Any] _traceback: bool = False _help_kwarg: t.Optional[str] = Default(None) _defined_groups: t.Dict[str, Typer] = {} + _finalizer: t.Optional[Finalizer] = None + _suppressed_base_arguments: t.Optional[t.Set[str]] = None help: t.Optional[t.Union[DefaultPlaceholder, str, Promise]] = Default(None) # type: ignore @@ -3065,7 +3167,7 @@ def _run(self, *args, **options): :return: t.Any object returned by the Typer app """ with self: - return self.typer_app( + result = self.typer_app( args=args, standalone_mode=False, supplied_params=options, @@ -3073,6 +3175,15 @@ def _run(self, *args, **options): complete_var=None, prog_name=f"{sys.argv[0]} {self.typer_app.info.name}", ) + if not self.is_compound_command and isinstance( + self.typer_app.info.result_callback, Finalizer + ): + # result callbacks are not called on singular commands by click/typer + # we do that here to keep our interface consistent + return self.typer_app.info.result_callback( + result, **options, _command=self + ) + return result def run_from_argv(self, argv): """ @@ -3102,16 +3213,30 @@ def execute(self, *args, **options): """ no_color = self.no_color force_color = self.force_color + skip_checks = self.skip_checks if options.get("no_color", None) is not None: self.no_color = options["no_color"] if options.get("force_color", None) is not None: self.force_color = options["force_color"] + if options.get("skip_checks", None) is not None: + self.skip_checks = options["skip_checks"] try: with self: - return super().execute(*args, **options) + # base class requires force_color, no_color and skip_checks to be present - we + # allow them to be suppressed + return super().execute( + *args, + **{ + "force_color": self.force_color, + "no_color": self.no_color, + "skip_checks": self.skip_checks, + **options, + }, + ) finally: self.no_color = no_color self.force_color = force_color + self.skip_checks = skip_checks def echo( self, message: t.Optional[t.Any] = None, nl: bool = True, err: bool = False diff --git a/django_typer/utils.py b/django_typer/utils.py index b7928d0..dec2e51 100644 --- a/django_typer/utils.py +++ b/django_typer/utils.py @@ -196,3 +196,24 @@ def is_method( return params[0] == "self" return isinstance(func_or_params, MethodType) return None + + +def accepts_var_kwargs(func: t.Callable[..., t.Any]) -> bool: + """ + Determines if the given function accepts variable keyword arguments. + """ + for param in reversed(inspect.signature(func).parameters.values()): + return param.kind is inspect.Parameter.VAR_KEYWORD + return False + + +def accepted_kwargs( + func: t.Callable[..., t.Any], kwargs: t.Dict[str, t.Any] +) -> t.Dict[str, t.Any]: + """ + Return the named keyword arguments that are accepted by the given function. + """ + if accepts_var_kwargs(func): + return kwargs + param_names = set(inspect.signature(func).parameters.keys()) + return {k: v for k, v in kwargs.items() if k in param_names} diff --git a/tests/apps/test_app/management/commands/dj_params_all_suppressed.py b/tests/apps/test_app/management/commands/dj_params_all_suppressed.py new file mode 100644 index 0000000..72e7626 --- /dev/null +++ b/tests/apps/test_app/management/commands/dj_params_all_suppressed.py @@ -0,0 +1,13 @@ +from django_typer.management import TyperCommand, COMMON_DEFAULTS +from typer.models import Context as TyperContext + + +class Command(TyperCommand): + help = "Test that django parameter suppression works as expected" + + suppressed_base_arguments = set(COMMON_DEFAULTS.keys()) + + def handle(self, ctx: TyperContext): + assert self.__class__ is Command + assert isinstance(ctx, TyperContext) + assert not ctx.params diff --git a/tests/apps/test_app/management/commands/dj_params_all_suppressed_init.py b/tests/apps/test_app/management/commands/dj_params_all_suppressed_init.py new file mode 100644 index 0000000..aa2dd20 --- /dev/null +++ b/tests/apps/test_app/management/commands/dj_params_all_suppressed_init.py @@ -0,0 +1,20 @@ +from django_typer.management import TyperCommand, COMMON_DEFAULTS, initialize, command +from typer.models import Context as TyperContext + + +class Command(TyperCommand): + help = "Test that django parameter suppression works as expected" + + suppressed_base_arguments = set(COMMON_DEFAULTS.keys()) + + @initialize() + def init(self, ctx: TyperContext): + assert self.__class__ is Command + assert isinstance(ctx, TyperContext) + assert not ctx.params + + @command(name="cmd") + def handle(self, ctx: TyperContext): + assert self.__class__ is Command + assert isinstance(ctx, TyperContext) + assert not ctx.params diff --git a/tests/apps/test_app/management/commands/dj_params_inherit.py b/tests/apps/test_app/management/commands/dj_params_inherit.py new file mode 100644 index 0000000..c26b7c7 --- /dev/null +++ b/tests/apps/test_app/management/commands/dj_params_inherit.py @@ -0,0 +1,33 @@ +from django_typer.management import TyperCommand, COMMON_DEFAULTS, initialize, command +from typer.models import Context as TyperContext +from .dj_params_some_suppressed_init import Command as BaseCommand + + +class Command(BaseCommand): + suppressed_base_arguments = {"verbosity", "skip_checks", "traceback", "--no-color"} + + tb: bool + + @initialize(invoke_without_command=True) + def init(self, ctx: TyperContext, traceback: bool = True): + self.tb = traceback + return self.check_context(ctx, traceback) + + @command() + def cmd(self, ctx: TyperContext): + return self.check_context(ctx, self.tb) + + def check_context(self, ctx, traceback): + assert self.suppressed_base_arguments + assert self.__class__ is Command + assert isinstance(ctx, TyperContext) + set(COMMON_DEFAULTS.keys()) + assert ctx.params + assert "verbosity" not in ctx.params + assert "skip_checks" not in ctx.params + assert "no-color" not in ctx.params + assert "traceback" in ctx.params + for param in COMMON_DEFAULTS.keys(): + if param not in self.suppressed_base_arguments and param != "no_color": + assert param in ctx.params + return f"traceback={traceback}" diff --git a/tests/apps/test_app/management/commands/dj_params_none_suppressed.py b/tests/apps/test_app/management/commands/dj_params_none_suppressed.py new file mode 100644 index 0000000..78e4146 --- /dev/null +++ b/tests/apps/test_app/management/commands/dj_params_none_suppressed.py @@ -0,0 +1,13 @@ +from django_typer.management import TyperCommand, COMMON_DEFAULTS +from typer.models import Context as TyperContext + + +class Command(TyperCommand): + help = "Test that django parameter suppression works as expected" + + suppressed_base_arguments = [] + + def handle(self, ctx: TyperContext): + assert self.__class__ is Command + assert isinstance(ctx, TyperContext) + assert not set(ctx.params.keys()).symmetric_difference(COMMON_DEFAULTS.keys()) diff --git a/tests/apps/test_app/management/commands/dj_params_some_suppressed.py b/tests/apps/test_app/management/commands/dj_params_some_suppressed.py new file mode 100644 index 0000000..7de6401 --- /dev/null +++ b/tests/apps/test_app/management/commands/dj_params_some_suppressed.py @@ -0,0 +1,22 @@ +from django_typer.management import TyperCommand, COMMON_DEFAULTS, initialize, command +from typer.models import Context as TyperContext + + +class Command(TyperCommand): + help = "Test that django parameter suppression works as expected" + + suppressed_base_arguments = {"verbosity", "skip_checks", "traceback"} + + def handle(self, ctx: TyperContext, traceback: bool = True): + assert self.suppressed_base_arguments + assert self.__class__ is Command + assert isinstance(ctx, TyperContext) + set(COMMON_DEFAULTS.keys()) + assert ctx.params + assert "verbosity" not in ctx.params + assert "skip_checks" not in ctx.params + assert "traceback" in ctx.params + for param in COMMON_DEFAULTS.keys(): + if param not in self.suppressed_base_arguments: + assert param in ctx.params + return f"traceback={traceback}" diff --git a/tests/apps/test_app/management/commands/dj_params_some_suppressed_init.py b/tests/apps/test_app/management/commands/dj_params_some_suppressed_init.py new file mode 100644 index 0000000..8e0ca37 --- /dev/null +++ b/tests/apps/test_app/management/commands/dj_params_some_suppressed_init.py @@ -0,0 +1,34 @@ +from typing import Any, TextIO +from django_typer.management import TyperCommand, COMMON_DEFAULTS, initialize, command +from typer.models import Context as TyperContext + + +class Command(TyperCommand): + help = "Test that django parameter suppression works as expected" + + suppressed_base_arguments = {"verbosity", "skip_checks", "traceback"} + + tb: bool + + @initialize(invoke_without_command=True) + def init(self, ctx: TyperContext, traceback: bool = True): + self.tb = traceback + return self.check_context(ctx, traceback) + + @command() + def cmd(self, ctx: TyperContext): + return self.check_context(ctx, self.tb) + + def check_context(self, ctx, traceback): + assert self.suppressed_base_arguments + assert self.__class__ is Command + assert isinstance(ctx, TyperContext) + set(COMMON_DEFAULTS.keys()) + assert ctx.params + assert "verbosity" not in ctx.params + assert "skip_checks" not in ctx.params + assert "traceback" in ctx.params + for param in COMMON_DEFAULTS.keys(): + if param not in self.suppressed_base_arguments: + assert param in ctx.params + return f"traceback={traceback}" diff --git a/tests/apps/test_app/management/commands/dj_params_subgroups.py b/tests/apps/test_app/management/commands/dj_params_subgroups.py new file mode 100644 index 0000000..ac52a8d --- /dev/null +++ b/tests/apps/test_app/management/commands/dj_params_subgroups.py @@ -0,0 +1,48 @@ +from typing import Any, TextIO +from django_typer.management import ( + TyperCommand, + COMMON_DEFAULTS, + initialize, + command, + group, +) +from typer.models import Context as TyperContext + + +class Command(TyperCommand): + help = "Test that django parameter suppression works as expected" + + suppressed_base_arguments = {"verbosity", "skip_checks", "traceback"} + + tb: bool + + @initialize(invoke_without_command=True) + def init(self, ctx: TyperContext, traceback: bool = True): + self.tb = traceback + return self.check_context(ctx, traceback) + + @command() + def cmd(self, ctx: TyperContext): + return self.check_context(ctx, self.tb) + + @group(invoke_without_command=True) + def subgroup(self, ctx: TyperContext, skip_checks: bool = True): + return self.check_context(ctx, self.tb, skip_checks=skip_checks) + + def check_context(self, ctx, traceback, skip_checks=None): + assert self.suppressed_base_arguments + assert self.__class__ is Command + assert isinstance(ctx, TyperContext) + set(COMMON_DEFAULTS.keys()) + assert ctx.params + assert "verbosity" not in ctx.params + assert ( + "skip_checks" not in ctx.params + if skip_checks is None + else "skip_checks" in ctx.params + ) + assert "traceback" in ctx.params + for param in COMMON_DEFAULTS.keys(): + if param not in self.suppressed_base_arguments: + assert param in ctx.params + return f"traceback={traceback}, skipchecks={skip_checks}" diff --git a/tests/test_django_params.py b/tests/test_django_params.py index 45af811..ef07296 100644 --- a/tests/test_django_params.py +++ b/tests/test_django_params.py @@ -227,3 +227,93 @@ def test_override_call_command(self): "optional_arg": 1, }, ) + + def test_no_suppression(self): + self.assertEqual(run_command("dj_params_none_suppressed")[2], 0) + call_command("dj_params_none_suppressed") + + def test_all_suppressed(self): + self.assertEqual(run_command("dj_params_all_suppressed")[2], 0) + call_command("dj_params_all_suppressed") + + def test_all_suppressed_init(self): + self.assertEqual(run_command("dj_params_all_suppressed_init", "cmd")[2], 0) + call_command("dj_params_all_suppressed_init", "cmd") + + def test_some_suppressed(self): + stdout, _, retcode = run_command("dj_params_some_suppressed") + self.assertEqual(retcode, 0) + self.assertEqual(stdout.strip(), "traceback=True") + + self.assertEqual(call_command("dj_params_some_suppressed"), "traceback=True") + + stdout, _, retcode = run_command("dj_params_some_suppressed", "--no-traceback") + self.assertEqual(retcode, 0) + self.assertEqual(stdout.strip(), "traceback=False") + + self.assertEqual( + call_command("dj_params_some_suppressed", "--no-traceback"), + "traceback=False", + ) + self.assertEqual( + call_command("dj_params_some_suppressed", traceback=False), + "traceback=False", + ) + + def test_some_suppressed_init(self, command_name="dj_params_some_suppressed_init"): + stdout, _, retcode = run_command(command_name) + self.assertEqual(retcode, 0) + self.assertTrue("traceback=True" in stdout.strip()) + + stdout, _, retcode = run_command(command_name, "cmd") + self.assertEqual(retcode, 0) + self.assertTrue("traceback=True" in stdout.strip()) + + self.assertTrue("traceback=True" in call_command(command_name)) + self.assertTrue("traceback=True" in call_command(command_name, "cmd")) + + stdout, _, retcode = run_command(command_name, "--no-traceback") + self.assertEqual(retcode, 0) + self.assertTrue("traceback=False" in stdout.strip()) + + self.assertTrue( + "traceback=False" in call_command(command_name, "--no-traceback") + ) + self.assertTrue( + "traceback=False" in call_command(command_name, traceback=False), "False" + ) + + stdout, _, retcode = run_command(command_name, "--no-traceback", "cmd") + self.assertEqual(retcode, 0) + self.assertTrue("traceback=False" in stdout.strip()) + + self.assertTrue( + "traceback=False" in call_command(command_name, "--no-traceback", "cmd") + ) + self.assertTrue( + "traceback=False" in call_command(command_name, "cmd", traceback=False) + ) + + def test_some_suppressed_inherit(self): + self.test_some_suppressed_init("dj_params_inherit") + + def test_some_suppressed_subgroups(self): + self.test_some_suppressed_init("dj_params_subgroups") + stdout, _, retcode = run_command( + "dj_params_subgroups", "--no-traceback", "subgroup", "--no-skip-checks" + ) + self.assertEqual(retcode, 0) + self.assertTrue("traceback=False, skipchecks=False" in stdout.strip()) + + self.assertTrue( + "traceback=True, skipchecks=True" + in call_command( + "dj_params_subgroups", "--traceback", "subgroup", "--skip-checks" + ) + ) + self.assertTrue( + "traceback=False, skipchecks=True" + in call_command( + "dj_params_subgroups", "subgroup", traceback=False, skip_checks=True + ) + ) diff --git a/tests/test_native.py b/tests/test_native.py index 5fbc03a..a0937e9 100644 --- a/tests/test_native.py +++ b/tests/test_native.py @@ -172,10 +172,9 @@ def test_native_direct(self): self.assertEqual(native.main("Brian"), {"name": "Brian"}) def test_native_cli(self): - self.assertEqual( - run_command(self.command, *self.settings, "Brian")[0].strip(), - str({"name": "Brian"}), - ) + stdout, stderr, retcode = run_command(self.command, *self.settings, "Brian") + self.assertEqual(retcode, 0, stderr) + self.assertEqual(stdout.strip(), str({"name": "Brian"})) self.assertEqual( str(run_command(self.command, *self.settings, "--version")[0]).strip(), DJANGO_VERSION,