From 7ef115a6172222d7d624ce92d0ec38a60d0e86f1 Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Sat, 22 Jun 2024 19:15:37 -0400 Subject: [PATCH] feat(SegmentReplayStrategy, drivers): add strategies.replay; refactor adapters -> drivers + adapters (#714) * implemented * add get_active_window_data parameter include_window_data; fix ActionEvent.from_dict to handle multiple separators; add test_models.py * update get_default_prompt_adapter * add TODO * tests.openadapt.adapters -> drivers * utils.get_marked_image, .extract_code_block; .strip_backticks * working segment.py (misses final click in calculator task) * include_replay_instructions; dev_mode=False * fix test_openai.py: ValueError -> Exception * replay.py --record -> --capture * black/flake8 * remove import * INCLUDE_CURRENT_SCREENSHOT; handle mouse events without x/y * add models.Replay; print_config in replay.py --- openadapt/adapters/__init__.py | 10 +- openadapt/adapters/prompt.py | 47 +++ openadapt/adapters/ultralytics.py | 24 +- openadapt/{adapters => drivers}/anthropic.py | 0 openadapt/{adapters => drivers}/google.py | 0 openadapt/{adapters => drivers}/openai.py | 23 +- openadapt/models.py | 16 +- openadapt/playback.py | 5 +- openadapt/plotting.py | 56 ++- .../prompts/describe_recording--segment.j2 | 16 + .../prompts/generate_action_event--segment.j2 | 73 ++++ openadapt/replay.py | 17 +- openadapt/strategies/__init__.py | 1 + openadapt/strategies/base.py | 1 + openadapt/strategies/segment.py | 345 ++++++++++++++++++ openadapt/strategies/visual.py | 136 ++++--- openadapt/utils.py | 41 +++ openadapt/video.py | 8 +- openadapt/window/__init__.py | 5 +- .../{adapters => drivers}/test_anthropic.py | 6 +- .../{adapters => drivers}/test_google.py | 4 +- .../{adapters => drivers}/test_openai.py | 6 +- 22 files changed, 740 insertions(+), 100 deletions(-) create mode 100644 openadapt/adapters/prompt.py rename openadapt/{adapters => drivers}/anthropic.py (100%) rename openadapt/{adapters => drivers}/google.py (100%) rename openadapt/{adapters => drivers}/openai.py (95%) create mode 100644 openadapt/prompts/describe_recording--segment.j2 create mode 100644 openadapt/prompts/generate_action_event--segment.j2 create mode 100644 openadapt/strategies/segment.py rename tests/openadapt/{adapters => drivers}/test_anthropic.py (73%) rename tests/openadapt/{adapters => drivers}/test_google.py (89%) rename tests/openadapt/{adapters => drivers}/test_openai.py (86%) diff --git a/openadapt/adapters/__init__.py b/openadapt/adapters/__init__.py index 79a6b96f1..c123eafe1 100644 --- a/openadapt/adapters/__init__.py +++ b/openadapt/adapters/__init__.py @@ -4,22 +4,20 @@ from openadapt.config import config -from . import anthropic, google, openai, replicate, som, ultralytics +from . import prompt, replicate, som, ultralytics +# TODO: remove def get_default_prompt_adapter() -> ModuleType: """Returns the default prompt adapter module. Returns: The module corresponding to the default prompt adapter. """ - return { - "openai": openai, - "anthropic": anthropic, - "google": google, - }[config.DEFAULT_ADAPTER] + return prompt +# TODO: refactor to follow adapters.prompt def get_default_segmentation_adapter() -> ModuleType: """Returns the default image segmentation adapter module. diff --git a/openadapt/adapters/prompt.py b/openadapt/adapters/prompt.py new file mode 100644 index 000000000..95a0f1d1d --- /dev/null +++ b/openadapt/adapters/prompt.py @@ -0,0 +1,47 @@ +"""Adapter for prompting foundation models.""" + +from loguru import logger +from typing import Type +from PIL import Image + + +from openadapt.drivers import anthropic, google, openai + + +# Define a list of drivers in the order they should be tried +DRIVER_ORDER: list[Type] = [openai, google, anthropic] + + +def prompt( + text: str, + images: list[Image.Image] | None = None, + system_prompt: str | None = None, +) -> str: + """Attempt to fetch a prompt completion from various services in order of priority. + + Args: + text: The main text prompt. + images: list of images to include in the prompt. + system_prompt: An optional system-level prompt. + + Returns: + The result from the first successful driver. + """ + text = text.strip() + for driver in DRIVER_ORDER: + try: + logger.info(f"Trying driver: {driver.__name__}") + return driver.prompt(text, images=images, system_prompt=system_prompt) + except Exception as e: + logger.exception(e) + logger.error(f"Driver {driver.__name__} failed with error: {e}") + import ipdb + + ipdb.set_trace() + continue + raise Exception("All drivers failed to provide a response") + + +if __name__ == "__main__": + # This could be extended to use command-line arguments or other input methods + print(prompt("Describe the solar system.")) diff --git a/openadapt/adapters/ultralytics.py b/openadapt/adapters/ultralytics.py index 24be674e3..9673aa2cb 100644 --- a/openadapt/adapters/ultralytics.py +++ b/openadapt/adapters/ultralytics.py @@ -77,10 +77,26 @@ def do_fastsam( retina_masks: bool = True, imgsz: int | tuple[int, int] | None = 1024, # threshold below which boxes will be filtered out - conf: float = 0.4, + min_confidence_threshold: float = 0.4, # discards all overlapping boxes with IoU > iou_threshold - iou: float = 0.9, + max_iou_threshold: float = 0.9, ) -> Image: + """Get segmented image via FastSAM. + + For usage of thresholds see: + github.com/ultralytics/ultralytics/blob/dacbd48fcf8407098166c6812eeb751deaac0faf + /ultralytics/utils/ops.py#L164 + + Args: + TODO + min_confidence_threshold (float, optional): The minimum confidence score + that a detection must meet or exceed to be considered valid. Detections + below this threshold will not be marked. Defaults to 0.00. + max_iou_threshold (float, optional): The maximum allowed Intersection over + Union (IoU) value for overlapping detections. Detections that exceed this + IoU threshold are considered for suppression, keeping only the + detection with the highest confidence. Defaults to 0.05. + """ model = FastSAM(model_name) imgsz = imgsz or image.size @@ -91,8 +107,8 @@ def do_fastsam( device=device, retina_masks=retina_masks, imgsz=imgsz, - conf=conf, - iou=iou, + conf=min_confidence_threshold, + iou=max_iou_threshold, ) # Prepare a Prompt Process object diff --git a/openadapt/adapters/anthropic.py b/openadapt/drivers/anthropic.py similarity index 100% rename from openadapt/adapters/anthropic.py rename to openadapt/drivers/anthropic.py diff --git a/openadapt/adapters/google.py b/openadapt/drivers/google.py similarity index 100% rename from openadapt/adapters/google.py rename to openadapt/drivers/google.py diff --git a/openadapt/adapters/openai.py b/openadapt/drivers/openai.py similarity index 95% rename from openadapt/adapters/openai.py rename to openadapt/drivers/openai.py index 940e0550b..e5b3d28d3 100644 --- a/openadapt/adapters/openai.py +++ b/openadapt/drivers/openai.py @@ -123,7 +123,12 @@ def get_response( headers=headers, json=payload, ) - return response + result = response.json() + if "error" in result: + error = result["error"] + message = error["message"] + raise Exception(message) + return result def get_completion(payload: dict, dev_mode: bool = False) -> str: @@ -136,15 +141,10 @@ def get_completion(payload: dict, dev_mode: bool = False) -> str: Returns: (str) first message from the response """ - response = get_response(payload) - response.raise_for_status() - result = response.json() - logger.info(f"result=\n{pformat(result)}") - if "error" in result: - error = result["error"] - message = error["message"] - # TODO: fail after maximum number of attempts - if "retry your request" in message: + try: + result = get_response(payload) + except Exception as exc: + if "retry your request" in str(exc): return get_completion(payload) elif dev_mode: import ipdb @@ -152,7 +152,8 @@ def get_completion(payload: dict, dev_mode: bool = False) -> str: ipdb.set_trace() # TODO: handle more errors else: - raise ValueError(result["error"]["message"]) + raise exc + logger.info(f"result=\n{pformat(result)}") choices = result["choices"] choice = choices[0] message = choice["message"] diff --git a/openadapt/models.py b/openadapt/models.py index 33ce49a79..76d424425 100644 --- a/openadapt/models.py +++ b/openadapt/models.py @@ -388,11 +388,11 @@ def from_dict( suffix_len = len(name_suffix) key_names = utils.split_by_separators( - action_dict["text"][prefix_len:-suffix_len], + action_dict.get("text", "")[prefix_len:-suffix_len], key_seps, ) canonical_key_names = utils.split_by_separators( - action_dict["canonical_text"][prefix_len:-suffix_len], + action_dict.get("canonical_text", "")[prefix_len:-suffix_len], key_seps, ) logger.info(f"{key_names=}") @@ -920,6 +920,18 @@ def asdict(self) -> dict: } +class Replay(db.Base): + """Class representing a replay in the database.""" + + __tablename__ = "replay" + + id = sa.Column(sa.Integer, primary_key=True) + timestamp = sa.Column(ForceFloat) + strategy_name = sa.Column(sa.String) + strategy_args = sa.Column(sa.JSON) + git_hash = sa.Column(sa.String) + + def copy_sa_instance(sa_instance: db.Base, **kwargs: dict) -> db.Base: """Copy a SQLAlchemy instance. diff --git a/openadapt/playback.py b/openadapt/playback.py index bcd9776c6..5d2e7270c 100644 --- a/openadapt/playback.py +++ b/openadapt/playback.py @@ -27,7 +27,10 @@ def play_mouse_event(event: ActionEvent, mouse_controller: mouse.Controller) -> pressed = event.mouse_pressed logger.debug(f"{name=} {x=} {y=} {dx=} {dy=} {button_name=} {pressed=}") - mouse_controller.position = (x, y) + if all([val is not None for val in (x, y)]): + mouse_controller.position = (x, y) + else: + logger.warning(f"{x=} {y=}") if name == "move": pass elif name == "click": diff --git a/openadapt/plotting.py b/openadapt/plotting.py index ba6260474..f476c3b82 100644 --- a/openadapt/plotting.py +++ b/openadapt/plotting.py @@ -13,7 +13,7 @@ import matplotlib.pyplot as plt import numpy as np -from openadapt import common, models, utils +from openadapt import common, contrib, models, utils from openadapt.config import PERFORMANCE_PLOTS_DIR_PATH, config from openadapt.models import ActionEvent @@ -764,3 +764,57 @@ def plot_segments( plt.imshow(image) plt.axis("off") plt.show() + + +def get_marked_image( + original_image: Image.Image, + masks: list[np.ndarray], + include_masks: bool = True, + include_marks: bool = True, +) -> Image.Image: + """Get a Set-of-Mark image using the original SoM visualizer. + + Args: + original_image (Image.Image): The original PIL image. + masks (list[np.ndarray]): A list of masks representing segments in the + original image. + include_masks (bool, optional): If True, masks will be included in the + output visualizations. Defaults to True. + include_marks (bool, optional): If True, marks will be included in the + output visualizations. Defaults to True. + + Returns: + Image.Image: The marked image, where marks and/or masks are applied based on + the specified confidence and IoU thresholds and the include flags. + """ + image_arr = np.asarray(original_image) + + # The rest of this function is copied from + # github.com/microsoft/SoM/blob/main/task_adapter/sam/tasks/inference_sam_m2m_auto.py + + # metadata = MetadataCatalog.get('coco_2017_train_panoptic') + metadata = None + visual = contrib.som.visualizer.Visualizer(image_arr, metadata=metadata) + mask_map = np.zeros(image_arr.shape, dtype=np.uint8) + label_mode = "1" + alpha = 0.1 + anno_mode = [] + if include_masks: + anno_mode.append("Mask") + if include_marks: + anno_mode.append("Mark") + for i, mask in enumerate(masks): + label = i + 1 + demo = visual.draw_binary_mask_with_number( + mask, + text=str(label), + label_mode=label_mode, + alpha=alpha, + anno_mode=anno_mode, + ) + mask_map[mask == 1] = label + + im = demo.get_image() + marked_image = Image.fromarray(im) + + return marked_image diff --git a/openadapt/prompts/describe_recording--segment.j2 b/openadapt/prompts/describe_recording--segment.j2 new file mode 100644 index 000000000..8668e7d0f --- /dev/null +++ b/openadapt/prompts/describe_recording--segment.j2 @@ -0,0 +1,16 @@ +Consider the actions in the recording and states of the active window immediately +before each action was taken: + +```json +{{ action_windows }} +``` + +Consider the attached screenshots taken immediately before each action. The order of +the screenshots matches the order of the actions above. + +Provide a detailed natural language description of everything that happened +in this recording. This description will be embedded in the context for a future prompt +to replay the recording (subject to proposed modifications in natural language) on a +live system, so make sure to include everything you will need to know. + +My career depends on this. Lives are at stake. diff --git a/openadapt/prompts/generate_action_event--segment.j2 b/openadapt/prompts/generate_action_event--segment.j2 new file mode 100644 index 000000000..004112a43 --- /dev/null +++ b/openadapt/prompts/generate_action_event--segment.j2 @@ -0,0 +1,73 @@ +{% if include_raw_recording %} +Consider the previously recorded actions: + +```json +{{ recorded_actions }} +``` +{% endif %} + + +{% if include_raw_recording_description %} +Consider the following description of the previously recorded actions: + +``json +{{ recording_description }} +``` +{% endif %} + + +{% if include_replay_instructions %} +Consider the user's proposed modifications in natural language instructions: + +```text +{{ replay_instructions }} +``` +{% endif %} + + +{% if include_modified_recording %} +Consider this updated list of actions that have been modified such that replaying them +would have accomplished the user's instructions: + +```json +{{ modified_actions }} +``` +{% endif %} + + +{% if include_modified_recording_description %} +Consider the following description of the updated list of actions that have been +modified such that replaying them would have accomplished the user's instructions: + +``json +{{ modified_recording_description }} +``` +{% endif %} + + +Consider the actions you've produced (and we have played back) so far: + +```json +{{ replayed_actions }} +``` + +{% if include_active_window %} +Consider the current active window: +```json +{{ current_window }} +``` +{% endif %} + + +The attached image is a screenshot of the current state of the system. + +Provide the next action to be replayed in order to accomplish the user's replay +instructions. + +Do NOT provide available_segment_descriptions in your response. + +Respond with JSON and nothing else. + +If you wish to terminate the recording, return an empty object. + +My career depends on this. Lives are at stake. diff --git a/openadapt/replay.py b/openadapt/replay.py index 1029102a8..79d9bf425 100644 --- a/openadapt/replay.py +++ b/openadapt/replay.py @@ -16,8 +16,8 @@ with redirect_stdout_stderr(): import fire -from openadapt import capture, utils -from openadapt.config import CAPTURE_DIR_PATH +from openadapt import capture as _capture, utils +from openadapt.config import CAPTURE_DIR_PATH, print_config from openadapt.db import crud from openadapt.models import Recording @@ -29,7 +29,7 @@ @logger.catch def replay( strategy_name: str, - record: bool = False, + capture: bool = False, timestamp: str | None = None, recording: Recording = None, status_pipe: multiprocessing.connection.Connection | None = None, @@ -41,7 +41,7 @@ def replay( strategy_name (str): Name of the replay strategy to use. timestamp (str, optional): Timestamp of the recording to replay. recording (Recording, optional): Recording to replay. - record (bool, optional): Flag indicating whether to record the replay. + capture (bool, optional): Flag indicating whether to capture the replay. status_pipe: A connection to communicate replay status. kwargs: Keyword arguments to pass to strategy. @@ -49,6 +49,7 @@ def replay( bool: True if replay was successful, None otherwise. """ utils.configure_logging(logger, LOG_LEVEL) + print_config() posthog.capture(event="replay.started", properties={"strategy_name": strategy_name}) if status_pipe: @@ -85,8 +86,8 @@ def replay( handler = None rval = True - if record: - capture.start(audio=False, camera=False) + if capture: + _capture.start(audio=False, camera=False) # TODO: handle this more robustly sleep(1) file_name = f"log-{strategy_name}-{recording.timestamp}.log" @@ -107,9 +108,9 @@ def replay( properties={"strategy_name": strategy_name, "success": rval}, ) - if record: + if capture: sleep(1) - capture.stop() + _capture.stop() logger.remove(handler) return rval diff --git a/openadapt/strategies/__init__.py b/openadapt/strategies/__init__.py index 553a739e6..fecc7c056 100644 --- a/openadapt/strategies/__init__.py +++ b/openadapt/strategies/__init__.py @@ -9,6 +9,7 @@ # disabled because importing is expensive # from openadapt.strategies.demo import DemoReplayStrategy from openadapt.strategies.naive import NaiveReplayStrategy +from openadapt.strategies.segment import SegmentReplayStrategy from openadapt.strategies.stateful import StatefulReplayStrategy from openadapt.strategies.vanilla import VanillaReplayStrategy from openadapt.strategies.visual import VisualReplayStrategy diff --git a/openadapt/strategies/base.py b/openadapt/strategies/base.py index 95ca88f7d..96f8b012f 100644 --- a/openadapt/strategies/base.py +++ b/openadapt/strategies/base.py @@ -95,6 +95,7 @@ def run(self) -> None: import ipdb ipdb.set_trace() + foo = 1 # noqa def log_fps(self) -> None: """Log the frames per second (FPS) rate.""" diff --git a/openadapt/strategies/segment.py b/openadapt/strategies/segment.py new file mode 100644 index 000000000..931f1f80f --- /dev/null +++ b/openadapt/strategies/segment.py @@ -0,0 +1,345 @@ +"""Extends vanilla strategy to use segmentation to convert description -> coordinate. + +Uses FastSAM for segmentation. +""" + +from pprint import pformat + +from loguru import logger + +from openadapt import adapters, common, models, strategies, utils +from openadapt.strategies.visual import ( + add_active_segment_descriptions, + get_window_segmentation, + apply_replay_instructions, +) + + +INCLUDE_RAW_RECORDING = False +INCLUDE_RAW_RECORDING_DESCRIPTION = False +INCLUDE_MODIFIED_RECORDING = True +INCLUDE_MODIFIED_RECORDING_DESCRIPTION = False +INCLUDE_REPLAY_INSTRUCTIONS = False +INCLUDE_WINDOW = False +INCLUDE_WINDOW_DATA = False +FILTER_MASKS = True +INCLUDE_CURRENT_SCREENSHOT = False + + +class SegmentReplayStrategy(strategies.base.BaseReplayStrategy): + """Segment replay strategy that performs segmentation in addition to vanilla.""" + + def __init__( + self, + recording: models.Recording, + instructions: str = "", + ) -> None: + """Initialize the SegmentReplayStrategy. + + Args: + recording (models.Recording): The recording object. + instructions(str): Natural language instructions + for how recording should be replayed. + """ + super().__init__(recording) + self.replay_instructions = instructions + self.action_history = [] + self.action_event_idx = 0 + + add_active_segment_descriptions(recording.processed_action_events) + self.modified_actions = apply_replay_instructions( + recording.processed_action_events, + self.replay_instructions, + ) + + if INCLUDE_RAW_RECORDING_DESCRIPTION: + self.recording_description = describe_recording( + self.recording.processed_action_events + ) + else: + self.recording_description = None + + if INCLUDE_MODIFIED_RECORDING_DESCRIPTION: + self.modified_recording_description = describe_recording( + self.modified_actions + ) + else: + self.modified_recording_description = None + + def get_next_action_event( + self, + screenshot: models.Screenshot, + window_event: models.WindowEvent, + include_raw_recording: bool = INCLUDE_RAW_RECORDING, + include_raw_recording_description: bool = INCLUDE_RAW_RECORDING_DESCRIPTION, + include_modified_recording: bool = INCLUDE_MODIFIED_RECORDING, + include_modified_recording_description: bool = ( + INCLUDE_MODIFIED_RECORDING_DESCRIPTION + ), + include_active_window: bool = INCLUDE_WINDOW, + include_active_window_data: bool = INCLUDE_WINDOW_DATA, + include_replay_instructions: bool = INCLUDE_REPLAY_INSTRUCTIONS, + include_current_screenshot: bool = INCLUDE_CURRENT_SCREENSHOT, + ) -> models.ActionEvent | None: + """Get the next ActionEvent for replay. + + Args: + screenshot (models.Screenshot): The screenshot object. + window_event (models.WindowEvent): The window event object. + include_raw_recording (bool): Whether to include the raw recording in the + prompt. + include_raw_recording_description (bool): Whether to include the raw + recording description in the prompt. + include_modified_recording (bool): Whether to include the modified + recording in the prompt. + include_modified_recording_description (bool): Whether to include the + modified recording description in the prompt. + include_active_window (bool): Whether to include window metadata in the + prompt. + include_active_window_data (bool): Whether to retain window a11y data in + the prompt. + include_replay_instructions (bool): Whether to include replay instructions + in the prompt. + include_current_screenshot (bool): Whether to include the current screenshot + in the prompt. + + Returns: + models.ActionEvent or None: The next ActionEvent for replay or None + if there are no more events. + """ + reference_actions = self.recording.processed_action_events + num_action_events = max( + len(reference_actions), + len(self.modified_actions), + ) + self.action_event_idx += 1 + if self.action_event_idx >= num_action_events: + raise StopIteration() + logger.debug(f"{self.action_event_idx=} of {num_action_events=}") + + generated_action_event = generate_action_event( + screenshot, + window_event, + reference_actions, + self.modified_actions, + self.action_history, + self.replay_instructions, + self.recording_description, + self.modified_recording_description, + include_raw_recording, + include_raw_recording_description, + include_modified_recording, + include_modified_recording_description, + include_active_window, + include_active_window_data, + include_replay_instructions, + include_current_screenshot, + ) + if not generated_action_event: + raise StopIteration() + + # convert segment -> coordinate + # (based on visual.py) + active_window = models.WindowEvent.get_active_window_event( + include_active_window_data + ) + + active_screenshot = models.Screenshot.take_screenshot() + logger.info(f"{active_window=}") + + if ( + generated_action_event.name in common.MOUSE_EVENTS + and generated_action_event.active_segment_description + ): + generated_action_event.screenshot = active_screenshot + generated_action_event.window_event = active_window + generated_action_event.recording = self.recording + exceptions = [] + while True: + active_window_segmentation = get_window_segmentation( + generated_action_event, + exceptions=exceptions, + ) + try: + target_segment_idx = active_window_segmentation.descriptions.index( + generated_action_event.active_segment_description + ) + except ValueError as exc: + exceptions.append(exc) + # TODO XXX this does not update the prompts, even though it should + logger.exception(exc) + import ipdb + + ipdb.set_trace() + logger.warning(f"{exc=} {len(exceptions)=}") + else: + break + target_centroid = active_window_segmentation.centroids[target_segment_idx] + # = scale_ratio * + width_ratio, height_ratio = utils.get_scale_ratios(generated_action_event) + target_mouse_x = target_centroid[0] / width_ratio + active_window.left + target_mouse_y = target_centroid[1] / height_ratio + active_window.top + generated_action_event.mouse_x = target_mouse_x + generated_action_event.mouse_y = target_mouse_y + else: + # just click wherever the mouse already is + pass + + self.action_history.append(generated_action_event) + return generated_action_event + + def __del__(self) -> None: + """Log the action history.""" + action_history_dicts = [ + action.to_prompt_dict() for action in self.action_history + ] + logger.info(f"action_history=\n{pformat(action_history_dicts)}") + + +def describe_recording( + action_events: list[models.ActionEvent], + include_window: bool = INCLUDE_WINDOW, + include_window_data: bool = INCLUDE_WINDOW_DATA, +) -> str: + """Generate a natural language description of the actions in the recording. + + Given the recorded states and actions, describe what happened. + + Args: + action_events (list[models.ActionEvent]): the list of actions to describe. + include_window (bool): flag indicating whether to include window metadata. + include_window_data (bool): flag indicating whether to include accessibility + API data in each window event. + + Returns: + (str) natural language description of the what happened in the recording. + """ + action_dicts = [action.to_prompt_dict() for action in action_events] + window_dicts = [ + ( + action.window_event.to_prompt_dict(include_window_data) + # this may be a modified action, in which case there is no window event + if action.window_event + else {} + ) + for action in action_events + ] + action_window_dicts = [ + { + "action": action_dict, + "window": window_dict if include_window else {}, + } + for action_dict, window_dict in zip(action_dicts, window_dicts) + ] + images = [action.screenshot.image for action in action_events if action.screenshot] + system_prompt = utils.render_template_from_file( + "prompts/system.j2", + ) + prompt = utils.render_template_from_file( + "prompts/describe_recording.j2", + action_windows=action_window_dicts, + ) + prompt_adapter = adapters.get_default_prompt_adapter() + recording_description = prompt_adapter.prompt( + prompt, + images=images, + system_prompt=system_prompt, + ) + return recording_description + + +def generate_action_event( + current_screenshot: models.Screenshot, + current_window_event: models.WindowEvent, + recorded_actions: list[models.ActionEvent], + modified_actions: list[models.ActionEvent], + replayed_actions: list[models.ActionEvent], + replay_instructions: str, + recording_description: str, + modified_recording_description: str, + include_raw_recording: bool, + include_raw_recording_description: bool, + include_modified_recording: bool, + include_modified_recording_description: bool, + include_active_window: bool, + include_active_window_data: bool, + include_replay_instructions: str, + include_current_screenshot: bool, +) -> models.ActionEvent: + """Modify the given ActionEvents according to the given replay instructions. + + Given the description of what happened, proposed modifications in natural language + instructions, the current state, and the actions produced so far, produce the next + action. + + Args: + current_screenshot (models.Screenshot): current state screenshot + current_window_event (models.WindowEvent): current state window data + recorded_actions (list[models.ActionEvent]): list of action events from the + recording + replayed_actions (list[models.ActionEvent]): list of actions produced during + current replay + replay_instructions (str): proposed modifications in natural language + instructions + include_raw_recording (bool): Whether to include the raw recording in the + prompt. + include_raw_recording_description (bool): Whether to include the raw + recording description in the prompt. + include_modified_recording (bool): Whether to include the modified + recording in the prompt. + include_modified_recording_description (bool): Whether to include the + modified recording description in the prompt. + include_active_window (bool): Whether to include window metadata in the + prompt. + include_active_window_data (bool): Whether to retain window a11y data in + the prompt. + include_replay_instructions (bool): Whether to include replay instructions + in the prompt. + include_current_screenshot (bool): Whether to include the current screenshot + in the prompt. + + Returns: + (models.ActionEvent) the next action event to be played, produced by the model + """ + current_image = current_screenshot.image + current_window_dict = current_window_event.to_prompt_dict( + include_active_window_data, + ) + recorded_action_dicts = [action.to_prompt_dict() for action in recorded_actions] + replayed_action_dicts = [action.to_prompt_dict() for action in replayed_actions] + modified_action_dicts = [action.to_prompt_dict() for action in modified_actions] + + system_prompt = utils.render_template_from_file( + "prompts/system.j2", + ) + prompt = utils.render_template_from_file( + "prompts/generate_action_event--segment.j2", + current_window=current_window_dict, + recorded_actions=recorded_action_dicts, + modified_actions=modified_action_dicts, + replayed_actions=replayed_action_dicts, + replay_instructions=replay_instructions, + recording_description=recording_description, + modified_recording_description=modified_recording_description, + include_raw_recording=include_raw_recording, + include_raw_recording_description=include_raw_recording_description, + include_modified_recording=include_modified_recording, + include_modified_recording_description=include_modified_recording_description, + include_active_window=include_active_window, + include_replay_instructions=include_replay_instructions, + ) + prompt_adapter = adapters.get_default_prompt_adapter() + images = [current_image] if include_current_screenshot else [] + content = prompt_adapter.prompt( + prompt, + images=images, + system_prompt=system_prompt, + ) + action_dict = utils.parse_code_snippet(content) + logger.info(f"{action_dict=}") + if not action_dict: + # allow early stopping + return None + action = models.ActionEvent.from_dict(action_dict) + logger.info(f"{action=}") + return action diff --git a/openadapt/strategies/visual.py b/openadapt/strategies/visual.py index a19aa93d2..7aaa00cab 100644 --- a/openadapt/strategies/visual.py +++ b/openadapt/strategies/visual.py @@ -52,7 +52,15 @@ from PIL import Image, ImageDraw import numpy as np -from openadapt import adapters, common, models, plotting, strategies, utils, vision +from openadapt import ( + adapters, + common, + models, + plotting, + strategies, + utils, + vision, +) DEBUG = False DEBUG_REPLAY = False @@ -68,6 +76,7 @@ class Segmentation: Attributes: image: The original image used to generate segments. + marked_image: The marked image (for Set-of-Mark prompting). masked_images: A list of PIL Image objects that have been masked based on segmentation. descriptions: Descriptions of each segmented region, correlating with each @@ -81,6 +90,7 @@ class Segmentation: """ image: Image.Image + marked_image: Image.Image masked_images: list[Image.Image] descriptions: list[str] bounding_boxes: list[dict[str, float]] # "top", "left", "height", "width" @@ -112,6 +122,7 @@ def add_active_segment_descriptions(action_events: list[models.ActionEvent]) -> def apply_replay_instructions( action_events: list[models.ActionEvent], replay_instructions: str, + # retain_window_events: bool = False, ) -> None: """Modify the given ActionEvents according to the given replay instructions. @@ -132,7 +143,7 @@ def apply_replay_instructions( prompt_adapter = adapters.get_default_prompt_adapter() content = prompt_adapter.prompt( prompt, - system_prompt, + system_prompt=system_prompt, ) content_dict = utils.parse_code_snippet(content) try: @@ -167,6 +178,7 @@ def __init__( """ super().__init__(recording) self.recording_action_idx = 0 + self.action_history = [] add_active_segment_descriptions(recording.processed_action_events) self.modified_actions = apply_replay_instructions( recording.processed_action_events, @@ -235,8 +247,16 @@ def get_next_action_event( target_mouse_y = target_centroid[1] / height_ratio + active_window.top modified_reference_action.mouse_x = target_mouse_x modified_reference_action.mouse_y = target_mouse_y + self.action_history.append(modified_reference_action) return modified_reference_action + def __del__(self) -> None: + """Log the action history.""" + action_history_dicts = [ + action.to_prompt_dict() for action in self.action_history + ] + logger.info(f"action_history=\n{pformat(action_history_dicts)}") + def get_active_segment( action: models.ActionEvent, @@ -420,8 +440,13 @@ def get_window_segmentation( len(descriptions), len(centroids), ) + marked_image = plotting.get_marked_image( + original_image, + refined_masks, # masks, + ) segmentation = Segmentation( original_image, + marked_image, masked_images, descriptions, bounding_boxes, @@ -451,65 +476,62 @@ def prompt_for_descriptions( Returns: list of descriptions for each masked image. """ - prompt_adapter = adapters.get_default_prompt_adapter() + # TODO: move inside adapters.prompt + for driver in adapters.prompt.DRIVER_ORDER: + # off by one to account for original image + if driver.MAX_IMAGES and (len(masked_images) + 1 > driver.MAX_IMAGES): + masked_images_batches = utils.split_list( + masked_images, + driver.MAX_IMAGES - 1, + ) + descriptions = [] + for masked_images_batch in masked_images_batches: + descriptions_batch = prompt_for_descriptions( + original_image, + masked_images_batch, + active_segment_description, + exceptions, + ) + descriptions += descriptions_batch + return descriptions - # TODO: move inside adapters - # off by one to account for original image - if prompt_adapter.MAX_IMAGES and ( - len(masked_images) + 1 > prompt_adapter.MAX_IMAGES - ): - masked_images_batches = utils.split_list( - masked_images, - prompt_adapter.MAX_IMAGES - 1, + images = [original_image] + masked_images + system_prompt = utils.render_template_from_file( + "prompts/system.j2", + ) + logger.info(f"system_prompt=\n{system_prompt}") + num_segments = len(masked_images) + prompt = utils.render_template_from_file( + "prompts/description.j2", + active_segment_description=active_segment_description, + num_segments=num_segments, + exceptions=exceptions, + ).strip() + logger.info(f"prompt=\n{prompt}") + logger.info(f"{len(images)=}") + descriptions_json = driver.prompt( + prompt, + system_prompt, + images, ) - descriptions = [] - for masked_images_batch in masked_images_batches: - descriptions_batch = prompt_for_descriptions( + descriptions = utils.parse_code_snippet(descriptions_json)["descriptions"] + logger.info(f"{descriptions=}") + try: + assert len(descriptions) == len(masked_images), ( + len(descriptions), + len(masked_images), + ) + except Exception as exc: + exceptions = exceptions or [] + exceptions.append(exc) + logger.info(f"exceptions=\n{pformat(exceptions)}") + return prompt_for_descriptions( original_image, - masked_images_batch, + masked_images, active_segment_description, exceptions, ) - descriptions += descriptions_batch - return descriptions - images = [original_image] + masked_images - system_prompt = utils.render_template_from_file( - "prompts/system.j2", - ) - logger.info(f"system_prompt=\n{system_prompt}") - num_segments = len(masked_images) - prompt = utils.render_template_from_file( - "prompts/description.j2", - active_segment_description=active_segment_description, - num_segments=num_segments, - exceptions=exceptions, - ) - logger.info(f"prompt=\n{prompt}") - logger.info(f"{len(images)=}") - descriptions_json = prompt_adapter.prompt( - prompt, - system_prompt, - images, - ) - descriptions = utils.parse_code_snippet(descriptions_json)["descriptions"] - logger.info(f"{descriptions=}") - try: - assert len(descriptions) == len(masked_images), ( - len(descriptions), - len(masked_images), - ) - except Exception as exc: - exceptions = exceptions or [] - exceptions.append(exc) - logger.info(f"exceptions=\n{pformat(exceptions)}") - return prompt_for_descriptions( - original_image, - masked_images, - active_segment_description, - exceptions, - ) - - # remove indexes - descriptions = [desc for idx, desc in descriptions] - return descriptions + # remove indexes + descriptions = [desc for idx, desc in descriptions] + return descriptions diff --git a/openadapt/utils.py b/openadapt/utils.py index 95abdc7bd..80c8f1b1d 100644 --- a/openadapt/utils.py +++ b/openadapt/utils.py @@ -606,6 +606,13 @@ def parse_code_snippet(snippet: str) -> dict: python_code = snippet.replace("```python\n", "").replace("```", "").strip() return ast.literal_eval(python_code) else: + # XXX this may loop forever + # TODO make sure to only do this once (e.g. before?) + processed_snippet = extract_code_block(snippet) + import ipdb + + ipdb.set_trace() + return parse_code_snippet(processed_snippet) msg = f"Unsupported {snippet=}" logger.warning(msg) return None @@ -614,6 +621,40 @@ def parse_code_snippet(snippet: str) -> dict: raise exc +def extract_code_block(text: str) -> str: + """Extract the text enclosed by the outermost backticks. + + Includes the backticks themselves. + + Args: + text (str): The input text containing potential code blocks enclosed by + backticks. + + Returns: + str: The text enclosed by the outermost backticks, or an empty string + if no complete block is found. + + Raises: + ValueError: If the number of backtick lines is uneven. + """ + backticks = "```" + lines = text.splitlines() + backtick_idxs = [ + idx for idx, line in enumerate(lines) if line.startswith(backticks) + ] + + if len(backtick_idxs) % 2 != 0: + raise ValueError("Uneven number of backtick lines") + + if len(backtick_idxs) < 2: + return "" # No enclosing backticks found, return empty string + + # Extract only the lines between the first and last backtick line, + # including the backticks + start_idx, end_idx = backtick_idxs[0], backtick_idxs[-1] + return "\n".join(lines[start_idx : end_idx + 1]) + + def split_list(input_list: list, size: int) -> list[list]: """Splits a list into a list of lists, where each inner list has a maximum size. diff --git a/openadapt/video.py b/openadapt/video.py index 18d9a3b03..be5f9970e 100644 --- a/openadapt/video.py +++ b/openadapt/video.py @@ -313,7 +313,13 @@ def extract_frames( video_container.close() - logger.info(f"frame_differences=\n{pformat(frame_differences)}") + logger.debug(f"frame_differences=\n{pformat(frame_differences)}") + invalid_frame_differences = { + timestamp: difference + for timestamp, difference in frame_differences.items() + if difference > tolerance + } + logger.info(f"invalid_frame_differences=\n{pformat(invalid_frame_differences)}") # Check if all timestamps have been matched for timestamp, frame in timestamp_frames.items(): diff --git a/openadapt/window/__init__.py b/openadapt/window/__init__.py index 5755633c3..99b42ec02 100644 --- a/openadapt/window/__init__.py +++ b/openadapt/window/__init__.py @@ -24,9 +24,12 @@ def get_active_window_data( ) -> dict[str, Any] | None: """Get data of the active window. + Args: + include_window_data (bool): whether to include a11y data. + Returns: dict or None: A dictionary containing information about the active window, - or None if the state is not available. + or None if the state is not available. """ state = get_active_window_state(include_window_data) if not state: diff --git a/tests/openadapt/adapters/test_anthropic.py b/tests/openadapt/drivers/test_anthropic.py similarity index 73% rename from tests/openadapt/adapters/test_anthropic.py rename to tests/openadapt/drivers/test_anthropic.py index c2e095051..c69d1ff9e 100644 --- a/tests/openadapt/adapters/test_anthropic.py +++ b/tests/openadapt/drivers/test_anthropic.py @@ -1,18 +1,18 @@ -"""Tests for adapters.anthropic.""" +"""Tests for drivers.anthropic.""" from PIL import Image import pytest import anthropic -from openadapt import adapters +from openadapt import drivers def test_prompt(calculator_image: Image) -> None: """Test image prompt.""" prompt = "What is this a screenshot of?" try: - result = adapters.anthropic.prompt(prompt, images=[calculator_image]) + result = drivers.anthropic.prompt(prompt, images=[calculator_image]) assert "calculator" in result.lower(), result except anthropic.AuthenticationError as e: pytest.xfail(f"Anthropic AuthenticationError occurred: {e}") diff --git a/tests/openadapt/adapters/test_google.py b/tests/openadapt/drivers/test_google.py similarity index 89% rename from tests/openadapt/adapters/test_google.py rename to tests/openadapt/drivers/test_google.py index cfcf4a448..e93648ab9 100644 --- a/tests/openadapt/adapters/test_google.py +++ b/tests/openadapt/drivers/test_google.py @@ -1,10 +1,10 @@ -"""Tests for adapters.google.""" +"""Tests for drivers.google.""" from google.api_core.exceptions import DeadlineExceeded, InvalidArgument from PIL import Image import pytest -from openadapt.adapters import google +from openadapt.drivers import google def test_prompt(calculator_image: Image) -> None: diff --git a/tests/openadapt/adapters/test_openai.py b/tests/openadapt/drivers/test_openai.py similarity index 86% rename from tests/openadapt/adapters/test_openai.py rename to tests/openadapt/drivers/test_openai.py index 574683d3e..14cb7b639 100644 --- a/tests/openadapt/adapters/test_openai.py +++ b/tests/openadapt/drivers/test_openai.py @@ -1,11 +1,11 @@ -"""Tests for adapters.openai.""" +"""Tests for drivers.openai.""" import pytest from PIL import Image import requests -from openadapt.adapters import openai +from openadapt.drivers import openai def test_prompt(calculator_image: Image) -> None: @@ -14,7 +14,7 @@ def test_prompt(calculator_image: Image) -> None: try: result = openai.prompt(prompt, images=[calculator_image]) assert "calculator" in result.lower(), result - except ValueError as e: + except Exception as e: if "Incorrect API key" in str(e): pytest.xfail(f"ValueError due to incorrect API key: {e}") else: