Skip to content

Commit

Permalink
websocket_test: Remove most manual closes
Browse files Browse the repository at this point in the history
At one time this was necessary to prevent spurious warnings at
shutdown, but not any more (and I intend to address warnings like this
with a more general solution).
  • Loading branch information
bdarnell committed Dec 10, 2018
1 parent c350dc9 commit ae9a2da
Showing 1 changed file with 17 additions and 71 deletions.
88 changes: 17 additions & 71 deletions tornado/test/websocket_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,21 @@
class TestWebSocketHandler(WebSocketHandler):
"""Base class for testing handlers that exposes the on_close event.
This allows for deterministic cleanup of the associated socket.
This allows for tests to see the close code and reason on the
server side.
"""

def initialize(self, close_future, compression_options=None):
def initialize(self, close_future=None, compression_options=None):
self.close_future = close_future
self.compression_options = compression_options

def get_compression_options(self):
return self.compression_options

def on_close(self):
self.close_future.set_result((self.close_code, self.close_reason))
if self.close_future is not None:
self.close_future.set_result((self.close_code, self.close_reason))


class EchoHandler(TestWebSocketHandler):
Expand Down Expand Up @@ -125,10 +128,8 @@ def open(self, arg):


class CoroutineOnMessageHandler(TestWebSocketHandler):
def initialize(self, close_future, compression_options=None):
super(CoroutineOnMessageHandler, self).initialize(
close_future, compression_options
)
def initialize(self, **kwargs):
super(CoroutineOnMessageHandler, self).initialize(**kwargs)
self.sleeping = 0

@gen.coroutine
Expand Down Expand Up @@ -191,16 +192,6 @@ def ws_connect(self, path, **kwargs):
)
raise gen.Return(ws)

@gen.coroutine
def close(self, ws):
"""Close a websocket connection and wait for the server side.
If we don't wait here, there are sometimes leak warnings in the
tests.
"""
ws.close()
yield self.close_future


class WebSocketTest(WebSocketBaseTestCase):
def get_app(self):
Expand Down Expand Up @@ -296,7 +287,6 @@ def test_websocket_gen(self):
yield ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
yield self.close(ws)

def test_websocket_callbacks(self):
websocket_connect(
Expand All @@ -317,23 +307,20 @@ def test_binary_message(self):
ws.write_message(b"hello \xe9", binary=True)
response = yield ws.read_message()
self.assertEqual(response, b"hello \xe9")
yield self.close(ws)

@gen_test
def test_unicode_message(self):
ws = yield self.ws_connect("/echo")
ws.write_message(u"hello \u00e9")
response = yield ws.read_message()
self.assertEqual(response, u"hello \u00e9")
yield self.close(ws)

@gen_test
def test_render_message(self):
ws = yield self.ws_connect("/render")
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "<b>hello</b>")
yield self.close(ws)

@gen_test
def test_error_in_on_message(self):
Expand All @@ -342,7 +329,6 @@ def test_error_in_on_message(self):
with ExpectLog(app_log, "Uncaught exception"):
response = yield ws.read_message()
self.assertIs(response, None)
yield self.close(ws)

@gen_test
def test_websocket_http_fail(self):
Expand Down Expand Up @@ -372,7 +358,6 @@ def test_websocket_close_buffered_data(self):
ws.write_message("world")
# Close the underlying stream.
ws.stream.close()
yield self.close_future

@gen_test
def test_websocket_headers(self):
Expand All @@ -385,7 +370,6 @@ def test_websocket_headers(self):
)
response = yield ws.read_message()
self.assertEqual(response, "hello")
yield self.close(ws)

@gen_test
def test_websocket_header_echo(self):
Expand All @@ -402,7 +386,6 @@ def test_websocket_header_echo(self):
self.assertEqual(
ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
)
yield self.close(ws)

@gen_test
def test_server_close_reason(self):
Expand Down Expand Up @@ -472,7 +455,6 @@ def test_check_origin_valid_no_path(self):
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
yield self.close(ws)

@gen_test
def test_check_origin_valid_with_path(self):
Expand All @@ -485,7 +467,6 @@ def test_check_origin_valid_with_path(self):
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
yield self.close(ws)

@gen_test
def test_check_origin_invalid_partial_url(self):
Expand Down Expand Up @@ -534,15 +515,13 @@ def test_subprotocols(self):
self.assertEqual(ws.selected_subprotocol, "goodproto")
res = yield ws.read_message()
self.assertEqual(res, "subprotocol=goodproto")
yield self.close(ws)

@gen_test
def test_subprotocols_not_offered(self):
ws = yield self.ws_connect("/subprotocol")
self.assertIs(ws.selected_subprotocol, None)
res = yield ws.read_message()
self.assertEqual(res, "subprotocol=None")
yield self.close(ws)

@gen_test
def test_open_coroutine(self):
Expand All @@ -552,12 +531,11 @@ def test_open_coroutine(self):
self.message_sent.set()
res = yield ws.read_message()
self.assertEqual(res, "ok")
yield self.close(ws)


class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
def initialize(self, close_future, compression_options=None):
super().initialize(close_future, compression_options)
def initialize(self, **kwargs):
super().initialize(**kwargs)
self.sleeping = 0

async def on_message(self, message):
Expand All @@ -571,16 +549,7 @@ async def on_message(self, message):

class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future() # type: Future[None]
return Application(
[
(
"/native",
NativeCoroutineOnMessageHandler,
dict(close_future=self.close_future),
)
]
)
return Application([("/native", NativeCoroutineOnMessageHandler)])

@gen_test
def test_native_coroutine(self):
Expand All @@ -598,8 +567,6 @@ class CompressionTestMixin(object):
MESSAGE = "Hello world. Testing 123 123"

def get_app(self):
self.close_future = Future() # type: Future[None]

class LimitedHandler(TestWebSocketHandler):
@property
def max_message_size(self):
Expand All @@ -613,18 +580,12 @@ def on_message(self, message):
(
"/echo",
EchoHandler,
dict(
close_future=self.close_future,
compression_options=self.get_server_compression_options(),
),
dict(compression_options=self.get_server_compression_options()),
),
(
"/limited",
LimitedHandler,
dict(
close_future=self.close_future,
compression_options=self.get_server_compression_options(),
),
dict(compression_options=self.get_server_compression_options()),
),
]
)
Expand All @@ -649,7 +610,6 @@ def test_message_sizes(self):
self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out)
yield self.close(ws)

@gen_test
def test_size_limit(self):
Expand All @@ -665,7 +625,6 @@ def test_size_limit(self):
ws.write_message("a" * 2048)
response = yield ws.read_message()
self.assertIsNone(response)
yield self.close(ws)


class UncompressedTestMixin(CompressionTestMixin):
Expand Down Expand Up @@ -743,19 +702,14 @@ class PingHandler(TestWebSocketHandler):
def on_pong(self, data):
self.write_message("got pong")

self.close_future = Future() # type: Future[None]
return Application(
[("/", PingHandler, dict(close_future=self.close_future))],
websocket_ping_interval=0.01,
)
return Application([("/", PingHandler)], websocket_ping_interval=0.01)

@gen_test
def test_server_ping(self):
ws = yield self.ws_connect("/")
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got pong")
yield self.close(ws)
# TODO: test that the connection gets closed if ping responses stop.


Expand All @@ -765,16 +719,14 @@ class PingHandler(TestWebSocketHandler):
def on_ping(self, data):
self.write_message("got ping")

self.close_future = Future() # type: Future[None]
return Application([("/", PingHandler, dict(close_future=self.close_future))])
return Application([("/", PingHandler)])

@gen_test
def test_client_ping(self):
ws = yield self.ws_connect("/", ping_interval=0.01)
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got ping")
yield self.close(ws)
# TODO: test that the connection gets closed if ping responses stop.


Expand All @@ -784,8 +736,7 @@ class PingHandler(TestWebSocketHandler):
def on_ping(self, data):
self.write_message(data, binary=isinstance(data, bytes))

self.close_future = Future() # type: Future[None]
return Application([("/", PingHandler, dict(close_future=self.close_future))])
return Application([("/", PingHandler)])

@gen_test
def test_manual_ping(self):
Expand All @@ -801,16 +752,11 @@ def test_manual_ping(self):
ws.ping(b"binary hello")
resp = yield ws.read_message()
self.assertEqual(resp, b"binary hello")
yield self.close(ws)


class MaxMessageSizeTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future() # type: Future[None]
return Application(
[("/", EchoHandler, dict(close_future=self.close_future))],
websocket_max_message_size=1024,
)
return Application([("/", EchoHandler)], websocket_max_message_size=1024)

@gen_test
def test_large_message(self):
Expand Down

0 comments on commit ae9a2da

Please sign in to comment.