diff --git a/puterbot/models.py b/puterbot/models.py index bb9ec0ea1..bea995e9d 100644 --- a/puterbot/models.py +++ b/puterbot/models.py @@ -3,9 +3,11 @@ from loguru import logger from pynput import keyboard from PIL import Image, ImageChops +import numpy as np import sqlalchemy as sa from puterbot.db import Base +from puterbot.utils import take_screenshot class Recording(Base): @@ -136,6 +138,9 @@ class Screenshot(Base): png_data = sa.Column(sa.LargeBinary) # TODO: replace prev with prev_timestamp? + # TODO: convert to png_data on save + sct_img = None + prev = None _image = None _diff = None @@ -144,8 +149,17 @@ class Screenshot(Base): @property def image(self): if not self._image: - buffer = io.BytesIO(self.png_data) - self._image = Image.open(buffer) + if self.sct_img: + self._image = Image.frombytes( + "RGB", + self.sct_img.size, + self.sct_img.bgra, + "raw", + "BGRX", + ) + else: + buffer = io.BytesIO(self.png_data) + self._image = Image.open(buffer) return self._image @property @@ -161,6 +175,16 @@ def diff_mask(self): self._diff_mask = self._diff.convert("1") return self._diff_mask + @property + def array(self): + return np.array(self.image) + + @classmethod + def take_screenshot(cls): + sct_img = take_screenshot() + screenshot = Screenshot(sct_img=sct_img) + return screenshot + class WindowEvent(Base): __tablename__ = "window_event" diff --git a/puterbot/strategies/base.py b/puterbot/strategies/base.py index 8f159593e..c44bbf869 100644 --- a/puterbot/strategies/base.py +++ b/puterbot/strategies/base.py @@ -10,9 +10,8 @@ import mss.base import numpy as np -from puterbot.models import Recording, InputEvent +from puterbot.models import InputEvent, Recording, Screenshot from puterbot.playback import play_input_event -from puterbot.utils import get_screenshot MAX_FRAME_TIMES = 1000 @@ -34,7 +33,7 @@ def __init__( @abstractmethod def get_next_input_event( self, - screenshot: mss.base.ScreenShot, + screenshot: Screenshot, ) -> InputEvent: pass @@ -42,7 +41,7 @@ def run(self): keyboard_controller = keyboard.Controller() mouse_controller = mouse.Controller() while True: - screenshot = get_screenshot() + screenshot = Screenshot.take_screenshot() self.screenshots.append(screenshot) try: input_event = self.get_next_input_event(screenshot) @@ -63,7 +62,7 @@ def log_fps(self): dts = np.diff(self.frame_times) if len(dts) > 1: mean_dt = np.mean(dts) - fps = len(dts) / mean_dt + fps = 1 / mean_dt logger.info(f"{fps=:.2f}") if len(self.frame_times) > self.max_frame_times: self.frame_times.pop(0) diff --git a/puterbot/strategies/ocr_mixin.py b/puterbot/strategies/ocr_mixin.py index e55a82c4e..57d06869a 100644 --- a/puterbot/strategies/ocr_mixin.py +++ b/puterbot/strategies/ocr_mixin.py @@ -16,11 +16,10 @@ class MyReplayStrategy(OCRReplayStrategyMixin): from PIL import Image from rapidocr_onnxruntime import RapidOCR from sklearn.cluster import DBSCAN -import mss.base import numpy as np import pandas as pd -from puterbot.models import Recording +from puterbot.models import Recording, Screenshot from puterbot.strategies.base import BaseReplayStrategy @@ -39,14 +38,10 @@ def __init__( def get_text( self, - screenshot: mss.base.ScreenShot + screenshot: Screenshot ): # TOOD: improve performance - image = Image.frombytes( - "RGB", screenshot.size, screenshot.bgra, "raw", "BGRX" - ) - arr = np.array(image) - result, elapse = self.ocr(arr) + result, elapse = self.ocr(screenshot.array) #det_elapse, cls_elapse, rec_elapse = elapse #all_elapse = det_elapse + cls_elapse + rec_elapse logger.debug(f"{result=}") diff --git a/puterbot/utils.py b/puterbot/utils.py index c0ed9c5cb..ca1869afd 100644 --- a/puterbot/utils.py +++ b/puterbot/utils.py @@ -363,12 +363,12 @@ def evenly_spaced(arr, N): return [val for idx, val in enumerate(arr) if idx in idxs] -def get_screenshot() -> mss.base.ScreenShot: +def take_screenshot() -> mss.base.ScreenShot: with mss.mss() as sct: # monitor 0 is all in one monitor = sct.monitors[0] - screenshot = sct.grab(monitor) - return screenshot + sct_img = sct.grab(monitor) + return sct_img def get_strategy_class_by_name():