Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
jesicasusanto committed Jul 18, 2023
1 parent 0c6283b commit 1027546
Showing 1 changed file with 56 additions and 46 deletions.
102 changes: 56 additions & 46 deletions openadapt/strategies/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from copy import deepcopy
from pprint import pformat
#import datetime

# import datetime

from loguru import logger
import deepdiff
Expand All @@ -25,7 +26,6 @@ class StatefulReplayStrategy(
OpenAIReplayStrategyMixin,
strategies.base.BaseReplayStrategy,
):

def __init__(
self,
recording: models.Recording,
Expand All @@ -39,8 +39,7 @@ def __init__(
for action_event in self.recording.processed_action_events
][:-1]
self.recording_action_diff_tups = zip(
self.recording_window_state_diffs,
self.recording_action_strs
self.recording_window_state_diffs, self.recording_action_strs
)
self.recording_action_idx = 0

Expand All @@ -52,53 +51,63 @@ def get_next_action_event(
logger.debug(f"{self.recording_action_idx=}")
if self.recording_action_idx == len(self.recording.processed_action_events):
raise StopIteration()
reference_action = (
self.recording.processed_action_events[self.recording_action_idx]
)
reference_action = self.recording.processed_action_events[
self.recording_action_idx
]
reference_window = reference_action.window_event

reference_window_dict = deepcopy({
key: val
for key, val in utils.row2dict(reference_window, follow=False).items()
if val is not None
and not key.endswith("timestamp")
and not key.endswith("id")
#and not isinstance(getattr(models.WindowEvent, key), property)
})
reference_window_dict = deepcopy(
{
key: val
for key, val in utils.row2dict(reference_window, follow=False).items()
if val is not None
and not key.endswith("timestamp")
and not key.endswith("id")
# and not isinstance(getattr(models.WindowEvent, key), property)
}
)
if reference_action.children:
reference_action_dicts = [
deepcopy({
key: val
for key, val in utils.row2dict(child, follow=False).items()
if val is not None
and not key.endswith("timestamp")
and not key.endswith("id")
and not isinstance(getattr(models.ActionEvent, key), property)
})
deepcopy(
{
key: val
for key, val in utils.row2dict(child, follow=False).items()
if val is not None
and not key.endswith("timestamp")
and not key.endswith("id")
and not isinstance(getattr(models.ActionEvent, key), property)
}
)
for child in reference_action.children
]
else:
reference_action_dicts = [
deepcopy({
key: val
for key, val in utils.row2dict(reference_action, follow=False).items()
if val is not None
and not key.endswith("timestamp")
and not key.endswith("id")
#and not isinstance(getattr(models.ActionEvent, key), property)
})
deepcopy(
{
key: val
for key, val in utils.row2dict(
reference_action, follow=False
).items()
if val is not None
and not key.endswith("timestamp")
and not key.endswith("id")
# and not isinstance(getattr(models.ActionEvent, key), property)
}
)
]
active_window_dict = deepcopy({
key: val
for key, val in utils.row2dict(active_window, follow=False).items()
if val is not None
and not key.endswith("timestamp")
and not key.endswith("id")
#and not isinstance(getattr(models.WindowEvent, key), property)
})
active_window_dict = deepcopy(
{
key: val
for key, val in utils.row2dict(active_window, follow=False).items()
if val is not None
and not key.endswith("timestamp")
and not key.endswith("id")
# and not isinstance(getattr(models.WindowEvent, key), property)
}
)
if reference_window_dict and "state" in reference_window_dict:
reference_window_dict["state"].pop("data")
if active_window_dict and "state" in active_window_dict :
if active_window_dict and "state" in active_window_dict:
active_window_dict["state"].pop("data")

prompt = (
Expand Down Expand Up @@ -145,16 +154,16 @@ def get_window_state_diffs(
ignore_window_ids = set()
if ignore_boundary_windows:
first_window_event = action_events[0].window_event
if first_window_event.state :
if first_window_event.state:
first_window_id = first_window_event.state["window_id"]
else :
else:
first_window_id = None
first_window_title = first_window_event.title
last_window_event = action_events[-1].window_event
if last_window_event.state :
if last_window_event.state:
last_window_id = last_window_event.state["window_id"]
last_window_title = last_window_event.title
else :
else:
last_window_id = None
last_window_title = last_window_event.title
if first_window_id != last_window_id:
Expand All @@ -164,7 +173,8 @@ def get_window_state_diffs(
logger.info(f"ignoring {first_window_title=} {last_window_title=}")
window_event_states = [
action_event.window_event.state
if action_event.window_event.state is not None and action_event.window_event.state["window_id"] not in ignore_window_ids
if action_event.window_event.state is not None
and action_event.window_event.state["window_id"] not in ignore_window_ids
else {}
for action_event in action_events
]
Expand All @@ -174,4 +184,4 @@ def get_window_state_diffs(
window_event_states, window_event_states[1:]
)
]
return diffs
return diffs

0 comments on commit 1027546

Please sign in to comment.