diff --git a/openadapt/models.py b/openadapt/models.py index f76528200..ba3d020b6 100644 --- a/openadapt/models.py +++ b/openadapt/models.py @@ -1,26 +1,38 @@ +"""This module defines the models used in the OpenAdapt system.""" + +from typing import Union import io from loguru import logger -from pynput import keyboard from PIL import Image, ImageChops +from pynput import keyboard import numpy as np import sqlalchemy as sa -from openadapt import config, db, utils, window +from openadapt import config, db, window # https://groups.google.com/g/sqlalchemy/c/wlr7sShU6-k class ForceFloat(sa.TypeDecorator): + """Custom SQLAlchemy type decorator for floating-point numbers.""" + impl = sa.Numeric(10, 2, asdecimal=False) cache_ok = True - def process_result_value(self, value, dialect): + def process_result_value( + self, + value: int | float | str | None, + dialect: str, + ) -> float | None: + """Convert the result value to float.""" if value is not None: value = float(value) return value class Recording(db.Base): + """Class representing a recording in the database.""" + __tablename__ = "recording" id = sa.Column(sa.Integer, primary_key=True) @@ -51,15 +63,18 @@ class Recording(db.Base): _processed_action_events = None @property - def processed_action_events(self): + def processed_action_events(self) -> list: + """Get the processed action events for the recording.""" from openadapt import events + if not self._processed_action_events: self._processed_action_events = events.get_events(self) return self._processed_action_events - class ActionEvent(db.Base): + """Class representing an action event in the database.""" + __tablename__ = "action_event" id = sa.Column(sa.Integer, primary_key=True) @@ -86,8 +101,12 @@ class ActionEvent(db.Base): children = sa.orm.relationship("ActionEvent") # TODO: replacing the above line with the following two results in an error: # AttributeError: 'list' object has no attribute '_sa_instance_state' - #children = sa.orm.relationship("ActionEvent", remote_side=[id], back_populates="parent") - #parent = sa.orm.relationship("ActionEvent", remote_side=[parent_id], back_populates="children") + # children = sa.orm.relationship( + # "ActionEvent", remote_side=[id], back_populates="parent" + # ) + # parent = sa.orm.relationship( + # "ActionEvent", remote_side=[parent_id], back_populates="children" + # ) # noqa: E501 recording = sa.orm.relationship("Recording", back_populates="action_events") screenshot = sa.orm.relationship("Screenshot", back_populates="action_event") @@ -95,7 +114,10 @@ class ActionEvent(db.Base): # TODO: playback_timestamp / original_timestamp - def _key(self, key_name, key_char, key_vk): + def _key( + self, key_name: str, key_char: str, key_vk: str + ) -> Union[keyboard.Key, keyboard.KeyCode, str, None]: + """Helper method to determine the key attribute based on available data.""" if key_name: key = keyboard.Key[key_name] elif key_char: @@ -108,10 +130,9 @@ def _key(self, key_name, key_char, key_vk): return key @property - def key(self): - logger.trace( - f"{self.name=} {self.key_name=} {self.key_char=} {self.key_vk=}" - ) + def key(self) -> Union[keyboard.Key, keyboard.KeyCode, str, None]: + """Get the key associated with the action event.""" + logger.trace(f"{self.name=} {self.key_name=} {self.key_char=} {self.key_vk=}") return self._key( self.key_name, self.key_char, @@ -119,7 +140,8 @@ def key(self): ) @property - def canonical_key(self): + def canonical_key(self) -> Union[keyboard.Key, keyboard.KeyCode, str, None]: + """Get the canonical key associated with the action event.""" logger.trace( f"{self.name=} " f"{self.canonical_key_name=} " @@ -132,7 +154,8 @@ def canonical_key(self): self.canonical_key_vk, ) - def _text(self, canonical=False): + def _text(self, canonical: bool = False) -> str | None: + """Helper method to generate the text representation of the action event.""" sep = config.ACTION_TEXT_SEP name_prefix = config.ACTION_TEXT_NAME_PREFIX name_suffix = config.ACTION_TEXT_NAME_SUFFIX @@ -157,21 +180,25 @@ def _text(self, canonical=False): else: if key_name_attr: text = f"{name_prefix}{key_attr}{name_suffix}".replace( - "Key.", "", + "Key.", + "", ) else: text = key_attr return text @property - def text(self): + def text(self) -> str: + """Get the text representation of the action event.""" return self._text() @property - def canonical_text(self): + def canonical_text(self) -> str: + """Get the canonical text representation of the action event.""" return self._text(canonical=True) - def __str__(self): + def __str__(self) -> str: + """Return a string representation of the action event.""" attr_names = [ "name", "mouse_x", @@ -183,16 +210,8 @@ def __str__(self): "text", "element_state", ] - attrs = [ - getattr(self, attr_name) - for attr_name in attr_names - ] - attrs = [ - int(attr) - if isinstance(attr, float) - else attr - for attr in attrs - ] + attrs = [getattr(self, attr_name) for attr_name in attr_names] + attrs = [int(attr) if isinstance(attr, float) else attr for attr in attrs] attrs = [ f"{attr_name}=`{attr}`" for attr_name, attr in zip(attr_names, attrs) @@ -202,21 +221,31 @@ def __str__(self): return rval @classmethod - def from_children(cls, children_dicts): - children = [ - ActionEvent(**child_dict) - for child_dict in children_dicts - ] + def from_children(cls: list, children_dicts: list) -> "ActionEvent": + """Create an ActionEvent instance from a list of child event dictionaries. + + Args: + children_dicts (list): List of dictionaries representing child events. + + Returns: + ActionEvent: An instance of ActionEvent with the specified child events. + + """ + children = [ActionEvent(**child_dict) for child_dict in children_dicts] return ActionEvent(children=children) class Screenshot(db.Base): + """Class representing a screenshot in the database.""" + __tablename__ = "screenshot" id = sa.Column(sa.Integer, primary_key=True) recording_timestamp = sa.Column(sa.ForeignKey("recording.timestamp")) timestamp = sa.Column(ForceFloat) png_data = sa.Column(sa.LargeBinary) + png_diff_data = sa.Column(sa.LargeBinary, nullable=True) + png_diff_mask_data = sa.Column(sa.LargeBinary, nullable=True) recording = sa.orm.relationship("Recording", back_populates="screenshots") action_event = sa.orm.relationship("ActionEvent", back_populates="screenshot") @@ -231,7 +260,8 @@ class Screenshot(db.Base): _diff_mask = None @property - def image(self): + def image(self) -> Image: + """Get the image associated with the screenshot.""" if not self._image: if self.sct_img: self._image = Image.frombytes( @@ -242,35 +272,49 @@ def image(self): "BGRX", ) else: - buffer = io.BytesIO(self.png_data) - self._image = Image.open(buffer) + self._image = self.convert_binary_to_png(self.png_data) return self._image @property - def diff(self): - if not self._diff: - assert self.prev, "Attempted to compute diff before setting prev" - self._diff = ImageChops.difference(self.image, self.prev.image) + def diff(self) -> Image: + """Get the difference between the current and previous screenshot.""" + if self.png_diff_data: + return self.convert_binary_to_png(self.png_diff_data) + + assert self.prev, "Attempted to compute diff before setting prev" + self._diff = ImageChops.difference(self.image, self.prev.image) return self._diff @property - def diff_mask(self): - if not self._diff_mask: - if self.diff: - self._diff_mask = self.diff.convert("1") + def diff_mask(self) -> Image: + """Get the difference mask between the current and previous screenshot.""" + if self.png_diff_mask_data: + return self.convert_binary_to_png(self.png_diff_mask_data) + + if self.diff: + self._diff_mask = self.diff.convert("1") return self._diff_mask @property - def array(self): + def array(self) -> np.ndarray: + """Get the NumPy array representation of the image.""" return np.array(self.image) @classmethod - def take_screenshot(cls): + def take_screenshot(cls: "Screenshot") -> "Screenshot": + """Capture a screenshot.""" + # avoid circular import + from openadapt import utils + sct_img = utils.take_screenshot() screenshot = Screenshot(sct_img=sct_img) return screenshot - def crop_active_window(self, action_event): + def crop_active_window(self, action_event: ActionEvent) -> None: + """Crop the screenshot to the active window defined by the action event.""" + # avoid circular import + from openadapt import utils + window_event = action_event.window_event width_ratio, height_ratio = utils.get_scale_ratios(action_event) @@ -282,8 +326,35 @@ def crop_active_window(self, action_event): box = (x0, y0, x1, y1) self._image = self._image.crop(box) + def convert_binary_to_png(self, image_binary: bytes) -> Image: + """Convert a binary image to a PNG image. + + Args: + image_binary (bytes): The binary image data. + + Returns: + Image: The PNG image. + """ + buffer = io.BytesIO(image_binary) + return Image.open(buffer) + + def convert_png_to_binary(self, image: Image) -> bytes: + """Convert a PNG image to binary image data. + + Args: + image (Image): The PNG image. + + Returns: + bytes: The binary image data. + """ + buffer = io.BytesIO() + image.save(buffer, format="PNG") + return buffer.getvalue() + class WindowEvent(db.Base): + """Class representing a window event in the database.""" + __tablename__ = "window_event" id = sa.Column(sa.Integer, primary_key=True) @@ -301,11 +372,14 @@ class WindowEvent(db.Base): action_events = sa.orm.relationship("ActionEvent", back_populates="window_event") @classmethod - def get_active_window_event(cls): + def get_active_window_event(cls: "WindowEvent") -> "WindowEvent": + """Get the active window event.""" return WindowEvent(**window.get_active_window_data()) class PerformanceStat(db.Base): + """Class representing a performance statistic in the database.""" + __tablename__ = "performance_stat" id = sa.Column(sa.Integer, primary_key=True) @@ -317,6 +391,8 @@ class PerformanceStat(db.Base): class MemoryStat(db.Base): + """Class representing a memory usage statistic in the database.""" + __tablename__ = "memory_stat" id = sa.Column(sa.Integer, primary_key=True) diff --git a/openadapt/strategies/stateful.py b/openadapt/strategies/stateful.py index 1f34e0c20..dff3a6d53 100644 --- a/openadapt/strategies/stateful.py +++ b/openadapt/strategies/stateful.py @@ -1,5 +1,4 @@ -""" -LLM with window states. +"""LLM with window states. Usage: @@ -8,15 +7,15 @@ from copy import deepcopy from pprint import pformat -#import datetime from loguru import logger import deepdiff -import numpy as np -from openadapt import config, events, models, strategies, utils +from openadapt import models, strategies, utils from openadapt.strategies.mixins.openai import OpenAIReplayStrategyMixin +# import datetime + IGNORE_BOUNDARY_WINDOWS = True @@ -25,11 +24,17 @@ class StatefulReplayStrategy( OpenAIReplayStrategyMixin, strategies.base.BaseReplayStrategy, ): + """LLM with window states.""" def __init__( self, recording: models.Recording, - ): + ) -> None: + """Initialize the StatefulReplayStrategy. + + Args: + recording (models.Recording): The recording object. + """ super().__init__(recording) self.recording_window_state_diffs = get_window_state_diffs( recording.processed_action_events @@ -40,7 +45,7 @@ def __init__( ][:-1] self.recording_action_diff_tups = zip( self.recording_window_state_diffs, - self.recording_action_strs + self.recording_action_strs, ) self.recording_action_idx = 0 @@ -48,54 +53,73 @@ def get_next_action_event( self, active_screenshot: models.Screenshot, active_window: models.WindowEvent, - ): + ) -> models.ActionEvent: + """Get the next ActionEvent for replay. + + Args: + active_screenshot (models.Screenshot): The active screenshot object. + active_window (models.WindowEvent): The active window event object. + + Returns: + models.ActionEvent: The next ActionEvent for replay. + """ logger.debug(f"{self.recording_action_idx=}") if self.recording_action_idx == len(self.recording.processed_action_events): raise StopIteration() - reference_action = ( - self.recording.processed_action_events[self.recording_action_idx] - ) + reference_action = self.recording.processed_action_events[ + self.recording_action_idx + ] reference_window = reference_action.window_event - reference_window_dict = deepcopy({ - key: val - for key, val in utils.row2dict(reference_window, follow=False).items() - if val is not None - and not key.endswith("timestamp") - and not key.endswith("id") - #and not isinstance(getattr(models.WindowEvent, key), property) - }) + reference_window_dict = deepcopy( + { + key: val + for key, val in utils.row2dict(reference_window, follow=False).items() + if val is not None + and not key.endswith("timestamp") + and not key.endswith("id") + # and not isinstance(getattr(models.WindowEvent, key), property) + } + ) if reference_action.children: reference_action_dicts = [ - deepcopy({ - key: val - for key, val in utils.row2dict(child, follow=False).items() - if val is not None - and not key.endswith("timestamp") - and not key.endswith("id") - and not isinstance(getattr(models.ActionEvent, key), property) - }) + deepcopy( + { + key: val + for key, val in utils.row2dict(child, follow=False).items() + if val is not None + and not key.endswith("timestamp") + and not key.endswith("id") + and not isinstance(getattr(models.ActionEvent, key), property) + } + ) for child in reference_action.children ] else: reference_action_dicts = [ - deepcopy({ - key: val - for key, val in utils.row2dict(reference_action, follow=False).items() - if val is not None - and not key.endswith("timestamp") - and not key.endswith("id") - #and not isinstance(getattr(models.ActionEvent, key), property) - }) + deepcopy( + { + key: val + for key, val in utils.row2dict( + reference_action, follow=False + ).items() + if val is not None + and not key.endswith("timestamp") + and not key.endswith("id") + # and not isinstance(getattr(models.ActionEvent, key), property) + } + ) ] - active_window_dict = deepcopy({ - key: val - for key, val in utils.row2dict(active_window, follow=False).items() - if val is not None - and not key.endswith("timestamp") - and not key.endswith("id") - #and not isinstance(getattr(models.WindowEvent, key), property) - }) + active_window_dict = deepcopy( + { + key: val + for key, val in utils.row2dict(active_window, follow=False).items() + if val is not None + and not key.endswith("timestamp") + and not key.endswith("id") + # and not isinstance(getattr(models.WindowEvent, key), property) + } + ) reference_window_dict["state"].pop("data") active_window_dict["state"].pop("data") @@ -103,7 +127,9 @@ def get_next_action_event( f"{reference_window_dict=}\n" f"{reference_action_dicts=}\n" f"{active_window_dict=}\n" - "Provide valid Python3 code containing the action dicts by completing the following, and nothing else:\n" + "Provide valid Python3 code containing the action dicts" + " by completing the following," + " and nothing else:\n" "active_action_dicts=" ) system_message = ( @@ -127,7 +153,15 @@ def get_next_action_event( return active_action -def get_action_dict_from_completion(completion): +def get_action_dict_from_completion(completion: str) -> dict[models.ActionEvent]: + """Convert the completion to a dictionary containing action information. + + Args: + completion (str): The completion provided by the user. + + Returns: + dict: The action dictionary. + """ try: action = eval(completion) except Exception as exc: @@ -137,9 +171,19 @@ def get_action_dict_from_completion(completion): def get_window_state_diffs( - action_events, - ignore_boundary_windows=IGNORE_BOUNDARY_WINDOWS, -): + action_events: list[models.ActionEvent], + ignore_boundary_windows: bool = IGNORE_BOUNDARY_WINDOWS, +) -> list[deepdiff.DeepDiff]: + """Get the differences in window state between consecutive action events. + + Args: + action_events (list[models.ActionEvent]): The list of action events. + ignore_boundary_windows (bool): Flag to ignore boundary windows. + Defaults to True. + + Returns: + list[deepdiff.DeepDiff]: list of deep diffs for window state differences. + """ ignore_window_ids = set() if ignore_boundary_windows: first_window_event = action_events[0].window_event diff --git a/openadapt/widget/assets/available.png b/openadapt/widget/assets/available.png new file mode 100644 index 000000000..196979f9e Binary files /dev/null and b/openadapt/widget/assets/available.png differ diff --git a/openadapt/widget/assets/logo.png b/openadapt/widget/assets/logo.png new file mode 100644 index 000000000..64aa029bd Binary files /dev/null and b/openadapt/widget/assets/logo.png differ diff --git a/openadapt/widget/assets/recording_inprogress.png b/openadapt/widget/assets/recording_inprogress.png new file mode 100644 index 000000000..3040dc65f Binary files /dev/null and b/openadapt/widget/assets/recording_inprogress.png differ diff --git a/openadapt/widget/assets/replay_available.png b/openadapt/widget/assets/replay_available.png new file mode 100644 index 000000000..30a041d69 Binary files /dev/null and b/openadapt/widget/assets/replay_available.png differ diff --git a/openadapt/widget/assets/replay_inprogress.png b/openadapt/widget/assets/replay_inprogress.png new file mode 100644 index 000000000..37f476921 Binary files /dev/null and b/openadapt/widget/assets/replay_inprogress.png differ diff --git a/openadapt/widget/assets/replay_paused.png b/openadapt/widget/assets/replay_paused.png new file mode 100644 index 000000000..c7574c0f6 Binary files /dev/null and b/openadapt/widget/assets/replay_paused.png differ diff --git a/openadapt/widget/widget.py b/openadapt/widget/widget.py new file mode 100644 index 000000000..e81dcc82b --- /dev/null +++ b/openadapt/widget/widget.py @@ -0,0 +1,87 @@ +import sys +from subprocess import Popen +import signal +import threading +import time + +from PySide6 import QtWidgets, QtCore, QtGui + + +class SystemTrayIcon(QtWidgets.QSystemTrayIcon): + def __init__(self, icon, parent=None): + QtWidgets.QSystemTrayIcon.__init__(self, icon, parent) + self.setToolTip("OpenAdapt") + self.menu = QtWidgets.QMenu() + self.action = self.menu.addAction("Exit") + self.action.triggered.connect(parent.quit) + self.setContextMenu(self.menu) + self.activated.connect(self.update_icon) + + self.icon_recording = QtGui.QIcon("assets/recording_inprogress.png") + self.icon_replay_available = QtGui.QIcon("assets/replay_available.png") + self.icon_replaying = QtGui.QIcon("assets/replay_inprogress.png") + self.icon_paused = QtGui.QIcon("assets/replay_paused.png") + self.icon_default = QtGui.QIcon("assets/logo.png") + self.setIcon(self.icon_default) + self.current_state = "default" + + def update_icon(self, reason): + if reason == QtWidgets.QSystemTrayIcon.Trigger: + if self.current_state == "default": + self.setIcon(self.icon_recording) + self.current_state = "recording_in_progress" + self.start_recording() + elif self.current_state == "recording_in_progress": + self.setIcon(self.icon_replay_available) + self.current_state = "replay_available" + self.stop_recording() + elif self.current_state == "replay_available": + self.setIcon(self.icon_replaying) + self.current_state = "replaying_in_progress" + self.replay_recording() + elif self.current_state == "replaying_in_progress": + self.setIcon(self.icon_paused) + self.current_state = "replaying_paused" + self.pause_replay() + elif self.current_state == "replaying_paused": + self.setIcon(self.icon_replaying) + self.current_state = "replaying_in_progress" + self.resume_replay() + + def start_recording(self): + # poetry run? + self.record_proc = Popen( + "python -m openadapt.record " + "test", + shell=True, + ) + + def stop_recording(self): + if self.record_proc is not None: + if sys.platform == "win32": + self.record_proc.send_signal(signal.CTRL_BREAK_EVENT) + else: + self.record_proc.send_signal(signal.SIGINT) + self.record_proc.wait() + self.record_proc = None + + def replay_recording(self): + self.replay_proc = Popen( + "python -m openadapt.replay " + "NaiveReplayStrategy", + shell=True, + ) + + def pause_replay(self): + self.replay_proc.send_signal(signal.SIGSTOP) + + def resume_replay(self): + self.replay_proc.send_signal(signal.SIGCONT) + + +def run_widget(): + app = QtWidgets.QApplication(sys.argv) + + # w = QtWidgets.QWidget() + tray_icon = SystemTrayIcon(QtGui.QIcon("assets/logo.png"), app) + tray_icon.show() + + sys.exit(app.exec()) diff --git a/openadapt/window/__init__.py b/openadapt/window/__init__.py index f722d02f7..6df73fc99 100644 --- a/openadapt/window/__init__.py +++ b/openadapt/window/__init__.py @@ -49,4 +49,4 @@ def get_active_element_state(x, y): return impl.get_active_element_state(x, y) except Exception as exc: logger.warning(f"{exc=}") - return None + return None \ No newline at end of file diff --git a/openadapt/window/_macos.py b/openadapt/window/_macos.py index 57048ba0b..a61e37c98 100644 --- a/openadapt/window/_macos.py +++ b/openadapt/window/_macos.py @@ -226,4 +226,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/openadapt/window/_windows.py b/openadapt/window/_windows.py index 40b74250b..f01e978a1 100644 --- a/openadapt/window/_windows.py +++ b/openadapt/window/_windows.py @@ -188,4 +188,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index fe4175cd9..49feac2ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ pandas = "2.0.0" presidio_analyzer = "2.2.32" presidio_anonymizer = "2.2.32" presidio_image_redactor = "0.0.46" +pyside6 = "^6.5.1.1" pytesseract = "0.3.7" pytest = "7.1.3" rapidocr-onnxruntime = "1.2.3" @@ -96,6 +97,7 @@ visualize = "openadapt.visualize:main" record = "openadapt.record:start" replay = "openadapt.replay:start" app = "openadapt.app.main:run_app" +widget = "openadapt.widget.widget:run_widget" [tool.semantic_release] version_variable = [