Skip to content

Commit

Permalink
fix: move playwright to thread to avoid event loop conflicts
Browse files Browse the repository at this point in the history
Resolves #351 by moving Playwright to a separate thread to isolate its event loop
from prompt_toolkit. This prevents the asyncio.run() error that occurred when
trying to use prompt_toolkit after browser operations.

Changes:
- Created thread-based browser manager
- Updated all browser operations to use the thread
- Added proper timeout and error handling

Co-authored-by: Bob <[email protected]>
  • Loading branch information
ErikBjare and TimeToBuildBob committed Dec 18, 2024
1 parent 0ecf045 commit 5673f4d
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 52 deletions.
113 changes: 61 additions & 52 deletions gptme/tools/_browser_playwright.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,90 +9,90 @@
from dataclasses import dataclass
from pathlib import Path

from playwright.sync_api import (
ElementHandle,
Geolocation,
Page,
Playwright,
sync_playwright,
)

_p: Playwright | None = None
logger = logging.getLogger(__name__)

from playwright.sync_api import Browser, ElementHandle
from ._browser_thread import BrowserThread

def get_browser():
"""
Return a browser object.
"""
global _p
if _p is None:
logger.info("Starting browser")
_p = sync_playwright().start()
_browser: BrowserThread | None = None
logger = logging.getLogger(__name__)

atexit.register(_p.stop)
browser = _p.chromium.launch()
return browser

def get_browser() -> BrowserThread:
global _browser
if _browser is None:
logger.info("Starting browser thread")
_browser = BrowserThread()
atexit.register(_browser.stop)
return _browser

def load_page(url: str) -> Page:
browser = get_browser()

# set browser language to English such that Google uses English
coords_sf: Geolocation = {"latitude": 37.773972, "longitude": 13.39}
def _load_page(browser: Browser, url: str) -> str:
"""Load a page and return its body HTML"""
context = browser.new_context(
locale="en-US",
geolocation=coords_sf,
geolocation={"latitude": 37.773972, "longitude": 13.39},
permissions=["geolocation"],
)

# create a new page
logger.info(f"Loading page: {url}")
page = context.new_page()
page.goto(url)

return page
return page.inner_html("body")


def read_url(url: str) -> str:
"""Read the text of a webpage and return the text in Markdown format."""
page = load_page(url)

# Get the HTML of the body
body_html = page.inner_html("body")

# Convert the HTML to Markdown
markdown = html_to_markdown(body_html)

return markdown
browser = get_browser()
body_html = browser.execute(_load_page, url)
return html_to_markdown(body_html)


def search_google(query: str) -> str:
def _search_google(browser: Browser, query: str) -> str:
query = urllib.parse.quote(query)
url = f"https://www.google.com/search?q={query}&hl=en"
page = load_page(url)

context = browser.new_context(
locale="en-US",
geolocation={"latitude": 37.773972, "longitude": 13.39},
permissions=["geolocation"],
)
page = context.new_page()
page.goto(url)

els = _list_clickable_elements(page)
for el in els:
# print(f"{el['type']}: {el['text']}")
if "Accept all" in el.text:
el.element.click()
logger.debug("Accepted Google terms")
break

# list results
result_str = _list_results_google(page)
return _list_results_google(page)

return result_str

def search_google(query: str) -> str:
browser = get_browser()
return browser.execute(_search_google, query)

def search_duckduckgo(query: str) -> str:

def _search_duckduckgo(browser: Browser, query: str) -> str:
url = f"https://duckduckgo.com/?q={query}"
page = load_page(url)

context = browser.new_context(
locale="en-US",
geolocation={"latitude": 37.773972, "longitude": 13.39},
permissions=["geolocation"],
)
page = context.new_page()
page.goto(url)

return _list_results_duckduckgo(page)


def search_duckduckgo(query: str) -> str:
browser = get_browser()
return browser.execute(_search_duckduckgo, query)


@dataclass
class Element:
type: str
Expand Down Expand Up @@ -190,24 +190,33 @@ def _list_results_duckduckgo(page) -> str:
return titleurl_to_list(hits)


def screenshot_url(url: str, path: Path | str | None = None) -> Path:
def _take_screenshot(
browser: Browser, url: str, path: Path | str | None = None
) -> Path:
"""Take a screenshot of a webpage and save it to a file."""
logger.info(f"Taking screenshot of '{url}' and saving to '{path}'")
page = load_page(url)

if path is None:
path = tempfile.mktemp(suffix=".png")
else:
# create the directory if it doesn't exist
os.makedirs(os.path.dirname(path), exist_ok=True)

# Take the screenshot
context = browser.new_context()
page = context.new_page()
page.goto(url)
page.screenshot(path=path)

print(f"Screenshot saved to {path}")
return Path(path)


def screenshot_url(url: str, path: Path | str | None = None) -> Path:
"""Take a screenshot of a webpage and save it to a file."""
logger.info(f"Taking screenshot of '{url}' and saving to '{path}'")
browser = get_browser()
path = browser.execute(_take_screenshot, url, path)
print(f"Screenshot saved to {path}")
return path


def html_to_markdown(html):
# check that pandoc is installed
if not shutil.which("pandoc"):
Expand Down
102 changes: 102 additions & 0 deletions gptme/tools/_browser_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import logging
import time
from collections.abc import Callable
from dataclasses import dataclass
from queue import Empty, Queue
from threading import Event, Lock, Thread
from typing import Any, Literal, TypeVar

from playwright.sync_api import sync_playwright

logger = logging.getLogger(__name__)

T = TypeVar("T")

TIMEOUT = 30 # seconds


@dataclass
class Command:
func: Callable
args: tuple
kwargs: dict


Action = Literal["stop"]


class BrowserThread:
def __init__(self):
self.queue: Queue[tuple[Command | Action, object]] = Queue()
self.results: dict[object, tuple[Any, Exception | None]] = {}
self.lock = Lock()
self.ready = Event()
self.thread = Thread(target=self._run, daemon=True)
self.thread.start()
# Wait for browser to be ready
if not self.ready.wait(timeout=TIMEOUT):
raise TimeoutError("Browser failed to start")
logger.info("Browser thread started")

def _run(self):
try:
playwright = sync_playwright().start()
browser = playwright.chromium.launch()
logger.info("Browser launched")
self.ready.set()

while True:
try:
cmd, cmd_id = self.queue.get(timeout=1.0)
if cmd == "stop":
break

try:
result = cmd.func(browser, *cmd.args, **cmd.kwargs)
with self.lock:
self.results[cmd_id] = (result, None)
except Exception as e:
logger.exception("Error in browser thread")
with self.lock:
self.results[cmd_id] = (None, e)
except Empty:
# Timeout on queue.get, continue waiting
continue
except Exception:
logger.exception("Fatal error in browser thread")
self.ready.set() # Prevent hanging in __init__
raise
finally:
try:
browser.close()
playwright.stop()
except Exception:
logger.exception("Error stopping browser")
logger.info("Browser stopped")

def execute(self, func: Callable[..., T], *args, **kwargs) -> T:
if not self.thread.is_alive():
raise RuntimeError("Browser thread died")

cmd_id = object() # unique id
self.queue.put((Command(func, args, kwargs), cmd_id))

deadline = time.monotonic() + TIMEOUT
while time.monotonic() < deadline:
with self.lock:
if cmd_id in self.results:
result, error = self.results.pop(cmd_id)
if error:
raise error
return result
time.sleep(0.1) # Prevent busy-waiting

raise TimeoutError(f"Browser operation timed out after {TIMEOUT}s")

def stop(self):
"""Stop the browser thread"""
try:
self.queue.put(("stop", object()))
self.thread.join(timeout=TIMEOUT)
except Exception:
logger.exception("Error stopping browser thread")

0 comments on commit 5673f4d

Please sign in to comment.