Skip to content

Commit

Permalink
feat(action verification): implement replay action verification (#857)
Browse files Browse the repository at this point in the history
* add is_action_event_complete

* retry_with_exceptions in apply_replay_instructions

* fix parse_code_snippet

* add error_reporting.py

* refactor video.py

* black/flake8

* add module docstring

* CHECK_ACTION_COMPLETE
  • Loading branch information
abrichr authored Jul 24, 2024
1 parent 001c8fa commit b288c07
Show file tree
Hide file tree
Showing 11 changed files with 225 additions and 91 deletions.
50 changes: 1 addition & 49 deletions openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,13 @@
import os
import pathlib
import shutil
import webbrowser

from loguru import logger
from pydantic import field_validator
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource
from PySide6.QtWidgets import QMessageBox, QPushButton
import git
import sentry_sdk

from openadapt.build_utils import get_root_dir_path, is_running_from_executable
from openadapt.build_utils import get_root_dir_path

CONFIG_DEFAULTS_FILE_PATH = (
pathlib.Path(__file__).parent / "config.defaults.json"
Expand Down Expand Up @@ -411,47 +407,3 @@ def print_config() -> None:
if not key.startswith("_") and key.isupper():
val = maybe_obfuscate(key, val)
logger.info(f"{key}={val}")

if config.ERROR_REPORTING_ENABLED:
if is_running_from_executable():
is_reporting_branch = True
else:
active_branch_name = git.Repo(PARENT_DIR_PATH).active_branch.name
logger.info(f"{active_branch_name=}")
is_reporting_branch = (
active_branch_name == config.ERROR_REPORTING_BRANCH
)
logger.info(f"{is_reporting_branch=}")
if is_reporting_branch:

def show_alert() -> None:
"""Show an alert to the user."""
msg = QMessageBox()
msg.setIcon(QMessageBox.Warning)
msg.setText("""
An error has occurred. The development team has been notified.
Please join the discord server to get help or send an email to
[email protected]
""")
discord_button = QPushButton("Join the discord server")
discord_button.clicked.connect(
lambda: webbrowser.open("https://discord.gg/yF527cQbDG")
)
msg.addButton(discord_button, QMessageBox.ActionRole)
msg.addButton(QMessageBox.Ok)
msg.exec()

def before_send_event(event: Any, hint: Any) -> Any:
"""Handle the event before sending it to Sentry."""
try:
show_alert()
except Exception:
pass
return event

sentry_sdk.init(
dsn=config.ERROR_REPORTING_DSN,
traces_sample_rate=1.0,
before_send=before_send_event,
ignore_errors=[KeyboardInterrupt],
)
2 changes: 2 additions & 0 deletions openadapt/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
multiprocessing.freeze_support()

from openadapt.build_utils import redirect_stdout_stderr
from openadapt.error_reporting import configure_error_reporting
from openadapt.custom_logger import logger


Expand All @@ -19,6 +20,7 @@ def run_openadapt() -> None:
from openadapt.config import print_config

print_config()
configure_error_reporting()
load_alembic_context()
tray._run()
except Exception as exc:
Expand Down
67 changes: 67 additions & 0 deletions openadapt/error_reporting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Module for error reporting logic."""

from typing import Any

from loguru import logger
from PySide6.QtGui import QIcon
from PySide6.QtWidgets import QMessageBox, QPushButton
import git
import sentry_sdk
import webbrowser

from openadapt.build_utils import is_running_from_executable
from openadapt.config import PARENT_DIR_PATH, config


def configure_error_reporting() -> None:
"""Configure error reporting."""
logger.info(f"{config.ERROR_REPORTING_ENABLED=}")
if not config.ERROR_REPORTING_ENABLED:
return

if is_running_from_executable():
is_reporting_branch = True
else:
active_branch_name = git.Repo(PARENT_DIR_PATH).active_branch.name
logger.info(f"{active_branch_name=}")
is_reporting_branch = active_branch_name == config.ERROR_REPORTING_BRANCH
logger.info(f"{is_reporting_branch=}")

if is_reporting_branch:
sentry_sdk.init(
dsn=config.ERROR_REPORTING_DSN,
traces_sample_rate=1.0,
before_send=before_send_event,
ignore_errors=[KeyboardInterrupt],
)


def show_alert() -> None:
"""Show an alert to the user."""
# TODO: move to config
from openadapt.app.tray import ICON_PATH

msg = QMessageBox()
msg.setIcon(QMessageBox.Warning)
msg.setWindowIcon(QIcon(ICON_PATH))
msg.setText("""
An error has occurred. The development team has been notified.
Please join the discord server to get help or send an email to
[email protected]
""")
discord_button = QPushButton("Join the discord server")
discord_button.clicked.connect(
lambda: webbrowser.open("https://discord.gg/yF527cQbDG")
)
msg.addButton(discord_button, QMessageBox.ActionRole)
msg.addButton(QMessageBox.Ok)
msg.exec()


def before_send_event(event: Any, hint: Any) -> Any:
"""Handle the event before sending it to Sentry."""
try:
show_alert()
except Exception:
pass
return event
9 changes: 8 additions & 1 deletion openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,14 @@ def to_prompt_dict(self, include_data: bool = True) -> dict[str, Any]:
if "state" in window_dict:
if include_data:
key_suffixes = [
"value", "h", "w", "x", "y", "description", "title", "help",
"value",
"h",
"w",
"x",
"y",
"description",
"title",
"help",
]
if sys.platform == "win32":
logger.warning(
Expand Down
9 changes: 9 additions & 0 deletions openadapt/prompts/apply_replay_instructions.j2
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,13 @@ Do NOT provide available_segment_descriptions in your response.

Respond with json and nothing else.

{% if exceptions.length %}
Your previous attempts at this produced the following exceptions:
{% for exception in exceptions %}
<exception>
{{ exception }}
</exception>
{% endfor %}
{% endif %}

My career depends on this. Lives are at stake.
22 changes: 22 additions & 0 deletions openadapt/prompts/is_action_complete.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Consider the actions that you previously generated:

```json
{{ actions }}
```

The attached image is a screenshot of the current state of the system, immediately
after the last action in the sequence was played.

Your task is to:
1. Describe what you would expect to see in the screenshot after the last action in the
sequence is complete, and
2. Determine whether the the last action has completed by looking at the attached
screenshot. For example, if you expect that the sequence of actions would result in
opening a particular application, you should determine whether that application has
finished opening.

Respond with JSON and nothing else. The JSON should have the following keys:
- "expected_state": Natural language description of what you would expect to see.
- "is_complete": Boolean indicating whether the last action is complete or not.

My career depends on this. Lives are at stake.
2 changes: 2 additions & 0 deletions openadapt/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from openadapt import utils
from openadapt.config import CAPTURE_DIR_PATH, print_config
from openadapt.db import crud
from openadapt.error_reporting import configure_error_reporting
from openadapt.models import Recording

LOG_LEVEL = "INFO"
Expand Down Expand Up @@ -50,6 +51,7 @@ def replay(
"""
utils.configure_logging(logger, LOG_LEVEL)
print_config()
configure_error_reporting()
posthog.capture(event="replay.started", properties={"strategy_name": strategy_name})

if status_pipe:
Expand Down
52 changes: 51 additions & 1 deletion openadapt/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from oa_pynput import keyboard, mouse
import numpy as np

from openadapt import models, playback, utils
from openadapt import adapters, models, playback, utils
from openadapt.custom_logger import logger

CHECK_ACTION_COMPLETE = True
MAX_FRAME_TIMES = 1000


Expand Down Expand Up @@ -55,6 +56,16 @@ def run(self) -> None:
mouse_controller = mouse.Controller()
while True:
screenshot = models.Screenshot.take_screenshot()

# check if previous action is complete
if CHECK_ACTION_COMPLETE:
is_action_complete = prompt_is_action_complete(
screenshot,
self.action_events,
)
if not is_action_complete:
continue

self.screenshots.append(screenshot)
window_event = models.WindowEvent.get_active_window_event()
self.window_events.append(window_event)
Expand Down Expand Up @@ -108,3 +119,42 @@ def log_fps(self) -> None:
logger.info(f"{fps=:.2f}")
if len(self.frame_times) > self.max_frame_times:
self.frame_times.pop(0)


def prompt_is_action_complete(
current_screenshot: models.Screenshot,
played_actions: list[models.ActionEvent],
) -> bool:
"""Determine whether the the last action is complete.
Args:
current_screenshot (models.Screenshot): The current Screenshot.
played_actions (list[models.ActionEvent]: The list of previously played
ActionEvents.
Returns:
(bool) whether or not the last played action has completed.
"""
if not played_actions:
return True
system_prompt = utils.render_template_from_file(
"prompts/system.j2",
)
actions_dict = {
"actions": [action.to_prompt_dict() for action in played_actions],
}
prompt = utils.render_template_from_file(
"prompts/is_action_complete.j2",
actions=actions_dict,
)
prompt_adapter = adapters.get_default_prompt_adapter()
content = prompt_adapter.prompt(
prompt,
system_prompt=system_prompt,
images=[current_screenshot.image],
)
content_dict = utils.parse_code_snippet(content)
expected_state = content_dict["expected_state"]
is_complete = content_dict["is_complete"]
logger.info(f"{expected_state=} {is_complete=}")
return is_complete
6 changes: 5 additions & 1 deletion openadapt/strategies/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,19 @@ def add_active_segment_descriptions(action_events: list[models.ActionEvent]) ->
action.available_segment_descriptions = window_segmentation.descriptions


@utils.retry_with_exceptions()
def apply_replay_instructions(
action_events: list[models.ActionEvent],
replay_instructions: str,
# retain_window_events: bool = False,
exceptions: list[Exception],
) -> None:
"""Modify the given ActionEvents according to the given replay instructions.
Args:
action_events: list of action events to be modified in place.
replay_instructions: instructions for how action events should be modified.
exceptions: list of exceptions that were produced attempting to run this
function.
"""
action_dicts = [action.to_prompt_dict() for action in action_events]
actions_dict = {"actions": action_dicts}
Expand All @@ -131,6 +134,7 @@ def apply_replay_instructions(
"prompts/apply_replay_instructions.j2",
actions=actions_dict,
replay_instructions=replay_instructions,
exceptions=exceptions,
)
prompt_adapter = adapters.get_default_prompt_adapter()
content = prompt_adapter.prompt(
Expand Down
Loading

0 comments on commit b288c07

Please sign in to comment.