Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: improve the mock server experience #1602

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ jobs:
- '3.6'
- 'pypy3.10'
env:
PYTHON_SLACK_SDK_MOCK_SERVER_MODE: 'threading'
CI_LARGE_SOCKET_MODE_PAYLOAD_TESTING_DISABLED: '1'
CI_UNSTABLE_TESTS_SKIP_ENABLED: '1'
FORCE_COLOR: '1'
Expand Down Expand Up @@ -59,3 +58,4 @@ jobs:
token: ${{ secrets.CODECOV_TOKEN }}
# python setup.py validate generates the coverage file
files: ./coverage.xml

15 changes: 0 additions & 15 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,5 @@ def restore_os_env(old_env: dict) -> None:
os.environ.update(old_env)


def get_mock_server_mode() -> str:
"""Returns a str representing the mode.

:return: threading/multiprocessing
"""
mode = os.environ.get("PYTHON_SLACK_SDK_MOCK_SERVER_MODE")
if mode is None:
# We used to use "multiprocessing"" for macOS until Big Sur 11.1
# Since 11.1, the "multiprocessing" mode started failing a lot...
# Therefore, we switched the default mode back to "threading".
return "threading"
else:
return mode


def is_ci_unstable_test_skip_enabled() -> bool:
return os.environ.get("CI_UNSTABLE_TESTS_SKIP_ENABLED") == "1"
87 changes: 87 additions & 0 deletions tests/mock_web_api_server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import asyncio
from http.server import SimpleHTTPRequestHandler
from queue import Queue
import threading
import time
from typing import Type
from unittest import TestCase

from tests.mock_web_api_server.received_requests import ReceivedRequests
from tests.mock_web_api_server.mock_server_thread import MockServerThread


def setup_mock_web_api_server(test: TestCase, handler: Type[SimpleHTTPRequestHandler], port: int = 8888):
test.server_started = threading.Event()
test.received_requests = ReceivedRequests(Queue())
test.thread = MockServerThread(queue=test.received_requests.queue, test=test, handler=handler, port=port)
test.thread.start()
test.server_started.wait()


def cleanup_mock_web_api_server(test: TestCase):
test.thread.stop()
test.thread = None


def assert_received_request_count(test: TestCase, path: str, min_count: int, timeout: float = 1):
start_time = time.time()
error = None
while time.time() - start_time < timeout:
try:
received_count = test.received_requests.get(path, 0)
assert (
received_count == min_count
), f"Expected {min_count} '{path}' {'requests' if min_count > 1 else 'request'}, but got {received_count}!"
return
except Exception as e:
error = e
# waiting for some requests to be received
time.sleep(0.05)

if error is not None:
raise error


def assert_auth_test_count(test: TestCase, expected_count: int):
assert_received_request_count(test, "/auth.test", expected_count, 0.5)


#########
# async #
#########


def setup_mock_web_api_server_async(test: TestCase, handler: Type[SimpleHTTPRequestHandler], port: int = 8888):
test.server_started = threading.Event()
test.received_requests = ReceivedRequests(asyncio.Queue())
test.thread = MockServerThread(queue=test.received_requests.queue, test=test, handler=handler, port=port)
test.thread.start()
test.server_started.wait()


def cleanup_mock_web_api_server_async(test: TestCase):
test.thread.stop_unsafe()
test.thread = None


async def assert_received_request_count_async(test: TestCase, path: str, min_count: int, timeout: float = 1):
start_time = time.time()
error = None
while time.time() - start_time < timeout:
try:
received_count = await test.received_requests.get_async(path, 0)
assert (
received_count == min_count
), f"Expected {min_count} '{path}' {'requests' if min_count > 1 else 'request'}, but got {received_count}!"
return
except Exception as e:
error = e
# waiting for mock_received_requests updates
await asyncio.sleep(0.05)

if error is not None:
raise error


async def assert_auth_test_count_async(test: TestCase, expected_count: int):
await assert_received_request_count_async(test, "/auth.test", expected_count, 0.5)
41 changes: 41 additions & 0 deletions tests/mock_web_api_server/mock_server_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from asyncio import Queue
import asyncio
from http.server import HTTPServer, SimpleHTTPRequestHandler
import threading
from typing import Type, Union
from unittest import TestCase


class MockServerThread(threading.Thread):
def __init__(
self, queue: Union[Queue, asyncio.Queue], test: TestCase, handler: Type[SimpleHTTPRequestHandler], port: int = 8888
):
threading.Thread.__init__(self)
self.handler = handler
self.test = test
self.queue = queue
self.port = port

def run(self):
self.server = HTTPServer(("localhost", self.port), self.handler)
self.server.queue = self.queue
self.test.server_url = f"http://localhost:{str(self.port)}"
self.test.host, self.test.port = self.server.socket.getsockname()
self.test.server_started.set() # threading.Event()

self.test = None
try:
self.server.serve_forever(0.05)
finally:
self.server.server_close()

def stop(self):
with self.server.queue.mutex:
del self.server.queue
self.server.shutdown()
self.join()

def stop_unsafe(self):
del self.server.queue
self.server.shutdown()
self.join()
21 changes: 21 additions & 0 deletions tests/mock_web_api_server/received_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import asyncio
from queue import Queue
from typing import Optional, Union


class ReceivedRequests:
def __init__(self, queue: Union[Queue, asyncio.Queue]):
self.queue = queue
self.received_requests: dict = {}

def get(self, key: str, default: Optional[int] = None) -> Optional[int]:
while not self.queue.empty():
path = self.queue.get()
self.received_requests[path] = self.received_requests.get(path, 0) + 1
return self.received_requests.get(key, default)

async def get_async(self, key: str, default: Optional[int] = None) -> Optional[int]:
while not self.queue.empty():
path = await self.queue.get()
self.received_requests[path] = self.received_requests.get(path, 0) + 1
return self.received_requests.get(key, default)
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
import json
import logging
import re
import sys
import threading
import time
from http import HTTPStatus
from http.server import HTTPServer, SimpleHTTPRequestHandler
from multiprocessing.context import Process
from typing import Type
from unittest import TestCase
from urllib.request import Request, urlopen

from tests.helpers import get_mock_server_mode
from http.server import SimpleHTTPRequestHandler


class MockHandler(SimpleHTTPRequestHandler):
Expand All @@ -32,6 +24,8 @@ def set_common_headers(self):
self.end_headers()

def do_GET(self):
# put_nowait is common between Queue & asyncio.Queue, it does not need to be awaited
self.server.queue.put_nowait(self.path)
if self.path == "/received_requests.json":
self.send_response(200)
self.set_common_headers()
Expand Down Expand Up @@ -97,123 +91,3 @@ def do_GET(self):
except Exception as e:
self.logger.error(str(e), exc_info=True)
raise


class MockServerProcessTarget:
def __init__(self, handler: Type[SimpleHTTPRequestHandler] = MockHandler):
self.handler = handler

def run(self):
self.handler.received_requests = {}
self.server = HTTPServer(("localhost", 8888), self.handler)
try:
self.server.serve_forever(0.05)
finally:
self.server.server_close()

def stop(self):
self.handler.received_requests = {}
self.server.shutdown()
self.join()


class MonitorThread(threading.Thread):
def __init__(self, test: TestCase, handler: Type[SimpleHTTPRequestHandler] = MockHandler):
threading.Thread.__init__(self, daemon=True)
self.handler = handler
self.test = test
self.test.mock_received_requests = None
self.is_running = True

def run(self) -> None:
while self.is_running:
try:
req = Request(f"{self.test.server_url}/received_requests.json")
resp = urlopen(req, timeout=1)
self.test.mock_received_requests = json.loads(resp.read().decode("utf-8"))
except Exception as e:
# skip logging for the initial request
if self.test.mock_received_requests is not None:
logging.getLogger(__name__).exception(e)
time.sleep(0.01)

def stop(self):
self.is_running = False
self.join()


class MockServerThread(threading.Thread):
def __init__(self, test: TestCase, handler: Type[SimpleHTTPRequestHandler] = MockHandler):
threading.Thread.__init__(self)
self.handler = handler
self.test = test

def run(self):
self.server = HTTPServer(("localhost", 8888), self.handler)
self.test.server_url = "http://localhost:8888"
self.test.host, self.test.port = self.server.socket.getsockname()
self.test.server_started.set() # threading.Event()

self.test = None
try:
self.server.serve_forever()
finally:
self.server.server_close()

def stop(self):
self.server.shutdown()
self.join()


def setup_mock_web_api_server(test: TestCase):
if get_mock_server_mode() == "threading":
test.server_started = threading.Event()
test.thread = MockServerThread(test)
test.thread.start()
test.server_started.wait()
else:
# start a mock server as another process
target = MockServerProcessTarget()
test.server_url = "http://localhost:8888"
test.host, test.port = "localhost", 8888
test.process = Process(target=target.run, daemon=True)
test.process.start()

time.sleep(0.1)

# start a thread in the current process
# this thread fetches mock_received_requests from the remote process
test.monitor_thread = MonitorThread(test)
test.monitor_thread.start()
count = 0
# wait until the first successful data retrieval
while test.mock_received_requests is None:
time.sleep(0.01)
count += 1
if count >= 100:
raise Exception("The mock server is not yet running!")


def cleanup_mock_web_api_server(test: TestCase):
if get_mock_server_mode() == "threading":
test.thread.stop()
test.thread = None
else:
# stop the thread to fetch mock_received_requests from the remote process
test.monitor_thread.stop()

retry_count = 0
# terminate the process
while test.process.is_alive():
test.process.terminate()
time.sleep(0.01)
retry_count += 1
if retry_count >= 100:
raise Exception("Failed to stop the mock server!")

# Python 3.6 does not have this method
if sys.version_info.major == 3 and sys.version_info.minor > 6:
# cleanup the process's resources
test.process.close()

test.process = None
8 changes: 3 additions & 5 deletions tests/slack_sdk/audit_logs/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
from urllib.error import URLError

from slack_sdk.audit_logs import AuditLogsClient, AuditLogsResponse
from tests.slack_sdk.audit_logs.mock_web_api_server import (
cleanup_mock_web_api_server,
setup_mock_web_api_server,
)
from tests.slack_sdk.audit_logs.mock_web_api_handler import MockHandler
from tests.mock_web_api_server import setup_mock_web_api_server, cleanup_mock_web_api_server


class TestAuditLogsClient(unittest.TestCase):
def setUp(self):
self.client = AuditLogsClient(token="xoxp-", base_url="http://localhost:8888/")
setup_mock_web_api_server(self)
setup_mock_web_api_server(self, MockHandler)

def tearDown(self):
cleanup_mock_web_api_server(self)
Expand Down
8 changes: 3 additions & 5 deletions tests/slack_sdk/audit_logs/test_client_http_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@

from slack_sdk.audit_logs import AuditLogsClient
from slack_sdk.http_retry import RateLimitErrorRetryHandler
from tests.slack_sdk.audit_logs.mock_web_api_server import (
cleanup_mock_web_api_server,
setup_mock_web_api_server,
)
from tests.slack_sdk.audit_logs.mock_web_api_handler import MockHandler
from tests.mock_web_api_server import setup_mock_web_api_server, cleanup_mock_web_api_server
from ..my_retry_handler import MyRetryHandler


class TestAuditLogsClient_HttpRetries(unittest.TestCase):
def setUp(self):
setup_mock_web_api_server(self)
setup_mock_web_api_server(self, MockHandler)

def tearDown(self):
cleanup_mock_web_api_server(self)
Expand Down
8 changes: 3 additions & 5 deletions tests/slack_sdk/oauth/token_rotation/test_token_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
from slack_sdk.oauth.installation_store import Installation
from slack_sdk.oauth.token_rotation import TokenRotator
from slack_sdk.web import WebClient
from tests.slack_sdk.web.mock_web_api_server import (
setup_mock_web_api_server,
cleanup_mock_web_api_server,
)
from tests.slack_sdk.web.mock_web_api_handler import MockHandler
from tests.mock_web_api_server import setup_mock_web_api_server, cleanup_mock_web_api_server


class TestTokenRotator(unittest.TestCase):
def setUp(self):
setup_mock_web_api_server(self)
setup_mock_web_api_server(self, MockHandler)
self.token_rotator = TokenRotator(
client=WebClient(base_url="http://localhost:8888", token=None),
client_id="111.222",
Expand Down
Loading
Loading