Skip to content

Commit

Permalink
Merge pull request #4486 from eclipse/jetty-9.4.x-4475-WebSocketStrea…
Browse files Browse the repository at this point in the history
…mMessageOrder

Issue #4475 - fix WebSocket streaming message ordering
  • Loading branch information
lachlan-roberts authored Jan 23, 2020
2 parents a76fd0e + 08b1be6 commit b649641
Show file tree
Hide file tree
Showing 7 changed files with 672 additions and 274 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,27 +116,24 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException
if (activeMessage == null)
{
if (LOG.isDebugEnabled())
{
LOG.debug("Binary Message InputStream");
}
final MessageInputStream stream = new MessageInputStream();

final MessageInputStream stream = new MessageInputStream(session);
activeMessage = stream;

// Always dispatch streaming read to another thread.
dispatch(new Runnable()
dispatch(() ->
{
@Override
public void run()
try
{
try
{
events.callBinaryStream(jsrsession.getAsyncRemote(), websocket, stream);
}
catch (Throwable e)
{
onFatalError(e);
}
events.callBinaryStream(jsrsession.getAsyncRemote(), websocket, stream);
}
catch (Throwable e)
{
session.close(e);
}

stream.close();
});
}
}
Expand Down Expand Up @@ -330,28 +327,25 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
if (activeMessage == null)
{
if (LOG.isDebugEnabled())
{
LOG.debug("Text Message Writer");
}

final MessageReader stream = new MessageReader(new MessageInputStream());
activeMessage = stream;
MessageInputStream inputStream = new MessageInputStream(session);
final MessageReader reader = new MessageReader(inputStream);
activeMessage = inputStream;

// Always dispatch streaming read to another thread.
dispatch(new Runnable()
dispatch(() ->
{
@Override
public void run()
try
{
try
{
events.callTextStream(jsrsession.getAsyncRemote(), websocket, stream);
}
catch (Throwable e)
{
onFatalError(e);
}
events.callTextStream(jsrsession.getAsyncRemote(), websocket, reader);
}
catch (Throwable e)
{
session.close(e);
}

inputStream.close();
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,22 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException
}
else if (wrapper.wantsStreams())
{
final MessageInputStream stream = new MessageInputStream();
activeMessage = stream;
dispatch(new Runnable()
@SuppressWarnings("unchecked")
MessageHandler.Whole<InputStream> handler = (Whole<InputStream>)wrapper.getHandler();
MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = inputStream;
dispatch(() ->
{
@SuppressWarnings("unchecked")
@Override
public void run()
try
{
MessageHandler.Whole<InputStream> handler = (Whole<InputStream>)wrapper.getHandler();
handler.onMessage(stream);
handler.onMessage(inputStream);
}
catch (Throwable t)
{
session.close(t);
}

inputStream.close();
});
}
else
Expand Down Expand Up @@ -190,18 +195,23 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
}
else if (wrapper.wantsStreams())
{
final MessageReader stream = new MessageReader(new MessageInputStream());
activeMessage = stream;

dispatch(new Runnable()
@SuppressWarnings("unchecked")
MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler();
MessageInputStream inputStream = new MessageInputStream(session);
MessageReader reader = new MessageReader(inputStream);
activeMessage = reader;
dispatch(() ->
{
@SuppressWarnings("unchecked")
@Override
public void run()
try
{
MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler();
handler.onMessage(stream);
handler.onMessage(reader);
}
catch (Throwable t)
{
session.close(t);
}

inputStream.close();
});
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@

import java.io.IOException;
import java.io.Reader;
import java.io.StringWriter;
import java.io.Writer;
import java.net.URI;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpoint;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
Expand All @@ -36,18 +40,24 @@
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.websocket.jsr356.server.deploy.WebSocketServerContainerInitializer;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class TextStreamTest
{
private static final String PATH = "/echo";
private static final String CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
private static final BlockingArrayQueue<QueuedTextStreamer> serverEndpoints = new BlockingArrayQueue<>();

private Server server;
private ServerConnector connector;
Expand All @@ -62,8 +72,9 @@ public void prepare() throws Exception

ServletContextHandler context = new ServletContextHandler(server, "/", true, false);
ServerContainer container = WebSocketServerContainerInitializer.configureContext(context);
ServerEndpointConfig config = ServerEndpointConfig.Builder.create(ServerTextStreamer.class, PATH).build();
container.addEndpoint(config);
container.addEndpoint(ServerEndpointConfig.Builder.create(ServerTextStreamer.class, PATH).build());
container.addEndpoint(ServerEndpointConfig.Builder.create(QueuedTextStreamer.class, "/test").build());
container.addEndpoint(ServerEndpointConfig.Builder.create(QueuedPartialTextStreamer.class, "/partial").build());

server.start();

Expand Down Expand Up @@ -125,6 +136,76 @@ public void testMoreThanLargestMessageOneByteAtATime() throws Exception
assertArrayEquals(data, client.getEcho());
}

@Test
public void testMessageOrdering() throws Exception
{
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/test");
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, uri);

final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
session.getBasicRemote().sendText(Integer.toString(i));
session.close();

QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
assertThat(msg, Matchers.is(Integer.toString(i)));
}
}

@Test
public void testFragmentedMessageOrdering() throws Exception
{
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/test");
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, uri);

final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
{
session.getBasicRemote().sendText("firstFrame" + i, false);
session.getBasicRemote().sendText("|secondFrame" + i, false);
session.getBasicRemote().sendText("|finalFrame" + i, true);
}
session.close();

QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
String expected = "firstFrame" + i + "|secondFrame" + i + "|finalFrame" + i;
assertThat(msg, Matchers.is(expected));
}
}

@Test
public void testMessageOrderingDoNotReadToEOF() throws Exception
{
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/partial");
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, uri);

final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
{
session.getBasicRemote().sendText(i + "|-----");
}
session.close();

QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
assertThat(msg, Matchers.is(Integer.toString(i)));
}
}

private char[] randomChars(int size)
{
char[] data = new char[size];
Expand Down Expand Up @@ -183,4 +264,62 @@ public void echo(Session session, Reader input) throws IOException
}
}
}

public static class QueuedTextStreamer extends Endpoint implements MessageHandler.Whole<Reader>
{
protected BlockingArrayQueue<String> messages = new BlockingArrayQueue<>();

public QueuedTextStreamer()
{
serverEndpoints.add(this);
}

@Override
public void onOpen(Session session, EndpointConfig config)
{
session.addMessageHandler(this);
}

@Override
public void onMessage(Reader input)
{
try
{
Thread.sleep(Math.abs(new Random().nextLong() % 200));
messages.add(IO.toString(input));
}
catch (Exception e)
{
e.printStackTrace();
}
}
}

public static class QueuedPartialTextStreamer extends QueuedTextStreamer
{
@Override
public void onMessage(Reader input)
{
try
{
Thread.sleep(Math.abs(new Random().nextLong() % 200));

// Do not read to EOF but just the first '|'.
StringWriter writer = new StringWriter();
while (true)
{
int read = input.read();
if (read < 0 || read == '|')
break;
writer.write(read);
}

messages.add(writer.toString());
}
catch (Exception e)
{
e.printStackTrace();
}
}
}
}
Loading

0 comments on commit b649641

Please sign in to comment.