diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index 704828e3ed..d63a665af8 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -39,10 +39,12 @@ class TestWebSocketHandler(WebSocketHandler): """Base class for testing handlers that exposes the on_close event. - This allows for deterministic cleanup of the associated socket. + This allows for tests to see the close code and reason on the + server side. + """ - def initialize(self, close_future, compression_options=None): + def initialize(self, close_future=None, compression_options=None): self.close_future = close_future self.compression_options = compression_options @@ -50,7 +52,8 @@ def get_compression_options(self): return self.compression_options def on_close(self): - self.close_future.set_result((self.close_code, self.close_reason)) + if self.close_future is not None: + self.close_future.set_result((self.close_code, self.close_reason)) class EchoHandler(TestWebSocketHandler): @@ -125,10 +128,8 @@ def open(self, arg): class CoroutineOnMessageHandler(TestWebSocketHandler): - def initialize(self, close_future, compression_options=None): - super(CoroutineOnMessageHandler, self).initialize( - close_future, compression_options - ) + def initialize(self, **kwargs): + super(CoroutineOnMessageHandler, self).initialize(**kwargs) self.sleeping = 0 @gen.coroutine @@ -191,16 +192,6 @@ def ws_connect(self, path, **kwargs): ) raise gen.Return(ws) - @gen.coroutine - def close(self, ws): - """Close a websocket connection and wait for the server side. - - If we don't wait here, there are sometimes leak warnings in the - tests. - """ - ws.close() - yield self.close_future - class WebSocketTest(WebSocketBaseTestCase): def get_app(self): @@ -296,7 +287,6 @@ def test_websocket_gen(self): yield ws.write_message("hello") response = yield ws.read_message() self.assertEqual(response, "hello") - yield self.close(ws) def test_websocket_callbacks(self): websocket_connect( @@ -317,7 +307,6 @@ def test_binary_message(self): ws.write_message(b"hello \xe9", binary=True) response = yield ws.read_message() self.assertEqual(response, b"hello \xe9") - yield self.close(ws) @gen_test def test_unicode_message(self): @@ -325,7 +314,6 @@ def test_unicode_message(self): ws.write_message(u"hello \u00e9") response = yield ws.read_message() self.assertEqual(response, u"hello \u00e9") - yield self.close(ws) @gen_test def test_render_message(self): @@ -333,7 +321,6 @@ def test_render_message(self): ws.write_message("hello") response = yield ws.read_message() self.assertEqual(response, "hello") - yield self.close(ws) @gen_test def test_error_in_on_message(self): @@ -342,7 +329,6 @@ def test_error_in_on_message(self): with ExpectLog(app_log, "Uncaught exception"): response = yield ws.read_message() self.assertIs(response, None) - yield self.close(ws) @gen_test def test_websocket_http_fail(self): @@ -372,7 +358,6 @@ def test_websocket_close_buffered_data(self): ws.write_message("world") # Close the underlying stream. ws.stream.close() - yield self.close_future @gen_test def test_websocket_headers(self): @@ -385,7 +370,6 @@ def test_websocket_headers(self): ) response = yield ws.read_message() self.assertEqual(response, "hello") - yield self.close(ws) @gen_test def test_websocket_header_echo(self): @@ -402,7 +386,6 @@ def test_websocket_header_echo(self): self.assertEqual( ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value" ) - yield self.close(ws) @gen_test def test_server_close_reason(self): @@ -472,7 +455,6 @@ def test_check_origin_valid_no_path(self): ws.write_message("hello") response = yield ws.read_message() self.assertEqual(response, "hello") - yield self.close(ws) @gen_test def test_check_origin_valid_with_path(self): @@ -485,7 +467,6 @@ def test_check_origin_valid_with_path(self): ws.write_message("hello") response = yield ws.read_message() self.assertEqual(response, "hello") - yield self.close(ws) @gen_test def test_check_origin_invalid_partial_url(self): @@ -534,7 +515,6 @@ def test_subprotocols(self): self.assertEqual(ws.selected_subprotocol, "goodproto") res = yield ws.read_message() self.assertEqual(res, "subprotocol=goodproto") - yield self.close(ws) @gen_test def test_subprotocols_not_offered(self): @@ -542,7 +522,6 @@ def test_subprotocols_not_offered(self): self.assertIs(ws.selected_subprotocol, None) res = yield ws.read_message() self.assertEqual(res, "subprotocol=None") - yield self.close(ws) @gen_test def test_open_coroutine(self): @@ -552,12 +531,11 @@ def test_open_coroutine(self): self.message_sent.set() res = yield ws.read_message() self.assertEqual(res, "ok") - yield self.close(ws) class NativeCoroutineOnMessageHandler(TestWebSocketHandler): - def initialize(self, close_future, compression_options=None): - super().initialize(close_future, compression_options) + def initialize(self, **kwargs): + super().initialize(**kwargs) self.sleeping = 0 async def on_message(self, message): @@ -571,16 +549,7 @@ async def on_message(self, message): class WebSocketNativeCoroutineTest(WebSocketBaseTestCase): def get_app(self): - self.close_future = Future() # type: Future[None] - return Application( - [ - ( - "/native", - NativeCoroutineOnMessageHandler, - dict(close_future=self.close_future), - ) - ] - ) + return Application([("/native", NativeCoroutineOnMessageHandler)]) @gen_test def test_native_coroutine(self): @@ -598,8 +567,6 @@ class CompressionTestMixin(object): MESSAGE = "Hello world. Testing 123 123" def get_app(self): - self.close_future = Future() # type: Future[None] - class LimitedHandler(TestWebSocketHandler): @property def max_message_size(self): @@ -613,18 +580,12 @@ def on_message(self, message): ( "/echo", EchoHandler, - dict( - close_future=self.close_future, - compression_options=self.get_server_compression_options(), - ), + dict(compression_options=self.get_server_compression_options()), ), ( "/limited", LimitedHandler, - dict( - close_future=self.close_future, - compression_options=self.get_server_compression_options(), - ), + dict(compression_options=self.get_server_compression_options()), ), ] ) @@ -649,7 +610,6 @@ def test_message_sizes(self): self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3) self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3) self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out) - yield self.close(ws) @gen_test def test_size_limit(self): @@ -665,7 +625,6 @@ def test_size_limit(self): ws.write_message("a" * 2048) response = yield ws.read_message() self.assertIsNone(response) - yield self.close(ws) class UncompressedTestMixin(CompressionTestMixin): @@ -743,11 +702,7 @@ class PingHandler(TestWebSocketHandler): def on_pong(self, data): self.write_message("got pong") - self.close_future = Future() # type: Future[None] - return Application( - [("/", PingHandler, dict(close_future=self.close_future))], - websocket_ping_interval=0.01, - ) + return Application([("/", PingHandler)], websocket_ping_interval=0.01) @gen_test def test_server_ping(self): @@ -755,7 +710,6 @@ def test_server_ping(self): for i in range(3): response = yield ws.read_message() self.assertEqual(response, "got pong") - yield self.close(ws) # TODO: test that the connection gets closed if ping responses stop. @@ -765,8 +719,7 @@ class PingHandler(TestWebSocketHandler): def on_ping(self, data): self.write_message("got ping") - self.close_future = Future() # type: Future[None] - return Application([("/", PingHandler, dict(close_future=self.close_future))]) + return Application([("/", PingHandler)]) @gen_test def test_client_ping(self): @@ -774,7 +727,6 @@ def test_client_ping(self): for i in range(3): response = yield ws.read_message() self.assertEqual(response, "got ping") - yield self.close(ws) # TODO: test that the connection gets closed if ping responses stop. @@ -784,8 +736,7 @@ class PingHandler(TestWebSocketHandler): def on_ping(self, data): self.write_message(data, binary=isinstance(data, bytes)) - self.close_future = Future() # type: Future[None] - return Application([("/", PingHandler, dict(close_future=self.close_future))]) + return Application([("/", PingHandler)]) @gen_test def test_manual_ping(self): @@ -801,16 +752,11 @@ def test_manual_ping(self): ws.ping(b"binary hello") resp = yield ws.read_message() self.assertEqual(resp, b"binary hello") - yield self.close(ws) class MaxMessageSizeTest(WebSocketBaseTestCase): def get_app(self): - self.close_future = Future() # type: Future[None] - return Application( - [("/", EchoHandler, dict(close_future=self.close_future))], - websocket_max_message_size=1024, - ) + return Application([("/", EchoHandler)], websocket_max_message_size=1024) @gen_test def test_large_message(self):