From f018c6c1508b7a6f94964f77acae6727c3561557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Wed, 4 Sep 2024 15:43:55 +0200 Subject: [PATCH] fix: fixed typing in ncurses.py --- gptme/ncurses.py | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/gptme/ncurses.py b/gptme/ncurses.py index 33513b1d..aea69505 100644 --- a/gptme/ncurses.py +++ b/gptme/ncurses.py @@ -1,28 +1,29 @@ import curses import textwrap +from typing import List, Optional class Message: - def __init__(self, content, role="user"): - self.content = content - self.expanded = False - self.role = role + def __init__(self, content: str, role: str = "user"): + self.content: str = content + self.expanded: bool = False + self.role: str = role class MessageApp: def __init__(self, stdscr): self.stdscr = stdscr - self.messages = [] - self.input_buffer = "" - self.cursor_y = 0 - self.cursor_x = 0 - self.scroll_offset = 0 - self.mode = "normal" - self.selected_message = None - self.current_role = "user" - - def add_message(self, content): + self.messages: List[Message] = [] + self.input_buffer: str = "" + self.cursor_y: int = 0 + self.cursor_x: int = 0 + self.scroll_offset: int = 0 + self.mode: str = "normal" + self.selected_message: Optional[Message] = None + self.current_role: str = "user" + + def add_message(self, content: str) -> None: self.messages.append(Message(content, self.current_role)) - def draw(self): + def draw(self) -> None: self.stdscr.clear() height, width = self.stdscr.getmaxyx() @@ -71,7 +72,7 @@ def draw(self): self.stdscr.refresh() - def run(self): + def run(self) -> None: curses.curs_set(1) curses.start_color() curses.init_pair(curses.COLOR_GREEN, curses.COLOR_GREEN, curses.COLOR_BLACK) @@ -129,23 +130,23 @@ def run(self): if key == 27: # ESC self.mode = "normal" self.selected_message = None - elif key == ord('e'): + elif key == ord('e') and self.selected_message is not None: self.mode = "edit" self.input_buffer = self.selected_message.content self.cursor_x = len(self.input_buffer) - elif key == ord('x'): + elif key == ord('x') and self.selected_message is not None: self.selected_message.expanded = not self.selected_message.expanded - elif key == ord('d'): + elif key == ord('d') and self.selected_message is not None: self.messages.remove(self.selected_message) if self.messages: self.selected_message = self.messages[0] else: self.selected_message = None self.mode = "normal" - elif key == curses.KEY_UP and self.messages: + elif key == curses.KEY_UP and self.messages and self.selected_message is not None: idx = self.messages.index(self.selected_message) self.selected_message = self.messages[max(0, idx - 1)] - elif key == curses.KEY_DOWN and self.messages: + elif key == curses.KEY_DOWN and self.messages and self.selected_message is not None: idx = self.messages.index(self.selected_message) self.selected_message = self.messages[min(len(self.messages) - 1, idx + 1)] @@ -154,7 +155,7 @@ def run(self): self.mode = "select" self.input_buffer = "" self.cursor_x = 0 - elif key == 10: # Enter + elif key == 10 and self.selected_message is not None: # Enter self.selected_message.content = self.input_buffer self.mode = "select" self.input_buffer = ""