Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #4475 - fix WebSocket streaming message ordering #4486

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,27 +116,24 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException
if (activeMessage == null)
{
if (LOG.isDebugEnabled())
{
LOG.debug("Binary Message InputStream");
}
final MessageInputStream stream = new MessageInputStream();

final MessageInputStream stream = new MessageInputStream(session);
activeMessage = stream;

// Always dispatch streaming read to another thread.
dispatch(new Runnable()
dispatch(() ->
{
@Override
public void run()
try
{
try
{
events.callBinaryStream(jsrsession.getAsyncRemote(), websocket, stream);
}
catch (Throwable e)
{
onFatalError(e);
}
events.callBinaryStream(jsrsession.getAsyncRemote(), websocket, stream);
}
catch (Throwable e)
{
session.close(e);
}

stream.close();
});
}
}
Expand Down Expand Up @@ -330,28 +327,25 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
if (activeMessage == null)
{
if (LOG.isDebugEnabled())
{
LOG.debug("Text Message Writer");
}

final MessageReader stream = new MessageReader(new MessageInputStream());
activeMessage = stream;
MessageInputStream inputStream = new MessageInputStream(session);
final MessageReader reader = new MessageReader(inputStream);
activeMessage = inputStream;

// Always dispatch streaming read to another thread.
dispatch(new Runnable()
dispatch(() ->
{
@Override
public void run()
try
{
try
{
events.callTextStream(jsrsession.getAsyncRemote(), websocket, stream);
}
catch (Throwable e)
{
onFatalError(e);
}
events.callTextStream(jsrsession.getAsyncRemote(), websocket, reader);
}
catch (Throwable e)
{
session.close(e);
}

inputStream.close();
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,22 @@ public void onBinaryFrame(ByteBuffer buffer, boolean fin) throws IOException
}
else if (wrapper.wantsStreams())
{
final MessageInputStream stream = new MessageInputStream();
activeMessage = stream;
dispatch(new Runnable()
@SuppressWarnings("unchecked")
MessageHandler.Whole<InputStream> handler = (Whole<InputStream>)wrapper.getHandler();
MessageInputStream inputStream = new MessageInputStream(session);
activeMessage = inputStream;
dispatch(() ->
{
@SuppressWarnings("unchecked")
@Override
public void run()
try
{
MessageHandler.Whole<InputStream> handler = (Whole<InputStream>)wrapper.getHandler();
handler.onMessage(stream);
handler.onMessage(inputStream);
}
catch (Throwable t)
{
session.close(t);
}

inputStream.close();
});
}
else
Expand Down Expand Up @@ -190,18 +195,23 @@ public void onTextFrame(ByteBuffer buffer, boolean fin) throws IOException
}
else if (wrapper.wantsStreams())
{
final MessageReader stream = new MessageReader(new MessageInputStream());
activeMessage = stream;

dispatch(new Runnable()
@SuppressWarnings("unchecked")
MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler();
MessageInputStream inputStream = new MessageInputStream(session);
MessageReader reader = new MessageReader(inputStream);
activeMessage = reader;
dispatch(() ->
{
@SuppressWarnings("unchecked")
@Override
public void run()
try
{
MessageHandler.Whole<Reader> handler = (Whole<Reader>)wrapper.getHandler();
handler.onMessage(stream);
handler.onMessage(reader);
}
catch (Throwable t)
{
session.close(t);
}

inputStream.close();
});
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@

import java.io.IOException;
import java.io.Reader;
import java.io.StringWriter;
import java.io.Writer;
import java.net.URI;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpoint;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
Expand All @@ -36,18 +40,24 @@
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.websocket.jsr356.server.deploy.WebSocketServerContainerInitializer;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class TextStreamTest
{
private static final String PATH = "/echo";
private static final String CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
private static final BlockingArrayQueue<QueuedTextStreamer> serverEndpoints = new BlockingArrayQueue<>();

private Server server;
private ServerConnector connector;
Expand All @@ -62,8 +72,9 @@ public void prepare() throws Exception

ServletContextHandler context = new ServletContextHandler(server, "/", true, false);
ServerContainer container = WebSocketServerContainerInitializer.configureContext(context);
ServerEndpointConfig config = ServerEndpointConfig.Builder.create(ServerTextStreamer.class, PATH).build();
container.addEndpoint(config);
container.addEndpoint(ServerEndpointConfig.Builder.create(ServerTextStreamer.class, PATH).build());
container.addEndpoint(ServerEndpointConfig.Builder.create(QueuedTextStreamer.class, "/test").build());
container.addEndpoint(ServerEndpointConfig.Builder.create(QueuedPartialTextStreamer.class, "/partial").build());

server.start();

Expand Down Expand Up @@ -125,6 +136,76 @@ public void testMoreThanLargestMessageOneByteAtATime() throws Exception
assertArrayEquals(data, client.getEcho());
}

@Test
public void testMessageOrdering() throws Exception
{
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/test");
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, uri);

final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
session.getBasicRemote().sendText(Integer.toString(i));
session.close();

QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
assertThat(msg, Matchers.is(Integer.toString(i)));
}
}

@Test
public void testFragmentedMessageOrdering() throws Exception
{
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/test");
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, uri);

final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
{
session.getBasicRemote().sendText("firstFrame" + i, false);
session.getBasicRemote().sendText("|secondFrame" + i, false);
session.getBasicRemote().sendText("|finalFrame" + i, true);
}
session.close();

QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
String expected = "firstFrame" + i + "|secondFrame" + i + "|finalFrame" + i;
assertThat(msg, Matchers.is(expected));
}
}

@Test
public void testMessageOrderingDoNotReadToEOF() throws Exception
{
URI uri = URI.create("ws://localhost:" + connector.getLocalPort() + "/partial");
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, uri);

final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
{
session.getBasicRemote().sendText(i + "|-----");
}
session.close();

QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
assertThat(msg, Matchers.is(Integer.toString(i)));
}
}

private char[] randomChars(int size)
{
char[] data = new char[size];
Expand Down Expand Up @@ -183,4 +264,62 @@ public void echo(Session session, Reader input) throws IOException
}
}
}

public static class QueuedTextStreamer extends Endpoint implements MessageHandler.Whole<Reader>
{
protected BlockingArrayQueue<String> messages = new BlockingArrayQueue<>();

public QueuedTextStreamer()
{
serverEndpoints.add(this);
}

@Override
public void onOpen(Session session, EndpointConfig config)
{
session.addMessageHandler(this);
}

@Override
public void onMessage(Reader input)
{
try
{
Thread.sleep(Math.abs(new Random().nextLong() % 200));
messages.add(IO.toString(input));
}
catch (Exception e)
{
e.printStackTrace();
}
}
}

public static class QueuedPartialTextStreamer extends QueuedTextStreamer
{
@Override
public void onMessage(Reader input)
{
try
{
Thread.sleep(Math.abs(new Random().nextLong() % 200));

// Do not read to EOF but just the first '|'.
StringWriter writer = new StringWriter();
while (true)
{
int read = input.read();
if (read < 0 || read == '|')
break;
writer.write(read);
}

messages.add(writer.toString());
}
catch (Exception e)
{
e.printStackTrace();
}
}
}
}
Loading