Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Better testing framework for homeserver-using things #3446

Merged
merged 3 commits into from
Jun 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added changelog.d/3446.misc
Empty file.
181 changes: 181 additions & 0 deletions tests/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from io import BytesIO

import attr
import json
from six import text_type

from twisted.python.failure import Failure
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactorClock

from synapse.http.site import SynapseRequest
from twisted.internet import threads
from tests.utils import setup_test_homeserver as _sth


@attr.s
class FakeChannel(object):
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
"""

result = attr.ib(factory=dict)

@property
def json_body(self):
if not self.result:
raise Exception("No result yet.")
return json.loads(self.result["body"])

def writeHeaders(self, version, code, reason, headers):
self.result["version"] = version
self.result["code"] = code
self.result["reason"] = reason
self.result["headers"] = headers

def write(self, content):
if "body" not in self.result:
self.result["body"] = b""

self.result["body"] += content

def requestDone(self, _self):
self.result["done"] = True

def getPeer(self):
return None

def getHost(self):
return None

@property
def transport(self):
return self


class FakeSite:
"""
A fake Twisted Web Site, with mocks of the extra things that
Synapse adds.
"""

server_version_string = b"1"
site_tag = "test"

@property
def access_logger(self):
class FakeLogger:
def info(self, *args, **kwargs):
pass

return FakeLogger()


def make_request(method, path, content=b""):
"""
Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath.
"""

if isinstance(content, text_type):
content = content.encode('utf8')

site = FakeSite()
channel = FakeChannel()

req = SynapseRequest(site, channel)
req.process = lambda: b""
req.content = BytesIO(content)
req.requestReceived(method, path, b"1.1")

return req, channel


def wait_until_result(clock, channel, timeout=100):
"""
Wait until the channel has a result.
"""
clock.run()
x = 0

while not channel.result:
x += 1

if x > timeout:
raise Exception("Timed out waiting for request to finish.")

clock.advance(0.1)


class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
A MemoryReactorClock that supports callFromThread.
"""
def callFromThread(self, callback, *args, **kwargs):
"""
Make the callback fire in the next reactor iteration.
"""
d = Deferred()
d.addCallback(lambda x: callback(*args, **kwargs))
self.callLater(0, d.callback, True)
return d


def setup_test_homeserver(*args, **kwargs):
"""
Set up a synchronous test server, driven by the reactor used by
the homeserver.
"""
d = _sth(*args, **kwargs).result

# Make the thread pool synchronous.
clock = d.get_clock()
pool = d.get_db_pool()

def runWithConnection(func, *args, **kwargs):
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runWithConnection,
func,
*args,
**kwargs
)

def runInteraction(interaction, *args, **kwargs):
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runInteraction,
interaction,
*args,
**kwargs
)

pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction

class ThreadPool:
"""
Threadless thread pool.
"""
def start(self):
pass

def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
def _(res):
if isinstance(res, Failure):
onResult(False, res)
else:
onResult(True, res)

d = Deferred()
d.addCallback(lambda x: function(*args, **kwargs))
d.addBoth(_)
clock._reactor.callLater(0, d.callback, True)
return d

clock.threadpool = ThreadPool()
pool.threadpool = ThreadPool()
return d
128 changes: 128 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import json
import re

from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactorClock

from synapse.util import Clock
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
from tests import unittest
from tests.server import make_request, setup_test_homeserver


class JsonResourceTests(unittest.TestCase):
def setUp(self):
self.reactor = MemoryReactorClock()
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
http_client=None, clock=self.hs_clock, reactor=self.reactor
)

def test_handler_for_request(self):
"""
JsonResource.handler_for_request gives correctly decoded URL args to
the callback, while Twisted will give the raw bytes of URL query
arguments.
"""
got_kwargs = {}

def _callback(request, **kwargs):
got_kwargs.update(kwargs)
return (200, kwargs)

res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo/(?P<room_id>[^/]*)$")], _callback)

request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83")
request.render(res)

self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})

def test_callback_direct_exception(self):
"""
If the web callback raises an uncaught exception, it will be translated
into a 500.
"""

def _callback(request, **kwargs):
raise Exception("boo")

res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback)

request, channel = make_request(b"GET", b"/foo")
request.render(res)

self.assertEqual(channel.result["code"], b'500')

def test_callback_indirect_exception(self):
"""
If the web callback raises an uncaught exception in a Deferred, it will
be translated into a 500.
"""

def _throw(*args):
raise Exception("boo")

def _callback(request, **kwargs):
d = Deferred()
d.addCallback(_throw)
self.reactor.callLater(1, d.callback, True)
return d

res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback)

request, channel = make_request(b"GET", b"/foo")
request.render(res)

# No error has been raised yet
self.assertTrue("code" not in channel.result)

# Advance time, now there's an error
self.reactor.advance(1)
self.assertEqual(channel.result["code"], b'500')

def test_callback_synapseerror(self):
"""
If the web callback raises a SynapseError, it returns the appropriate
status code and message set in it.
"""

def _callback(request, **kwargs):
raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN)

res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback)

request, channel = make_request(b"GET", b"/foo")
request.render(res)

self.assertEqual(channel.result["code"], b'403')
reply_body = json.loads(channel.result["body"])
self.assertEqual(reply_body["error"], "Forbidden!!one!")
self.assertEqual(reply_body["errcode"], "M_FORBIDDEN")

def test_no_handler(self):
"""
If there is no handler to process the request, Synapse will return 400.
"""

def _callback(request, **kwargs):
"""
Not ever actually called!
"""
self.fail("shouldn't ever get here")

res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/foo$")], _callback)

request, channel = make_request(b"GET", b"/foobar")
request.render(res)

self.assertEqual(channel.result["code"], b'400')
reply_body = json.loads(channel.result["body"])
self.assertEqual(reply_body["error"], "Unrecognized request")
self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED")