diff --git a/examples/chat.py b/examples/chat.py index 952a0d2c..622b4873 100644 --- a/examples/chat.py +++ b/examples/chat.py @@ -1,5 +1,5 @@ -import sys, os +import sys, os, re sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from exllamav2 import( @@ -155,6 +155,11 @@ def get_tokenized_context(max_len): codeblock_formatter = None if args.no_code_formatting else CodeBlockFormatter() in_code_block = False +delim_buffer_array = [] +delim_pattern = re.compile(r'(`{1,3})') + +delim_overflow = "" + # Main loop print(f" -- Prompt format: {args.mode}") @@ -198,10 +203,17 @@ def get_tokenized_context(max_len): response_text += chunk responses_ids[-1] = torch.cat([responses_ids[-1], tokens], dim = -1) - # Check for code block delimiters + # Append chunk to delimiter buffer if contains delimiters + if delim_pattern.search(chunk) and len(delim_buffer_array) < 2: # dirty fix for assumption that codeblock start is never smaller than `` + ` + # add chunk + delim_buffer_array.append(chunk) + else: + delim_overflow = "".join(delim_buffer_array) + delim_buffer_array = [] - codeblock_delimiter = chunk.startswith("```") and codeblock_formatter is not None - if codeblock_delimiter: chunk = chunk[3:] # Suppress delimiter in output + # Check for code block delimiters + # if delim_buffer_array contains a full delimiter (```), codeblock true + codeblock_delimiter = "".join(delim_buffer_array).find("```") != -1 and (codeblock_formatter is not None) # Print output @@ -212,9 +224,13 @@ def get_tokenized_context(max_len): codeblock_formatter.begin() print("\n") in_code_block = True + delim_buffer_array = [] # Print unformatted - print(chunk, end = "") + # if delim buffer is > 0 do not print for now + if len(delim_buffer_array) == 0: + print(chunk, end = "") + sys.stdout.flush() else: @@ -223,6 +239,7 @@ def get_tokenized_context(max_len): if codeblock_delimiter: print("\033[0m", end = "") # Reset block color to be certain in_code_block = False + delim_buffer_array = [] # Print formatted codeblock_formatter.print_code_block(chunk) diff --git a/examples/chat_formatting.py b/examples/chat_formatting.py index 7d04ad19..229e86b0 100644 --- a/examples/chat_formatting.py +++ b/examples/chat_formatting.py @@ -15,11 +15,13 @@ # Code block formatter for black background + class BlackBackgroundTerminalFormatter(TerminalFormatter): code_pad: int = 2 block_pad_left: int = 1 + def __init__(self): super().__init__(style = "monokai") @@ -91,6 +93,7 @@ class CodeBlockFormatter: code_block_text: str lines_printed: int + last_lexer: str formatter = BlackBackgroundTerminalFormatter() @@ -100,6 +103,7 @@ def begin(self): self.code_block_text = "" self.lines_printed = 0 + self.last_lexer = get_lexer_by_name("text") self.formatter.begin() @@ -107,7 +111,7 @@ def begin(self): # Print a code block, updating the CLI in real-time def print_code_block(self, chunk): - + # Clear previously printed lines for _ in range(self.lines_printed): # -1 not needed? # Move cursor up one line @@ -126,7 +130,15 @@ def print_code_block(self, chunk): self.code_block_text += chunk # Remove language after codeblock start - code_block_text = re.sub(r'```.*?$', '```', self.code_block_text, flags=re.MULTILINE) + code_block_text = '\n'.join([''] + self.code_block_text.split('\n')[1:]) + + # Handle delim at end + if code_block_text.endswith("```"): + code_block_text = code_block_text[:-3] + + + # Get specified language + specified_lang = self.code_block_text.split('\n', 1)[0] # Get 1st line (directly after delimiter, can be language) # Split updated text into lines and find the longest line lines = code_block_text.split('\n') @@ -140,9 +152,17 @@ def print_code_block(self, chunk): # Try guessing the lexer for syntax highlighting, if we haven't guessed already try: - lexer = guess_lexer(padded_text) + if bool(specified_lang): + lexer = get_lexer_by_name(specified_lang) + self.last_lexer = lexer + elif '\n' in chunk: # Offload lexguessing to every newline + lexer = guess_lexer(padded_text) + self.last_lexer = lexer + else: + lexer = self.last_lexer except ClassNotFound: lexer = get_lexer_by_name("text") # Fallback to plain text if language isn't supported by pygments + self.last_lexer = lexer # Highlight highlighted_text = highlight(padded_text, lexer, self.formatter)