diff --git a/core/src/main/java/io/undertow/websockets/core/WebSocketChannel.java b/core/src/main/java/io/undertow/websockets/core/WebSocketChannel.java index aacebd8e51..4b76338091 100644 --- a/core/src/main/java/io/undertow/websockets/core/WebSocketChannel.java +++ b/core/src/main/java/io/undertow/websockets/core/WebSocketChannel.java @@ -24,6 +24,7 @@ import io.undertow.websockets.extensions.ExtensionFunction; import org.xnio.ChannelExceptionHandler; import org.xnio.ChannelListener; +import org.xnio.ChannelListener.SimpleSetter; import org.xnio.ChannelListeners; import org.xnio.IoUtils; import org.xnio.OptionMap; @@ -82,6 +83,7 @@ public abstract class WebSocketChannel extends AbstractFramedChannel peerConnections; + private static final CloseMessage CLOSE_MSG = new CloseMessage(CloseMessage.GOING_AWAY, WebSocketMessages.MESSAGES.messageCloseWebSocket()); /** * Create a new {@link WebSocketChannel} * 8 @@ -158,6 +160,15 @@ protected void lastDataRead() { } catch (IOException e) { IoUtils.safeClose(this); } + final ChannelListener listener = ((SimpleSetter)getReceiveSetter()).get(); + if(listener instanceof AbstractReceiveListener) { + final AbstractReceiveListener abstractReceiveListener = (AbstractReceiveListener) listener; + try { + abstractReceiveListener.onCloseMessage(CLOSE_MSG, this); + } catch(Exception e) { + e.printStackTrace(); + } + } } } diff --git a/core/src/main/java/io/undertow/websockets/core/WebSocketMessages.java b/core/src/main/java/io/undertow/websockets/core/WebSocketMessages.java index fc17a8387c..491a33f0e0 100644 --- a/core/src/main/java/io/undertow/websockets/core/WebSocketMessages.java +++ b/core/src/main/java/io/undertow/websockets/core/WebSocketMessages.java @@ -171,4 +171,7 @@ public interface WebSocketMessages { @Message(id = 2045, value = "Unable to send on newly created channel!") IllegalStateException unableToSendOnNewChannel(); + + @Message(id = 2046, value = "Closing WebSocket, peer went away.") + String messageCloseWebSocket(); } diff --git a/core/src/test/java/io/undertow/websockets/core/protocol/AbstractWebSocketServerTest.java b/core/src/test/java/io/undertow/websockets/core/protocol/AbstractWebSocketServerTest.java index ac52118f0b..4561aec3a9 100644 --- a/core/src/test/java/io/undertow/websockets/core/protocol/AbstractWebSocketServerTest.java +++ b/core/src/test/java/io/undertow/websockets/core/protocol/AbstractWebSocketServerTest.java @@ -26,8 +26,10 @@ import io.undertow.websockets.core.AbstractReceiveListener; import io.undertow.websockets.core.BufferedBinaryMessage; import io.undertow.websockets.core.BufferedTextMessage; +import io.undertow.websockets.core.CloseMessage; import io.undertow.websockets.core.WebSocketCallback; import io.undertow.websockets.core.WebSocketChannel; +import io.undertow.websockets.core.WebSocketMessages; import io.undertow.websockets.core.WebSockets; import io.undertow.websockets.spi.WebSocketHttpExchange; import io.undertow.websockets.utils.FrameChecker; @@ -46,6 +48,7 @@ import java.io.IOException; import java.net.URI; import java.nio.ByteBuffer; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -167,6 +170,50 @@ protected void onFullCloseMessage(WebSocketChannel channel, BufferedBinaryMessag client.destroy(); } + @Test + public void testCloseOnPeerGone() throws Exception { + if (getVersion() == WebSocketVersion.V00) { + // ignore 00 tests for now + return; + } + final AtomicBoolean connected = new AtomicBoolean(false); + final FutureResult latch = new FutureResult(); + DefaultServer.setRootHandler(new WebSocketProtocolHandshakeHandler(new WebSocketConnectionCallback() { + @Override + public void onConnect(final WebSocketHttpExchange exchange, final WebSocketChannel channel) { + connected.set(true); + channel.getReceiveSetter().set(new AbstractReceiveListener() { + + @Override + protected void onFullTextMessage(WebSocketChannel channel, BufferedTextMessage message) { + Assert.fail(); + } + + @Override + protected void onCloseMessage(CloseMessage msg, WebSocketChannel channel) { + latch.setResult(msg); + } + + @Override + protected void onError(WebSocketChannel channel, Throwable t) { + Assert.fail(); + } + }); + channel.resumeReceives(); + } + })); + + WebSocketTestClient client = new WebSocketTestClient(getVersion(), + new URI("ws://" + NetworkUtils.formatPossibleIpv6Address(DefaultServer.getHostAddress("default")) + ":" + + DefaultServer.getHostPort("default") + "/")); + client.connect(); + client.destroy(true); + latch.getIoFuture().await(5000, TimeUnit.MILLISECONDS); + final CloseMessage msg = latch.getIoFuture().get(); + Assert.assertNotNull(msg); + Assert.assertEquals(WebSocketMessages.MESSAGES.messageCloseWebSocket(), msg.getReason()); + } + protected WebSocketVersion getVersion() { return WebSocketVersion.V00; } diff --git a/core/src/test/java/io/undertow/websockets/utils/WebSocketTestClient.java b/core/src/test/java/io/undertow/websockets/utils/WebSocketTestClient.java index b50c785369..373c186b55 100644 --- a/core/src/test/java/io/undertow/websockets/utils/WebSocketTestClient.java +++ b/core/src/test/java/io/undertow/websockets/utils/WebSocketTestClient.java @@ -137,7 +137,11 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E * Destroy the client and also close open connections if any exist */ public void destroy() { - if (!closed) { + this.destroy(false); + } + + public void destroy(boolean dirty) { + if (!closed && !dirty) { final CountDownLatch latch = new CountDownLatch(1); send(new CloseWebSocketFrame(), new FrameListener() { @Override