From bebd5c5b54d0aaed6dc98f9f704d9a014aa7aad6 Mon Sep 17 00:00:00 2001 From: Maksym Ostroverkhov Date: Thu, 9 Feb 2023 13:11:41 +0200 Subject: [PATCH] extend websocket handshake test suite --- .../websocketx/WebSocketHandshakeTest.java | 176 ++++++++++++++++++ 1 file changed, 176 insertions(+) diff --git a/netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketHandshakeTest.java b/netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketHandshakeTest.java index 71cd984..ed642e8 100644 --- a/netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketHandshakeTest.java +++ b/netty-websocket-http1-test/src/test/java/com/jauntsdn/netty/handler/codec/http/websocketx/WebSocketHandshakeTest.java @@ -19,7 +19,9 @@ import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; @@ -28,13 +30,25 @@ import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakeException; import io.netty.handler.codec.http.websocketx.WebSocketDecoderConfig; +import io.netty.util.AttributeKey; import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.DefaultPromise; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; import org.assertj.core.api.Assertions; @@ -174,6 +188,123 @@ void serverBuilderMissingHandler() { }); } + @Timeout(15) + @Test + void clientTimeout() throws InterruptedException { + Channel s = + server = + new ServerBootstrap() + .group(new NioEventLoopGroup(1)) + .channel(NioServerSocketChannel.class) + .childHandler( + new ChannelInitializer() { + + @Override + protected void initChannel(SocketChannel ch) { + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast( + new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ReferenceCountUtil.safeRelease(msg); + } + }); + } + }) + .bind("localhost", 0) + .sync() + .channel(); + + AttributeKey handshakeKey = AttributeKey.newInstance("handshake"); + + Channel client = + new Bootstrap() + .group(new NioEventLoopGroup(1)) + .channel(NioSocketChannel.class) + .handler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + + HttpClientCodec http1Codec = new HttpClientCodec(); + HttpObjectAggregator http1Aggregator = new HttpObjectAggregator(65536); + + WebSocketClientProtocolHandler webSocketProtocolHandler = + WebSocketClientProtocolHandler.create() + .handshakeTimeoutMillis(1) + .allowMaskMismatch(true) + .webSocketHandler( + (ctx, webSocketFrameFactory) -> { + throw new AssertionError("should not be called"); + }) + .build(); + + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(http1Codec, http1Aggregator, webSocketProtocolHandler); + + ChannelFuture handshake = webSocketProtocolHandler.handshakeCompleted(); + ch.attr(handshakeKey).set(handshake); + } + }) + .connect(s.localAddress()) + .sync() + .channel(); + + ChannelFuture handshakeFuture = client.attr(handshakeKey).get(); + handshakeFuture.await(); + Throwable cause = handshakeFuture.cause(); + Assertions.assertThat(cause).isNotNull(); + Assertions.assertThat(cause).isInstanceOf(WebSocketClientHandshakeException.class); + client.closeFuture().await(); + Assertions.assertThat(client.isOpen()).isFalse(); + } + + @Timeout(15) + @Test + void serverNonWebSocketRequest() throws InterruptedException { + WebSocketDecoderConfig decoderConfig = webSocketDecoderConfig(false, true, 125); + TestWebSocketHandler serverHandler = new TestWebSocketHandler(); + Channel s = server = testServer("/", decoderConfig, serverHandler); + + AttributeKey> handshakeKey = AttributeKey.newInstance("response"); + + Channel client = + new Bootstrap() + .group(new NioEventLoopGroup(1)) + .channel(NioSocketChannel.class) + .handler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + + HttpClientCodec http1Codec = new HttpClientCodec(); + HttpObjectAggregator http1Aggregator = new HttpObjectAggregator(65536); + NonWebSocketRequestHandler nonWebSocketRequestHandler = + new NonWebSocketRequestHandler(); + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(http1Codec, http1Aggregator, nonWebSocketRequestHandler); + + Future handshake = nonWebSocketRequestHandler.response(); + ch.attr(handshakeKey).set(handshake); + } + }) + .connect(s.localAddress()) + .sync() + .channel(); + + Future responseFuture = client.attr(handshakeKey).get(); + responseFuture.await(); + FullHttpResponse response = responseFuture.getNow(); + try { + Assertions.assertThat(response).isNotNull(); + Assertions.assertThat(response.status()).isEqualTo(HttpResponseStatus.BAD_REQUEST); + } finally { + response.release(); + } + client.closeFuture().await(); + Assertions.assertThat(client.isOpen()).isFalse(); + } + static Channel testClient( SocketAddress address, String path, @@ -236,6 +367,51 @@ static Channel testServer( .channel(); } + static class NonWebSocketRequestHandler extends ChannelInboundHandlerAdapter { + private Promise responsePromise; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof FullHttpResponse) { + responsePromise.trySuccess((FullHttpResponse) msg); + return; + } + super.channelRead(ctx, msg); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + responsePromise.tryFailure(new ClosedChannelException()); + super.channelInactive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + responsePromise.tryFailure(cause); + super.exceptionCaught(ctx, cause); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + super.handlerAdded(ctx); + responsePromise = new DefaultPromise<>(ctx.executor()); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + FullHttpRequest request = + new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.POST, "/", Unpooled.EMPTY_BUFFER); + + ctx.writeAndFlush(request); + } + + Future response() { + return responsePromise; + } + } + static class TestAcceptor extends ChannelInitializer { private final String path; private final WebSocketDecoderConfig webSocketDecoderConfig;