diff --git a/gptcli/cli.py b/gptcli/cli.py index 65e1040..cd591d4 100644 --- a/gptcli/cli.py +++ b/gptcli/cli.py @@ -9,6 +9,7 @@ from rich.console import Console from rich.live import Live from rich.markdown import Markdown +from .markdown import CustomMarkdown from rich.text import Text from gptcli.session import (ALL_COMMANDS, COMMAND_CLEAR, COMMAND_QUIT, @@ -16,11 +17,7 @@ ResponseStreamer, UserInputProvider) TERMINAL_WELCOME = """ -Hi! I'm here to help. Type `:q` or Ctrl-D to exit, `:c` or Ctrl-C and Enter to clear -the conversation, `:r` or Ctrl-R to re-generate the last response. -To enter multi-line mode, enter a backslash `\\` followed by a new line. -Exit the multi-line mode by pressing ESC and then Enter (Meta+Enter). -Try `:?` for help. +Assistant: """ @@ -43,7 +40,7 @@ def print(self, text: str): self.current_text += text if self.markdown: assert self.live - content = Markdown(self.current_text, style="green") + content = CustomMarkdown(self.current_text, style="green") self.live.update(content) self.live.refresh() else: @@ -84,7 +81,7 @@ def __init__(self, markdown: bool): def on_chat_start(self): console = Console(width=80) - console.print(Markdown(TERMINAL_WELCOME)) + console.print(CustomMarkdown(TERMINAL_WELCOME)) def on_chat_clear(self): self.console.print("[bold]Cleared the conversation.[/bold]") diff --git a/gptcli/markdown.py b/gptcli/markdown.py new file mode 100644 index 0000000..873d7cb --- /dev/null +++ b/gptcli/markdown.py @@ -0,0 +1,677 @@ +from __future__ import annotations + +from typing import ClassVar, Dict, Iterable, List, Optional, Type, Union + +from markdown_it import MarkdownIt +from markdown_it.token import Token + + +from rich import box +from rich._loop import loop_first +from rich._stack import Stack +from rich.console import Console, ConsoleOptions, JustifyMethod, RenderResult +from rich.containers import Renderables +from rich.jupyter import JupyterMixin +from rich.panel import Panel +from rich.rule import Rule +from rich.segment import Segment +from rich.style import Style, StyleStack +from rich.syntax import Syntax +from rich.text import Text, TextType + +class MarkdownElement: + new_line: ClassVar[bool] = True + + @classmethod + def create(cls, markdown: "CustomMarkdown", token: Token) -> "MarkdownElement": + """Factory to create markdown element, + + Args: + markdown (Markdown): The parent Markdown object. + token (Token): A node from markdown-it. + + Returns: + MarkdownElement: A new markdown element + """ + return cls() + + def on_enter(self, context: "MarkdownContext") -> None: + """Called when the node is entered. + + Args: + context (MarkdownContext): The markdown context. + """ + + def on_text(self, context: "MarkdownContext", text: TextType) -> None: + """Called when text is parsed. + + Args: + context (MarkdownContext): The markdown context. + """ + + def on_leave(self, context: "MarkdownContext") -> None: + """Called when the parser leaves the element. + + Args: + context (MarkdownContext): [description] + """ + + def on_child_close( + self, context: "MarkdownContext", child: "MarkdownElement" + ) -> bool: + """Called when a child element is closed. + + This method allows a parent element to take over rendering of its children. + + Args: + context (MarkdownContext): The markdown context. + child (MarkdownElement): The child markdown element. + + Returns: + bool: Return True to render the element, or False to not render the element. + """ + return True + + def __rich_console__( + self, console: "Console", options: "ConsoleOptions" + ) -> "RenderResult": + return () + + +class UnknownElement(MarkdownElement): + """An unknown element. + + Hopefully there will be no unknown elements, and we will have a MarkdownElement for + everything in the document. + + """ + + +class TextElement(MarkdownElement): + """Base class for elements that render text.""" + + style_name = "none" + + def on_enter(self, context: "MarkdownContext") -> None: + self.style = context.enter_style(self.style_name) + self.text = Text(justify="left") + + def on_text(self, context: "MarkdownContext", text: TextType) -> None: + self.text.append(text, context.current_style if isinstance(text, str) else None) + + def on_leave(self, context: "MarkdownContext") -> None: + context.leave_style() + + +class Paragraph(TextElement): + """A Paragraph.""" + + style_name = "markdown.paragraph" + justify: JustifyMethod + + @classmethod + def create(cls, markdown: "CustomMarkdown", token: Token) -> "Paragraph": + return cls(justify=markdown.justify or "left") + + def __init__(self, justify: JustifyMethod) -> None: + self.justify = justify + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + self.text.justify = self.justify + yield self.text + + +class Heading(TextElement): + """A heading.""" + + @classmethod + def create(cls, markdown: "CustomMarkdown", token: Token) -> "Heading": + return cls(token.tag) + + def on_enter(self, context: "MarkdownContext") -> None: + self.text = Text() + context.enter_style(self.style_name) + + def __init__(self, tag: str) -> None: + self.tag = tag + self.style_name = f"markdown.{tag}" + super().__init__() + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + text = self.text + text.justify = "center" + if self.tag == "h1": + # Draw a border around h1s + yield Panel( + text, + box=box.HEAVY, + style="markdown.h1.border", + ) + else: + # Styled text for h2 and beyond + if self.tag == "h2": + yield Text("") + yield text + + +class CodeBlock(TextElement): + """A code block with syntax highlighting.""" + + style_name = "markdown.code_block" + + @classmethod + def create(cls, markdown: "CustomMarkdown", token: Token) -> "CodeBlock": + node_info = token.info or "" + lexer_name = node_info.partition(" ")[0] + return cls(lexer_name or "default", markdown.code_theme) + + def __init__(self, lexer_name: str, theme: str) -> None: + self.lexer_name = lexer_name + self.theme = theme + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + code = str(self.text).rstrip() + syntax = Syntax( + code, self.lexer_name, theme=self.theme, word_wrap=True, padding=0 + ) + yield syntax + + +class BlockQuote(TextElement): + """A block quote.""" + + style_name = "markdown.block_quote" + + def __init__(self) -> None: + self.elements: Renderables = Renderables() + + def on_child_close( + self, context: "MarkdownContext", child: "MarkdownElement" + ) -> bool: + self.elements.append(child) + return False + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + render_options = options.update(width=options.max_width - 4) + lines = console.render_lines(self.elements, render_options, style=self.style) + style = self.style + new_line = Segment("\n") + padding = Segment("▌ ", style) + for line in lines: + yield padding + yield from line + yield new_line + + +class HorizontalRule(MarkdownElement): + """A horizontal rule to divide sections.""" + + new_line = False + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + style = console.get_style("markdown.hr", default="none") + yield Rule(style=style) + + +class ListElement(MarkdownElement): + """A list element.""" + + @classmethod + def create(cls, markdown: "CustomMarkdown", token: Token) -> "ListElement": + return cls(token.type, int(token.attrs.get("start", 1))) + + def __init__(self, list_type: str, list_start: int | None) -> None: + self.items: List[ListItem] = [] + self.list_type = list_type + self.list_start = list_start + + def on_child_close( + self, context: "MarkdownContext", child: "MarkdownElement" + ) -> bool: + assert isinstance(child, ListItem) + self.items.append(child) + return False + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + if self.list_type == "bullet_list_open": + for item in self.items: + yield from item.render_bullet(console, options) + else: + number = 1 if self.list_start is None else self.list_start + last_number = number + len(self.items) + for index, item in enumerate(self.items): + yield from item.render_number( + console, options, number + index, last_number + ) + + +class ListItem(TextElement): + """An item in a list.""" + + style_name = "markdown.item" + + def __init__(self) -> None: + self.elements: Renderables = Renderables() + + def on_child_close( + self, context: "MarkdownContext", child: "MarkdownElement" + ) -> bool: + self.elements.append(child) + return False + + def render_bullet(self, console: Console, options: ConsoleOptions) -> RenderResult: + render_options = options.update(width=options.max_width - 3) + lines = console.render_lines(self.elements, render_options, style=self.style) + bullet_style = console.get_style("markdown.item.bullet", default="none") + + bullet = Segment(" • ", bullet_style) + padding = Segment(" " * 3, bullet_style) + new_line = Segment("\n") + for first, line in loop_first(lines): + yield bullet if first else padding + yield from line + yield new_line + + def render_number( + self, console: Console, options: ConsoleOptions, number: int, last_number: int + ) -> RenderResult: + number_width = len(str(last_number)) + 2 + render_options = options.update(width=options.max_width - number_width) + lines = console.render_lines(self.elements, render_options, style=self.style) + number_style = console.get_style("markdown.item.number", default="none") + + new_line = Segment("\n") + padding = Segment(" " * number_width, number_style) + numeral = Segment(f"{number}".rjust(number_width - 1) + " ", number_style) + for first, line in loop_first(lines): + yield numeral if first else padding + yield from line + yield new_line + + +class Link(TextElement): + @classmethod + def create(cls, markdown: "CustomMarkdown", token: Token) -> "MarkdownElement": + url = token.attrs.get("href", "#") + return cls(token.content, str(url)) + + def __init__(self, text: str, href: str): + self.text = Text(text) + self.href = href + + +class ImageItem(TextElement): + """Renders a placeholder for an image.""" + + new_line = False + + @classmethod + def create(cls, markdown: "CustomMarkdown", token: Token) -> "MarkdownElement": + """Factory to create markdown element, + + Args: + markdown (Markdown): The parent Markdown object. + token (Any): A token from markdown-it. + + Returns: + MarkdownElement: A new markdown element + """ + return cls(str(token.attrs.get("src", "")), markdown.hyperlinks) + + def __init__(self, destination: str, hyperlinks: bool) -> None: + self.destination = destination + self.hyperlinks = hyperlinks + self.link: Optional[str] = None + super().__init__() + + def on_enter(self, context: "MarkdownContext") -> None: + self.link = context.current_style.link + self.text = Text(justify="left") + super().on_enter(context) + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + link_style = Style(link=self.link or self.destination or None) + title = self.text or Text(self.destination.strip("/").rsplit("/", 1)[-1]) + if self.hyperlinks: + title.stylize(link_style) + text = Text.assemble("🌆 ", title, " ", end="") + yield text + + +class MarkdownContext: + """Manages the console render state.""" + + def __init__( + self, + console: Console, + options: ConsoleOptions, + style: Style, + inline_code_lexer: Optional[str] = None, + inline_code_theme: str = "monokai", + ) -> None: + self.console = console + self.options = options + self.style_stack: StyleStack = StyleStack(style) + self.stack: Stack[MarkdownElement] = Stack() + + self._syntax: Optional[Syntax] = None + if inline_code_lexer is not None: + self._syntax = Syntax("", inline_code_lexer, theme=inline_code_theme) + + @property + def current_style(self) -> Style: + """Current style which is the product of all styles on the stack.""" + return self.style_stack.current + + def on_text(self, text: str, node_type: str) -> None: + """Called when the parser visits text.""" + if node_type in {"fence", "code_inline"} and self._syntax is not None: + highlight_text = self._syntax.highlight(text) + highlight_text.rstrip() + self.stack.top.on_text( + self, Text.assemble(highlight_text, style=self.style_stack.current) + ) + else: + self.stack.top.on_text(self, text) + + def enter_style(self, style_name: Union[str, Style]) -> Style: + """Enter a style context.""" + style = self.console.get_style(style_name, default="none") + self.style_stack.push(style) + return self.current_style + + def leave_style(self) -> Style: + """Leave a style context.""" + style = self.style_stack.pop() + return style + + +class CustomMarkdown(JupyterMixin): + """A Markdown renderable. + + Args: + markup (str): A string containing markdown. + code_theme (str, optional): Pygments theme for code blocks. Defaults to "monokai". + justify (JustifyMethod, optional): Justify value for paragraphs. Defaults to None. + style (Union[str, Style], optional): Optional style to apply to markdown. + hyperlinks (bool, optional): Enable hyperlinks. Defaults to ``True``. + inline_code_lexer: (str, optional): Lexer to use if inline code highlighting is + enabled. Defaults to None. + inline_code_theme: (Optional[str], optional): Pygments theme for inline code + highlighting, or None for no highlighting. Defaults to None. + """ + + elements: ClassVar[Dict[str, Type[MarkdownElement]]] = { + "paragraph_open": Paragraph, + "heading_open": Heading, + "fence": CodeBlock, + "code_block": CodeBlock, + "blockquote_open": BlockQuote, + "hr": HorizontalRule, + "bullet_list_open": ListElement, + "ordered_list_open": ListElement, + "list_item_open": ListItem, + "image": ImageItem, + } + + inlines = {"em", "strong", "code", "s"} + + def __init__( + self, + markup: str, + code_theme: str = "monokai", + justify: Optional[JustifyMethod] = None, + style: Union[str, Style] = "none", + hyperlinks: bool = True, + inline_code_lexer: Optional[str] = None, + inline_code_theme: Optional[str] = None, + ) -> None: + parser = MarkdownIt().enable("strikethrough") + self.markup = markup + self.parsed = parser.parse(markup) + self.code_theme = code_theme + self.justify: Optional[JustifyMethod] = justify + self.style = style + self.hyperlinks = hyperlinks + self.inline_code_lexer = inline_code_lexer + self.inline_code_theme = inline_code_theme or code_theme + + def _flatten_tokens(self, tokens: Iterable[Token]) -> Iterable[Token]: + """Flattens the token stream.""" + for token in tokens: + is_fence = token.type == "fence" + is_image = token.tag == "img" + if token.children and not (is_image or is_fence): + yield from self._flatten_tokens(token.children) + else: + yield token + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + """Render markdown to the console.""" + style = console.get_style(self.style, default="none") + options = options.update(height=None) + context = MarkdownContext( + console, + options, + style, + inline_code_lexer=self.inline_code_lexer, + inline_code_theme=self.inline_code_theme, + ) + tokens = self.parsed + inline_style_tags = self.inlines + new_line = False + _new_line_segment = Segment.line() + + for token in self._flatten_tokens(tokens): + node_type = token.type + tag = token.tag + + entering = token.nesting == 1 + exiting = token.nesting == -1 + self_closing = token.nesting == 0 + + if node_type == "text": + context.on_text(token.content, node_type) + elif node_type == "hardbreak": + context.on_text("\n", node_type) + elif node_type == "softbreak": + context.on_text(" ", node_type) + elif node_type == "link_open": + href = str(token.attrs.get("href", "")) + if self.hyperlinks: + link_style = console.get_style("markdown.link_url", default="none") + link_style += Style(link=href) + context.enter_style(link_style) + else: + context.stack.push(Link.create(self, token)) + elif node_type == "link_close": + if self.hyperlinks: + context.leave_style() + else: + element = context.stack.pop() + assert isinstance(element, Link) + link_style = console.get_style("markdown.link", default="none") + context.enter_style(link_style) + context.on_text(element.text.plain, node_type) + context.leave_style() + context.on_text(" (", node_type) + link_url_style = console.get_style( + "markdown.link_url", default="none" + ) + context.enter_style(link_url_style) + context.on_text(element.href, node_type) + context.leave_style() + context.on_text(")", node_type) + elif ( + tag in inline_style_tags + and node_type != "fence" + and node_type != "code_block" + ): + if entering: + # If it's an opening inline token e.g. strong, em, etc. + # Then we move into a style context i.e. push to stack. + context.enter_style(f"markdown.{tag}") + elif exiting: + # If it's a closing inline style, then we pop the style + # off of the stack, to move out of the context of it... + context.leave_style() + else: + # If it's a self-closing inline style e.g. `code_inline` + context.enter_style(f"markdown.{tag}") + if token.content: + context.on_text(token.content, node_type) + context.leave_style() + else: + # Map the markdown tag -> MarkdownElement renderable + element_class = self.elements.get(token.type) or UnknownElement + element = element_class.create(self, token) + + if entering or self_closing: + context.stack.push(element) + element.on_enter(context) + + if exiting: # CLOSING tag + element = context.stack.pop() + + should_render = not context.stack or ( + context.stack + and context.stack.top.on_child_close(context, element) + ) + + if should_render: + if new_line: + yield _new_line_segment + yield from console.render(element, context.options) + elif self_closing: # SELF-CLOSING tags (e.g. text, code, image) + context.stack.pop() + text = token.content + if text is not None: + element.on_text(context, text) + + should_render = ( + not context.stack + or context.stack + and context.stack.top.on_child_close(context, element) + ) + if should_render: + if new_line: + yield _new_line_segment + yield from console.render(element, context.options) + + if exiting or self_closing: + element.on_leave(context) + new_line = element.new_line + + +if __name__ == "__main__": # pragma: no cover + + import argparse + import sys + + parser = argparse.ArgumentParser( + description="Render Markdown to the console with Rich" + ) + parser.add_argument( + "path", + metavar="PATH", + help="path to markdown file, or - for stdin", + ) + parser.add_argument( + "-c", + "--force-color", + dest="force_color", + action="store_true", + default=None, + help="force color for non-terminals", + ) + parser.add_argument( + "-t", + "--code-theme", + dest="code_theme", + default="monokai", + help="pygments code theme", + ) + parser.add_argument( + "-i", + "--inline-code-lexer", + dest="inline_code_lexer", + default=None, + help="inline_code_lexer", + ) + parser.add_argument( + "-y", + "--hyperlinks", + dest="hyperlinks", + action="store_true", + help="enable hyperlinks", + ) + parser.add_argument( + "-w", + "--width", + type=int, + dest="width", + default=None, + help="width of output (default will auto-detect)", + ) + parser.add_argument( + "-j", + "--justify", + dest="justify", + action="store_true", + help="enable full text justify", + ) + parser.add_argument( + "-p", + "--page", + dest="page", + action="store_true", + help="use pager to scroll output", + ) + args = parser.parse_args() + + from rich.console import Console + + if args.path == "-": + markdown_body = sys.stdin.read() + else: + with open(args.path, "rt", encoding="utf-8") as markdown_file: + markdown_body = markdown_file.read() + markdown = CustomMarkdown( + markdown_body, + justify="full" if args.justify else "left", + code_theme=args.code_theme, + hyperlinks=args.hyperlinks, + inline_code_lexer=args.inline_code_lexer, + ) + if args.page: + import io + import pydoc + + fileio = io.StringIO() + console = Console( + file=fileio, force_terminal=args.force_color, width=args.width + ) + console.print(markdown) + pydoc.pager(fileio.getvalue()) + + else: + console = Console( + force_terminal=args.force_color, width=args.width, record=True + ) + console.print(markdown)