Skip to content

Commit

Permalink
fix merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
jesicasusanto committed Jul 25, 2023
1 parent dbed721 commit e3afd5a
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 57 deletions.
109 changes: 82 additions & 27 deletions openadapt/models.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
"""This module defines the models used in the OpenAdapt system."""

from typing import Union
import io

from loguru import logger
from pynput import keyboard
from PIL import Image, ImageChops
from pynput import keyboard
import numpy as np
import sqlalchemy as sa

from openadapt import config, db, utils, window
from openadapt import config, db, window


# https://groups.google.com/g/sqlalchemy/c/wlr7sShU6-k
class ForceFloat(sa.TypeDecorator):
"""Custom SQLAlchemy type decorator for floating-point numbers."""

impl = sa.Numeric(10, 2, asdecimal=False)
cache_ok = True

def process_result_value(self, value, dialect):
def process_result_value(
self,
value: int | float | str | None,
dialect: str,
) -> float | None:
"""Convert the result value to float."""
if value is not None:
value = float(value)
return value


class Recording(db.Base):
"""Class representing a recording in the database."""

__tablename__ = "recording"

id = sa.Column(sa.Integer, primary_key=True)
Expand Down Expand Up @@ -51,7 +63,8 @@ class Recording(db.Base):
_processed_action_events = None

@property
def processed_action_events(self):
def processed_action_events(self) -> list:
"""Get the processed action events for the recording."""
from openadapt import events

if not self._processed_action_events:
Expand All @@ -60,6 +73,8 @@ def processed_action_events(self):


class ActionEvent(db.Base):
"""Class representing an action event in the database."""

__tablename__ = "action_event"

id = sa.Column(sa.Integer, primary_key=True)
Expand All @@ -86,16 +101,23 @@ class ActionEvent(db.Base):
children = sa.orm.relationship("ActionEvent")
# TODO: replacing the above line with the following two results in an error:
# AttributeError: 'list' object has no attribute '_sa_instance_state'
# children = sa.orm.relationship("ActionEvent", remote_side=[id], back_populates="parent")
# parent = sa.orm.relationship("ActionEvent", remote_side=[parent_id], back_populates="children")
# children = sa.orm.relationship(
# "ActionEvent", remote_side=[id], back_populates="parent"
# )
# parent = sa.orm.relationship(
# "ActionEvent", remote_side=[parent_id], back_populates="children"
# ) # noqa: E501

recording = sa.orm.relationship("Recording", back_populates="action_events")
screenshot = sa.orm.relationship("Screenshot", back_populates="action_event")
window_event = sa.orm.relationship("WindowEvent", back_populates="action_events")

# TODO: playback_timestamp / original_timestamp

def _key(self, key_name, key_char, key_vk):
def _key(
self, key_name: str, key_char: str, key_vk: str
) -> Union[keyboard.Key, keyboard.KeyCode, str, None]:
"""Helper method to determine the key attribute based on available data."""
if key_name:
key = keyboard.Key[key_name]
elif key_char:
Expand All @@ -108,7 +130,8 @@ def _key(self, key_name, key_char, key_vk):
return key

@property
def key(self):
def key(self) -> Union[keyboard.Key, keyboard.KeyCode, str, None]:
"""Get the key associated with the action event."""
logger.trace(f"{self.name=} {self.key_name=} {self.key_char=} {self.key_vk=}")
return self._key(
self.key_name,
Expand All @@ -117,7 +140,8 @@ def key(self):
)

@property
def canonical_key(self):
def canonical_key(self) -> Union[keyboard.Key, keyboard.KeyCode, str, None]:
"""Get the canonical key associated with the action event."""
logger.trace(
f"{self.name=} "
f"{self.canonical_key_name=} "
Expand All @@ -130,7 +154,8 @@ def canonical_key(self):
self.canonical_key_vk,
)

def _text(self, canonical=False):
def _text(self, canonical: bool = False) -> str | None:
"""Helper method to generate the text representation of the action event."""
sep = config.ACTION_TEXT_SEP
name_prefix = config.ACTION_TEXT_NAME_PREFIX
name_suffix = config.ACTION_TEXT_NAME_SUFFIX
Expand Down Expand Up @@ -163,14 +188,17 @@ def _text(self, canonical=False):
return text

@property
def text(self):
def text(self) -> str:
"""Get the text representation of the action event."""
return self._text()

@property
def canonical_text(self):
def canonical_text(self) -> str:
"""Get the canonical text representation of the action event."""
return self._text(canonical=True)

def __str__(self):
def __str__(self) -> str:
"""Return a string representation of the action event."""
attr_names = [
"name",
"mouse_x",
Expand All @@ -193,15 +221,23 @@ def __str__(self):
return rval

@classmethod
def from_children(cls, children_dicts):
if children_dicts:
children = [ActionEvent(**child_dict) for child_dict in children_dicts]
return ActionEvent(children=children)
else:
return None
def from_children(cls: list, children_dicts: list) -> "ActionEvent":
"""Create an ActionEvent instance from a list of child event dictionaries.
Args:
children_dicts (list): List of dictionaries representing child events.
Returns:
ActionEvent: An instance of ActionEvent with the specified child events.
"""
children = [ActionEvent(**child_dict) for child_dict in children_dicts]
return ActionEvent(children=children)


class Screenshot(db.Base):
"""Class representing a screenshot in the database."""

__tablename__ = "screenshot"

id = sa.Column(sa.Integer, primary_key=True)
Expand All @@ -222,7 +258,8 @@ class Screenshot(db.Base):
_diff_mask = None

@property
def image(self):
def image(self) -> Image:
"""Get the image associated with the screenshot."""
if not self._image:
if self.sct_img:
self._image = Image.frombytes(
Expand All @@ -238,30 +275,41 @@ def image(self):
return self._image

@property
def diff(self):
def diff(self) -> Image:
"""Get the difference between the current and previous screenshot."""
if not self._diff:
assert self.prev, "Attempted to compute diff before setting prev"
self._diff = ImageChops.difference(self.image, self.prev.image)
return self._diff

@property
def diff_mask(self):
def diff_mask(self) -> Image:
"""Get the difference mask of the screenshot."""
if not self._diff_mask:
if self.diff:
self._diff_mask = self.diff.convert("1")
return self._diff_mask

@property
def array(self):
def array(self) -> np.ndarray:
"""Get the NumPy array representation of the image."""
return np.array(self.image)

@classmethod
def take_screenshot(cls):
def take_screenshot(cls: "Screenshot") -> "Screenshot":
"""Capture a screenshot."""
# avoid circular import
from openadapt import utils

sct_img = utils.take_screenshot()
screenshot = Screenshot(sct_img=sct_img)
return screenshot

def crop_active_window(self, action_event):
def crop_active_window(self, action_event: ActionEvent) -> None:
"""Crop the screenshot to the active window defined by the action event."""
# avoid circular import
from openadapt import utils

window_event = action_event.window_event
width_ratio, height_ratio = utils.get_scale_ratios(action_event)

Expand All @@ -275,6 +323,8 @@ def crop_active_window(self, action_event):


class WindowEvent(db.Base):
"""Class representing a window event in the database."""

__tablename__ = "window_event"

id = sa.Column(sa.Integer, primary_key=True)
Expand All @@ -292,11 +342,14 @@ class WindowEvent(db.Base):
action_events = sa.orm.relationship("ActionEvent", back_populates="window_event")

@classmethod
def get_active_window_event(cls):
def get_active_window_event(cls: "WindowEvent") -> "WindowEvent":
"""Get the active window event."""
return WindowEvent(**window.get_active_window_data())


class PerformanceStat(db.Base):
"""Class representing a performance statistic in the database."""

__tablename__ = "performance_stat"

id = sa.Column(sa.Integer, primary_key=True)
Expand All @@ -308,9 +361,11 @@ class PerformanceStat(db.Base):


class MemoryStat(db.Base):
"""Class representing a memory usage statistic in the database."""

__tablename__ = "memory_stat"

id = sa.Column(sa.Integer, primary_key=True)
recording_timestamp = sa.Column(sa.Integer)
memory_usage_bytes = sa.Column(ForceFloat)
timestamp = sa.Column(ForceFloat)
timestamp = sa.Column(ForceFloat)
Loading

0 comments on commit e3afd5a

Please sign in to comment.