diff --git a/riposte/__init__.py b/riposte/__init__.py index 04008d5..b0b52b1 100644 --- a/riposte/__init__.py +++ b/riposte/__init__.py @@ -1 +1,2 @@ +from .group import Group # noqa from .riposte import Riposte # noqa diff --git a/riposte/riposte.py b/riposte/riposte.py index 3fe9d38..617115e 100644 --- a/riposte/riposte.py +++ b/riposte/riposte.py @@ -3,11 +3,11 @@ from pathlib import Path import readline import shlex -from typing import Callable, Dict, Iterable, List, Optional, Sequence +from typing import List, Optional, Sequence from . import input_streams -from .command import Command from .exceptions import CommandError, RiposteException, StopRiposteException +from .group import Group from .printer.mixins import PrinterMixin from .printer.thread import PrinterThread @@ -16,7 +16,7 @@ def is_libedit(): return "libedit" in readline.__doc__ -class Riposte(PrinterMixin): +class Riposte(Group, PrinterMixin): def __init__( self, prompt: str = "riposte:~ $ ", @@ -24,6 +24,8 @@ def __init__( history_file: Path = Path.home() / ".riposte", history_length: int = 100, ): + super().__init__() + self.banner = banner self.print_banner = True self.parser = None @@ -31,7 +33,6 @@ def __init__( self.input_stream = input_streams.prompt_input(lambda: self.prompt) self._prompt = prompt - self._commands: Dict[str, Command] = {} self.setup_cli() @@ -145,13 +146,6 @@ def _parse_line(line: str) -> List[str]: except ValueError as err: raise RiposteException(err) - def _get_command(self, command_name: str) -> Command: - """Resolve command name into registered `Command` object.""" - try: - return self._commands[command_name] - except KeyError: - raise CommandError(f"Unknown command: {command_name}") - def setup_cli(self): """Initialize CLI @@ -180,6 +174,15 @@ def parse_cli_arguments(self) -> None: Path(self.arguments.file) ) + def register_group(self, group: Group) -> None: + for command in group._commands.values(): + if command.name in self._commands: + raise RiposteException( + f"'{command.name}' command already exists." + ) + + self._commands[command.name] = command + @property def prompt(self): """Entrypoint for customizing prompt @@ -189,33 +192,6 @@ def prompt(self): """ return self._prompt - def command( - self, - name: str, - description: str = "", - guides: Dict[str, Iterable[Callable]] = None, - ) -> Callable: - """Decorator for bounding command with handling function.""" - - def wrapper(func: Callable): - if name not in self._commands: - self._commands[name] = Command(name, func, description, guides) - else: - raise RiposteException(f"'{name}' command already exists.") - return func - - return wrapper - - def complete(self, command: str) -> Callable: - """Decorator for bounding complete function with `Command`.""" - - def wrapper(func: Callable): - cmd = self._get_command(command) - cmd.attach_completer(func) - return func - - return wrapper - def _process(self) -> None: """Process input provided by the input stream. diff --git a/tests/conftest.py b/tests/conftest.py index a93b916..3dedde3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import pytest -from riposte import Riposte +from riposte import Group, Riposte from riposte.command import Command @@ -31,3 +31,8 @@ def command(): func=Mock(name="mocked_handling_function", __annotations__={}), description="foo description", ) + + +@pytest.fixture +def group(): + return Group() diff --git a/tests/test_riposte.py b/tests/test_riposte.py index 818ffdb..e27a911 100644 --- a/tests/test_riposte.py +++ b/tests/test_riposte.py @@ -3,7 +3,7 @@ import pytest -from riposte import Riposte, input_streams +from riposte import Group, Riposte, input_streams from riposte.command import Command from riposte.exceptions import CommandError, RiposteException @@ -222,3 +222,36 @@ def test_parse_cli_arguments_file(mocked_input_streams, repl: Riposte): Path(arguments.file) ) assert repl.input_stream is mocked_input_streams.file_input.return_value + + +def test_register_group(repl: Riposte, group: Group): + @group.command("foo") + def foo(): + pass + + @repl.command("bar") + def bar(): + pass + + commands_before_register = repl._commands + + repl.register_group(group) + + assert repl._commands == {"foo": foo, **commands_before_register} + + +def test_register_group_existing_command(repl: Riposte, group: Group): + @group.command("foo") + def foo(): + pass + + @repl.command("foo") + def bar(): + pass + + commands_before_register = repl._commands + + with pytest.raises(RiposteException): + repl.register_group(group) + + assert repl._commands == commands_before_register