From e3afd5aeb9a91f76649fad9e496ea1f61e24e3f0 Mon Sep 17 00:00:00 2001 From: jesicasusanto Date: Tue, 25 Jul 2023 10:21:25 -0700 Subject: [PATCH] fix merge conflict --- openadapt/models.py | 109 +++++++++++++++++++++++-------- openadapt/strategies/stateful.py | 86 +++++++++++++++--------- 2 files changed, 138 insertions(+), 57 deletions(-) diff --git a/openadapt/models.py b/openadapt/models.py index 50c5c7983..39848dcc8 100644 --- a/openadapt/models.py +++ b/openadapt/models.py @@ -1,26 +1,38 @@ +"""This module defines the models used in the OpenAdapt system.""" + +from typing import Union import io from loguru import logger -from pynput import keyboard from PIL import Image, ImageChops +from pynput import keyboard import numpy as np import sqlalchemy as sa -from openadapt import config, db, utils, window +from openadapt import config, db, window # https://groups.google.com/g/sqlalchemy/c/wlr7sShU6-k class ForceFloat(sa.TypeDecorator): + """Custom SQLAlchemy type decorator for floating-point numbers.""" + impl = sa.Numeric(10, 2, asdecimal=False) cache_ok = True - def process_result_value(self, value, dialect): + def process_result_value( + self, + value: int | float | str | None, + dialect: str, + ) -> float | None: + """Convert the result value to float.""" if value is not None: value = float(value) return value class Recording(db.Base): + """Class representing a recording in the database.""" + __tablename__ = "recording" id = sa.Column(sa.Integer, primary_key=True) @@ -51,7 +63,8 @@ class Recording(db.Base): _processed_action_events = None @property - def processed_action_events(self): + def processed_action_events(self) -> list: + """Get the processed action events for the recording.""" from openadapt import events if not self._processed_action_events: @@ -60,6 +73,8 @@ def processed_action_events(self): class ActionEvent(db.Base): + """Class representing an action event in the database.""" + __tablename__ = "action_event" id = sa.Column(sa.Integer, primary_key=True) @@ -86,8 +101,12 @@ class ActionEvent(db.Base): children = sa.orm.relationship("ActionEvent") # TODO: replacing the above line with the following two results in an error: # AttributeError: 'list' object has no attribute '_sa_instance_state' - # children = sa.orm.relationship("ActionEvent", remote_side=[id], back_populates="parent") - # parent = sa.orm.relationship("ActionEvent", remote_side=[parent_id], back_populates="children") + # children = sa.orm.relationship( + # "ActionEvent", remote_side=[id], back_populates="parent" + # ) + # parent = sa.orm.relationship( + # "ActionEvent", remote_side=[parent_id], back_populates="children" + # ) # noqa: E501 recording = sa.orm.relationship("Recording", back_populates="action_events") screenshot = sa.orm.relationship("Screenshot", back_populates="action_event") @@ -95,7 +114,10 @@ class ActionEvent(db.Base): # TODO: playback_timestamp / original_timestamp - def _key(self, key_name, key_char, key_vk): + def _key( + self, key_name: str, key_char: str, key_vk: str + ) -> Union[keyboard.Key, keyboard.KeyCode, str, None]: + """Helper method to determine the key attribute based on available data.""" if key_name: key = keyboard.Key[key_name] elif key_char: @@ -108,7 +130,8 @@ def _key(self, key_name, key_char, key_vk): return key @property - def key(self): + def key(self) -> Union[keyboard.Key, keyboard.KeyCode, str, None]: + """Get the key associated with the action event.""" logger.trace(f"{self.name=} {self.key_name=} {self.key_char=} {self.key_vk=}") return self._key( self.key_name, @@ -117,7 +140,8 @@ def key(self): ) @property - def canonical_key(self): + def canonical_key(self) -> Union[keyboard.Key, keyboard.KeyCode, str, None]: + """Get the canonical key associated with the action event.""" logger.trace( f"{self.name=} " f"{self.canonical_key_name=} " @@ -130,7 +154,8 @@ def canonical_key(self): self.canonical_key_vk, ) - def _text(self, canonical=False): + def _text(self, canonical: bool = False) -> str | None: + """Helper method to generate the text representation of the action event.""" sep = config.ACTION_TEXT_SEP name_prefix = config.ACTION_TEXT_NAME_PREFIX name_suffix = config.ACTION_TEXT_NAME_SUFFIX @@ -163,14 +188,17 @@ def _text(self, canonical=False): return text @property - def text(self): + def text(self) -> str: + """Get the text representation of the action event.""" return self._text() @property - def canonical_text(self): + def canonical_text(self) -> str: + """Get the canonical text representation of the action event.""" return self._text(canonical=True) - def __str__(self): + def __str__(self) -> str: + """Return a string representation of the action event.""" attr_names = [ "name", "mouse_x", @@ -193,15 +221,23 @@ def __str__(self): return rval @classmethod - def from_children(cls, children_dicts): - if children_dicts: - children = [ActionEvent(**child_dict) for child_dict in children_dicts] - return ActionEvent(children=children) - else: - return None + def from_children(cls: list, children_dicts: list) -> "ActionEvent": + """Create an ActionEvent instance from a list of child event dictionaries. + + Args: + children_dicts (list): List of dictionaries representing child events. + + Returns: + ActionEvent: An instance of ActionEvent with the specified child events. + + """ + children = [ActionEvent(**child_dict) for child_dict in children_dicts] + return ActionEvent(children=children) class Screenshot(db.Base): + """Class representing a screenshot in the database.""" + __tablename__ = "screenshot" id = sa.Column(sa.Integer, primary_key=True) @@ -222,7 +258,8 @@ class Screenshot(db.Base): _diff_mask = None @property - def image(self): + def image(self) -> Image: + """Get the image associated with the screenshot.""" if not self._image: if self.sct_img: self._image = Image.frombytes( @@ -238,30 +275,41 @@ def image(self): return self._image @property - def diff(self): + def diff(self) -> Image: + """Get the difference between the current and previous screenshot.""" if not self._diff: assert self.prev, "Attempted to compute diff before setting prev" self._diff = ImageChops.difference(self.image, self.prev.image) return self._diff @property - def diff_mask(self): + def diff_mask(self) -> Image: + """Get the difference mask of the screenshot.""" if not self._diff_mask: if self.diff: self._diff_mask = self.diff.convert("1") return self._diff_mask @property - def array(self): + def array(self) -> np.ndarray: + """Get the NumPy array representation of the image.""" return np.array(self.image) @classmethod - def take_screenshot(cls): + def take_screenshot(cls: "Screenshot") -> "Screenshot": + """Capture a screenshot.""" + # avoid circular import + from openadapt import utils + sct_img = utils.take_screenshot() screenshot = Screenshot(sct_img=sct_img) return screenshot - def crop_active_window(self, action_event): + def crop_active_window(self, action_event: ActionEvent) -> None: + """Crop the screenshot to the active window defined by the action event.""" + # avoid circular import + from openadapt import utils + window_event = action_event.window_event width_ratio, height_ratio = utils.get_scale_ratios(action_event) @@ -275,6 +323,8 @@ def crop_active_window(self, action_event): class WindowEvent(db.Base): + """Class representing a window event in the database.""" + __tablename__ = "window_event" id = sa.Column(sa.Integer, primary_key=True) @@ -292,11 +342,14 @@ class WindowEvent(db.Base): action_events = sa.orm.relationship("ActionEvent", back_populates="window_event") @classmethod - def get_active_window_event(cls): + def get_active_window_event(cls: "WindowEvent") -> "WindowEvent": + """Get the active window event.""" return WindowEvent(**window.get_active_window_data()) class PerformanceStat(db.Base): + """Class representing a performance statistic in the database.""" + __tablename__ = "performance_stat" id = sa.Column(sa.Integer, primary_key=True) @@ -308,9 +361,11 @@ class PerformanceStat(db.Base): class MemoryStat(db.Base): + """Class representing a memory usage statistic in the database.""" + __tablename__ = "memory_stat" id = sa.Column(sa.Integer, primary_key=True) recording_timestamp = sa.Column(sa.Integer) memory_usage_bytes = sa.Column(ForceFloat) - timestamp = sa.Column(ForceFloat) + timestamp = sa.Column(ForceFloat) \ No newline at end of file diff --git a/openadapt/strategies/stateful.py b/openadapt/strategies/stateful.py index 17867746f..dff3a6d53 100644 --- a/openadapt/strategies/stateful.py +++ b/openadapt/strategies/stateful.py @@ -1,5 +1,4 @@ -""" -LLM with window states. +"""LLM with window states. Usage: @@ -9,15 +8,14 @@ from copy import deepcopy from pprint import pformat -# import datetime - from loguru import logger import deepdiff -import numpy as np -from openadapt import config, events, models, strategies, utils +from openadapt import models, strategies, utils from openadapt.strategies.mixins.openai import OpenAIReplayStrategyMixin +# import datetime + IGNORE_BOUNDARY_WINDOWS = True @@ -26,10 +24,17 @@ class StatefulReplayStrategy( OpenAIReplayStrategyMixin, strategies.base.BaseReplayStrategy, ): + """LLM with window states.""" + def __init__( self, recording: models.Recording, - ): + ) -> None: + """Initialize the StatefulReplayStrategy. + + Args: + recording (models.Recording): The recording object. + """ super().__init__(recording) self.recording_window_state_diffs = get_window_state_diffs( recording.processed_action_events @@ -39,7 +44,8 @@ def __init__( for action_event in self.recording.processed_action_events ][:-1] self.recording_action_diff_tups = zip( - self.recording_window_state_diffs, self.recording_action_strs + self.recording_window_state_diffs, + self.recording_action_strs, ) self.recording_action_idx = 0 @@ -47,7 +53,16 @@ def get_next_action_event( self, active_screenshot: models.Screenshot, active_window: models.WindowEvent, - ): + ) -> models.ActionEvent: + """Get the next ActionEvent for replay. + + Args: + active_screenshot (models.Screenshot): The active screenshot object. + active_window (models.WindowEvent): The active window event object. + + Returns: + models.ActionEvent: The next ActionEvent for replay. + """ logger.debug(f"{self.recording_action_idx=}") if self.recording_action_idx == len(self.recording.processed_action_events): raise StopIteration() @@ -105,16 +120,16 @@ def get_next_action_event( # and not isinstance(getattr(models.WindowEvent, key), property) } ) - if reference_window_dict and "state" in reference_window_dict: - reference_window_dict["state"].pop("data") - if active_window_dict and "state" in active_window_dict: - active_window_dict["state"].pop("data") + reference_window_dict["state"].pop("data") + active_window_dict["state"].pop("data") prompt = ( f"{reference_window_dict=}\n" f"{reference_action_dicts=}\n" f"{active_window_dict=}\n" - "Provide valid Python3 code containing the action dicts by completing the following, and nothing else:\n" + "Provide valid Python3 code containing the action dicts" + " by completing the following," + " and nothing else:\n" "active_action_dicts=" ) system_message = ( @@ -138,7 +153,15 @@ def get_next_action_event( return active_action -def get_action_dict_from_completion(completion): +def get_action_dict_from_completion(completion: str) -> dict[models.ActionEvent]: + """Convert the completion to a dictionary containing action information. + + Args: + completion (str): The completion provided by the user. + + Returns: + dict: The action dictionary. + """ try: action = eval(completion) except Exception as exc: @@ -148,24 +171,27 @@ def get_action_dict_from_completion(completion): def get_window_state_diffs( - action_events, - ignore_boundary_windows=IGNORE_BOUNDARY_WINDOWS, -): + action_events: list[models.ActionEvent], + ignore_boundary_windows: bool = IGNORE_BOUNDARY_WINDOWS, +) -> list[deepdiff.DeepDiff]: + """Get the differences in window state between consecutive action events. + + Args: + action_events (list[models.ActionEvent]): The list of action events. + ignore_boundary_windows (bool): Flag to ignore boundary windows. + Defaults to True. + + Returns: + list[deepdiff.DeepDiff]: list of deep diffs for window state differences. + """ ignore_window_ids = set() if ignore_boundary_windows: first_window_event = action_events[0].window_event - if first_window_event.state: - first_window_id = first_window_event.state["window_id"] - else: - first_window_id = None - first_window_title = first_window_event.title + first_window_id = first_window_event.state["window_id"] + first_window_title = first_window_event.title last_window_event = action_events[-1].window_event - if last_window_event.state: - last_window_id = last_window_event.state["window_id"] - last_window_title = last_window_event.title - else: - last_window_id = None - last_window_title = last_window_event.title + last_window_id = last_window_event.state["window_id"] + last_window_title = last_window_event.title if first_window_id != last_window_id: logger.warning(f"{first_window_id=} != {last_window_id=}") ignore_window_ids.add(first_window_id) @@ -183,4 +209,4 @@ def get_window_state_diffs( window_event_states, window_event_states[1:] ) ] - return diffs + return diffs \ No newline at end of file