From 7cdc46ac30a749589029231ef36cf8e256c3b3b2 Mon Sep 17 00:00:00 2001 From: Steffen Neubauer Date: Mon, 13 Nov 2017 01:46:38 +0000 Subject: [PATCH] Simplify code and add tests --- click/_termui_impl.py | 64 ++++++++++++++++++++----------------------- click/termui.py | 23 ++++++++++------ tests/test_utils.py | 21 ++++++++++++-- 3 files changed, 61 insertions(+), 47 deletions(-) diff --git a/click/_termui_impl.py b/click/_termui_impl.py index 98d33d91c..923a99aa6 100644 --- a/click/_termui_impl.py +++ b/click/_termui_impl.py @@ -16,7 +16,7 @@ import inspect from ._compat import _default_text_stdout, range_type, PY2, isatty, \ open_stream, strip_ansi, term_len, get_best_encoding, WIN, int_types, \ - CYGWIN + string_types, text_type, CYGWIN from .utils import echo from .exceptions import ClickException @@ -273,54 +273,47 @@ def next(self): del next -def pager(text, color=None): +def pager(generator, color=None): """Decide what method to use for paging through text.""" stdout = _default_text_stdout() if not isatty(sys.stdin) or not isatty(stdout): - return _nullpager(stdout, text, color) + return _nullpager(stdout, generator, color) pager_cmd = (os.environ.get('PAGER', None) or '').strip() if pager_cmd: if WIN: - return _tempfilepager(text, pager_cmd, color) - return _pipepager(text, pager_cmd, color) + return _tempfilepager(generator, pager_cmd, color) + return _pipepager(generator, pager_cmd, color) if os.environ.get('TERM') in ('dumb', 'emacs'): - return _nullpager(stdout, text, color) + return _nullpager(stdout, generator, color) if WIN or sys.platform.startswith('os2'): - return _tempfilepager(text, 'more <', color) + return _tempfilepager(generator, 'more <', color) if hasattr(os, 'system') and os.system('(less) 2>/dev/null') == 0: - return _pipepager(text, 'less', color) + return _pipepager(generator, 'less', color) import tempfile fd, filename = tempfile.mkstemp() os.close(fd) try: if hasattr(os, 'system') and os.system('more "%s"' % filename) == 0: - return _pipepager(text, 'more', color) - return _nullpager(stdout, text, color) + return _pipepager(generator, 'more', color) + return _nullpager(stdout, generator, color) finally: os.unlink(filename) -def _pipe_make_gen(text): - """Converts string or generator of strings into generator""" - if inspect.isgeneratorfunction(text): - for chunk in text(): yield chunk - elif inspect.isgenerator(text): - for chunk in text: yield chunk - else: - yield text - +def text_generator(any_gen, append=None): + """Convert any generator to a text generator""" + for o in any_gen: + if not isinstance(o, string_types): + yield text_type(o) + else: + yield o -def _pipe_make_str(text): - """Converts string or generator of strings into string""" - if inspect.isgeneratorfunction(text): - return "".join(text()) - if inspect.isgenerator(text): - return "".join(text) - return text + if append: + yield append -def _pipepager(text, cmd, color): +def _pipepager(generator, cmd, color): """Page through text by feeding it to another program. Invoking a pager through this might support colors. """ @@ -342,11 +335,11 @@ def _pipepager(text, cmd, color): env=env) encoding = get_best_encoding(c.stdin) try: - for chunk in _pipe_make_gen(text): + for text in generator: if not color: - chunk = strip_ansi(chunk) + text = strip_ansi(text) - c.stdin.write(chunk.encode(encoding, 'replace')) + c.stdin.write(chunk.encode(text, 'replace')) except (IOError, KeyboardInterrupt): pass else: @@ -369,11 +362,12 @@ def _pipepager(text, cmd, color): break -def _tempfilepager(text, cmd, color): +def _tempfilepager(generator, cmd, color): """Page through text by invoking a program on a temporary file.""" import tempfile filename = tempfile.mktemp() - text = _pipe_make_str(text) + # TODO: This never terminates if the passed generator never terminates. + text = "".join(generator) if not color: text = strip_ansi(text) encoding = get_best_encoding(sys.stdout) @@ -385,11 +379,11 @@ def _tempfilepager(text, cmd, color): os.unlink(filename) -def _nullpager(stream, text, color): +def _nullpager(stream, generator, color): """Simply print unformatted text. This is the ultimate fallback.""" - for chunk in _pipe_make_gen(text): + for text in generator: if not color: - chunk = strip_ansi(chunk) + text = strip_ansi(text) stream.write(text) diff --git a/click/termui.py b/click/termui.py index 49ce54299..2e56bbe31 100644 --- a/click/termui.py +++ b/click/termui.py @@ -3,7 +3,7 @@ import struct import inspect -from ._compat import raw_input, text_type, string_types, \ +from ._compat import raw_input, string_types, \ isatty, strip_ansi, get_winterm_size, DEFAULT_COLUMNS, WIN from .utils import echo from .exceptions import Abort, UsageError @@ -204,24 +204,29 @@ def ioctl_gwinsz(fd): return int(cr[1]), int(cr[0]) -def echo_via_pager(text, color=None): +def echo_via_pager(text_or_generator, color=None): """This function takes a text and shows it via an environment specific pager on stdout. .. versionchanged:: 3.0 Added the `color` flag. - :param text: the text to page. + :param text_or_generator: the text to page, or alternatively, a + generator emitting the text to page. :param color: controls if the pager supports ANSI colors or not. The default is autodetection. """ color = resolve_color_default(color) - if not inspect.isgenerator(text) \ - and not inspect.isgeneratorfunction(text) \ - and not isinstance(text, string_types): - text = text_type(text) - from ._termui_impl import pager - return pager(text, color) + + if inspect.isgenerator(text_or_generator): + gen = text_or_generator + elif inspect.isgeneratorfunction(text_or_generator): + gen = text_or_generator() + else: + gen = (t for t in [text_or_generator]) + + from ._termui_impl import pager, text_generator + return pager(text_generator(gen, append='\n'), color) def progressbar(iterable=None, length=None, label=None, show_eta=True, diff --git a/tests/test_utils.py b/tests/test_utils.py index 88923adbc..49fb00ba5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -146,14 +146,29 @@ def f(_): assert out == 'Password: \nScrew you.\n' +def _test_gen_func(): + yield 'a' + yield 'b' + yield 'c' + + @pytest.mark.skipif(WIN, reason='Different behavior on windows.') @pytest.mark.parametrize('cat', ['cat', 'cat ', 'cat ']) -def test_echo_via_pager(monkeypatch, capfd, cat): +@pytest.mark.parametrize('test', [ + ('haha\n', 'haha'), + ('gen\n', (c for c in 'gen')), + ('abc\n', _test_gen_func), +]) +def test_echo_via_pager(monkeypatch, capfd, cat, test): monkeypatch.setitem(os.environ, 'PAGER', cat) monkeypatch.setattr(click._termui_impl, 'isatty', lambda x: True) - click.echo_via_pager('haha') + + expected_output = test[0] + test_input = test[1] + + click.echo_via_pager(test_input) out, err = capfd.readouterr() - assert out == 'haha\n' + assert out == expected_output @pytest.mark.skipif(WIN, reason='Test does not make sense on Windows.')