Skip to content

Commit

Permalink
Improve concurrency on Windows (#286)
Browse files Browse the repository at this point in the history
* Get future result in a callback

* POC with asyncio queue

* Clean up

* Update windows test wrapper

* WindowsProactorEventLoopPolicy on Windows

* Test setting future exception

* Update docstring

* Different errors depending on the platform
  • Loading branch information
elacuesta authored Jul 3, 2024
1 parent 5926a20 commit 164ff9a
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 32 deletions.
75 changes: 53 additions & 22 deletions scrapy_playwright/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import concurrent
import logging
import platform
import threading
Expand Down Expand Up @@ -98,34 +97,66 @@ async def _get_header_value(

if platform.system() == "Windows":

class _WindowsAdapter:
"""Utility class to redirect coroutines to an asyncio event loop running
in a different thread. This allows to use a ProactorEventLoop, which is
supported by Playwright on Windows.
class _ThreadedLoopAdapter:
"""Utility class to start an asyncio event loop in a new thread and redirect coroutines.
This allows to run Playwright in a different loop than the Scrapy crawler, allowing to
use ProactorEventLoop which is supported by Playwright on Windows.
"""

loop = None
thread = None
_loop: asyncio.AbstractEventLoop
_thread: threading.Thread
_coro_queue: asyncio.Queue = asyncio.Queue()
_stop_event: asyncio.Event = asyncio.Event()

@classmethod
def get_event_loop(cls) -> asyncio.AbstractEventLoop:
if cls.thread is None:
if cls.loop is None:
policy = asyncio.WindowsProactorEventLoopPolicy() # type: ignore
cls.loop = policy.new_event_loop()
asyncio.set_event_loop(cls.loop)
if not cls.loop.is_running():
cls.thread = threading.Thread(target=cls.loop.run_forever, daemon=True)
cls.thread.start()
logger.info("Started loop on separate thread: %s", cls.loop)
return cls.loop
async def _handle_coro(cls, coro, future) -> None:
try:
future.set_result(await coro)
except Exception as exc:
future.set_exception(exc)

@classmethod
async def get_result(cls, coro) -> concurrent.futures.Future:
return asyncio.run_coroutine_threadsafe(coro=coro, loop=cls.get_event_loop()).result()
async def _process_queue(cls) -> None:
while not cls._stop_event.is_set():
coro, future = await cls._coro_queue.get()
asyncio.create_task(cls._handle_coro(coro, future))
cls._coro_queue.task_done()

def _deferred_from_coro(coro) -> Deferred:
return scrapy.utils.defer.deferred_from_coro(_WindowsAdapter.get_result(coro))
@classmethod
def _deferred_from_coro(cls, coro) -> Deferred:
future: asyncio.Future = asyncio.Future()
asyncio.run_coroutine_threadsafe(cls._coro_queue.put((coro, future)), cls._loop)
return scrapy.utils.defer.deferred_from_coro(future)

@classmethod
def start(cls) -> None:
policy = asyncio.WindowsProactorEventLoopPolicy() # type: ignore[attr-defined]
cls._loop = policy.new_event_loop()
asyncio.set_event_loop(cls._loop)

cls._thread = threading.Thread(target=cls._loop.run_forever, daemon=True)
cls._thread.start()
logger.info("Started loop on separate thread: %s", cls._loop)

asyncio.run_coroutine_threadsafe(cls._process_queue(), cls._loop)

@classmethod
def stop(cls) -> None:
cls._stop_event.set()
asyncio.run_coroutine_threadsafe(cls._coro_queue.join(), cls._loop)
cls._loop.call_soon_threadsafe(cls._loop.stop)
cls._thread.join()

_deferred_from_coro = _ThreadedLoopAdapter._deferred_from_coro
else:

class _ThreadedLoopAdapter: # type: ignore[no-redef]
@classmethod
def start(cls) -> None:
pass

@classmethod
def stop(cls) -> None:
pass

_deferred_from_coro = scrapy.utils.defer.deferred_from_coro
5 changes: 4 additions & 1 deletion scrapy_playwright/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
from scrapy_playwright.headers import use_scrapy_headers
from scrapy_playwright.page import PageMethod
from scrapy_playwright._utils import (
_ThreadedLoopAdapter,
_deferred_from_coro,
_encode_body,
_get_header_value,
_get_page_content,
_is_safe_close_error,
_maybe_await,
_deferred_from_coro,
)


Expand Down Expand Up @@ -102,6 +103,7 @@ class ScrapyPlaywrightDownloadHandler(HTTPDownloadHandler):

def __init__(self, crawler: Crawler) -> None:
super().__init__(settings=crawler.settings, crawler=crawler)
_ThreadedLoopAdapter.start()
if platform.system() != "Windows":
verify_installed_reactor("twisted.internet.asyncioreactor.AsyncioSelectorReactor")
crawler.signals.connect(self._engine_started, signals.engine_started)
Expand Down Expand Up @@ -293,6 +295,7 @@ def close(self) -> Deferred:
logger.info("Closing download handler")
yield super().close()
yield _deferred_from_coro(self._close())
_ThreadedLoopAdapter.stop()

async def _close(self) -> None:
await asyncio.gather(*[ctx.context.close() for ctx in self.context_wrappers.values()])
Expand Down
11 changes: 7 additions & 4 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import logging
import platform
Expand All @@ -13,17 +14,19 @@


if platform.system() == "Windows":
from scrapy_playwright._utils import _WindowsAdapter
from scrapy_playwright._utils import _ThreadedLoopAdapter

def allow_windows(test_method):
"""Wrap tests with the _WindowsAdapter class on Windows."""
"""Wrap tests with the _ThreadedLoopAdapter class on Windows."""
if not inspect.iscoroutinefunction(test_method):
raise RuntimeError(f"{test_method} must be an async def method")

@wraps(test_method)
async def wrapped(self, *args, **kwargs):
logger.debug("Calling _WindowsAdapter.get_result for %r", self)
await _WindowsAdapter.get_result(test_method(self, *args, **kwargs))
_ThreadedLoopAdapter.start()
coro = test_method(self, *args, **kwargs)
asyncio.run_coroutine_threadsafe(coro=coro, loop=_ThreadedLoopAdapter._loop).result()
_ThreadedLoopAdapter.stop()

return wrapped

Expand Down
26 changes: 21 additions & 5 deletions tests/tests_twisted/test_mixed_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ class MixedRequestsTestCase(TestCase):
'_download_request', which is a coroutine ('download_request' returns a Deferred).
"""

timeout_ms = 500

@defer.inlineCallbacks
def setUp(self):
self.server = StaticMockServer()
self.server.__enter__()
self.handler = ScrapyPlaywrightDownloadHandler.from_crawler(get_crawler())
self.handler = ScrapyPlaywrightDownloadHandler.from_crawler(
get_crawler(settings_dict={"PLAYWRIGHT_DEFAULT_NAVIGATION_TIMEOUT": self.timeout_ms})
)
yield self.handler._engine_started()

@defer.inlineCallbacks
Expand All @@ -29,26 +33,38 @@ def tearDown(self):

@defer.inlineCallbacks
def test_download_request(self):
def _test_regular(response, request):
def _check_regular(response, request):
self.assertIsInstance(response, Response)
self.assertEqual(response.css("a::text").getall(), ["Lorem Ipsum", "Infinite Scroll"])
self.assertEqual(response.url, request.url)
self.assertEqual(response.status, 200)
self.assertNotIn("playwright", response.flags)

def _test_playwright(response, request):
def _check_playwright_ok(response, request):
self.assertIsInstance(response, Response)
self.assertEqual(response.css("a::text").getall(), ["Lorem Ipsum", "Infinite Scroll"])
self.assertEqual(response.url, request.url)
self.assertEqual(response.status, 200)
self.assertIn("playwright", response.flags)

def _check_playwright_error(failure, url):
# different errors depending on the platform
self.assertTrue(
f"Page.goto: net::ERR_CONNECTION_REFUSED at {url}" in str(failure.value)
or f"Page.goto: Timeout {self.timeout_ms}ms exceeded" in str(failure.value)
)

req1 = Request(self.server.urljoin("/index.html"))
yield self.handler.download_request(req1, Spider("foo")).addCallback(
_test_regular, request=req1
_check_regular, request=req1
)

req2 = Request(self.server.urljoin("/index.html"), meta={"playwright": True})
yield self.handler.download_request(req2, Spider("foo")).addCallback(
_test_playwright, request=req2
_check_playwright_ok, request=req2
)

req3 = Request("http://localhost:12345/asdf", meta={"playwright": True})
yield self.handler.download_request(req3, Spider("foo")).addErrback(
_check_playwright_error, url=req3.url
)

0 comments on commit 164ff9a

Please sign in to comment.