Skip to content

Commit

Permalink
Add ability to group comands
Browse files Browse the repository at this point in the history
  • Loading branch information
fwkz committed Jan 17, 2024
1 parent b90c33e commit a394769
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 40 deletions.
1 change: 1 addition & 0 deletions riposte/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .group import Group # noqa
from .riposte import Riposte # noqa
52 changes: 14 additions & 38 deletions riposte/riposte.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from pathlib import Path
import readline
import shlex
from typing import Callable, Dict, Iterable, List, Optional, Sequence
from typing import List, Optional, Sequence

from . import input_streams
from .command import Command
from .exceptions import CommandError, RiposteException, StopRiposteException
from .group import Group
from .printer.mixins import PrinterMixin
from .printer.thread import PrinterThread

Expand All @@ -16,22 +16,23 @@ def is_libedit():
return "libedit" in readline.__doc__


class Riposte(PrinterMixin):
class Riposte(Group, PrinterMixin):
def __init__(
self,
prompt: str = "riposte:~ $ ",
banner: Optional[str] = None,
history_file: Path = Path.home() / ".riposte",
history_length: int = 100,
):
super().__init__()

self.banner = banner
self.print_banner = True
self.parser = None
self.arguments = None
self.input_stream = input_streams.prompt_input(lambda: self.prompt)

self._prompt = prompt
self._commands: Dict[str, Command] = {}

self.setup_cli()

Expand Down Expand Up @@ -145,13 +146,6 @@ def _parse_line(line: str) -> List[str]:
except ValueError as err:
raise RiposteException(err)

def _get_command(self, command_name: str) -> Command:
"""Resolve command name into registered `Command` object."""
try:
return self._commands[command_name]
except KeyError:
raise CommandError(f"Unknown command: {command_name}")

def setup_cli(self):
"""Initialize CLI
Expand Down Expand Up @@ -180,6 +174,15 @@ def parse_cli_arguments(self) -> None:
Path(self.arguments.file)
)

def register_group(self, group: Group) -> None:
for command in group._commands.values():
if command.name in self._commands:
raise RiposteException(
f"'{command.name}' command already exists."
)

self._commands[command.name] = command

@property
def prompt(self):
"""Entrypoint for customizing prompt
Expand All @@ -189,33 +192,6 @@ def prompt(self):
"""
return self._prompt

def command(
self,
name: str,
description: str = "",
guides: Dict[str, Iterable[Callable]] = None,
) -> Callable:
"""Decorator for bounding command with handling function."""

def wrapper(func: Callable):
if name not in self._commands:
self._commands[name] = Command(name, func, description, guides)
else:
raise RiposteException(f"'{name}' command already exists.")
return func

return wrapper

def complete(self, command: str) -> Callable:
"""Decorator for bounding complete function with `Command`."""

def wrapper(func: Callable):
cmd = self._get_command(command)
cmd.attach_completer(func)
return func

return wrapper

def _process(self) -> None:
"""Process input provided by the input stream.
Expand Down
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from riposte import Riposte
from riposte import Group, Riposte
from riposte.command import Command


Expand Down Expand Up @@ -31,3 +31,8 @@ def command():
func=Mock(name="mocked_handling_function", __annotations__={}),
description="foo description",
)


@pytest.fixture
def group():
return Group()
35 changes: 34 additions & 1 deletion tests/test_riposte.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from riposte import Riposte, input_streams
from riposte import Group, Riposte, input_streams
from riposte.command import Command
from riposte.exceptions import CommandError, RiposteException

Expand Down Expand Up @@ -222,3 +222,36 @@ def test_parse_cli_arguments_file(mocked_input_streams, repl: Riposte):
Path(arguments.file)
)
assert repl.input_stream is mocked_input_streams.file_input.return_value


def test_register_group(repl: Riposte, group: Group):
@group.command("foo")
def foo():
pass

@repl.command("bar")
def bar():
pass

commands_before_register = repl._commands

repl.register_group(group)

assert repl._commands == {"foo": foo, **commands_before_register}


def test_register_group_existing_command(repl: Riposte, group: Group):
@group.command("foo")
def foo():
pass

@repl.command("foo")
def bar():
pass

commands_before_register = repl._commands

with pytest.raises(RiposteException):
repl.register_group(group)

assert repl._commands == commands_before_register

0 comments on commit a394769

Please sign in to comment.