Skip to content

Commit

Permalink
Addresses eclipse-ee4j#630
Browse files Browse the repository at this point in the history
  • Loading branch information
dansiviter committed Jun 29, 2020
1 parent 597e6c6 commit 125b6cb
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -301,21 +301,25 @@ public void removeMessageHandler(MessageHandler handler) {
}
}

Map<Class<?>, MessageHandler> getRegisteredHandlers() {
return new HashMap<>(registeredHandlers);
}

/**
* Get all successfully registered {@link MessageHandler}s.
*
* @return unmodifiable {@link Set} of registered {@link MessageHandler}s.
*/
public Set<MessageHandler> getMessageHandlers() {
if (messageHandlerCache == null) {
messageHandlerCache = Collections.unmodifiableSet(new HashSet<MessageHandler>(registeredHandlers.values()));
messageHandlerCache = Collections.unmodifiableSet(new HashSet<>(registeredHandlers.values()));
}

return messageHandlerCache;
}

public List<Map.Entry<Class<?>, MessageHandler>> getOrderedWholeMessageHandlers() {
List<Map.Entry<Class<?>, MessageHandler>> result = new ArrayList<Map.Entry<Class<?>, MessageHandler>>();
List<Map.Entry<Class<?>, MessageHandler>> result = new ArrayList<>();
for (final Map.Entry<Class<?>, MessageHandler> entry : registeredHandlers.entrySet()) {
if (entry.getValue() instanceof MessageHandler.Whole) {
result.add(entry);
Expand All @@ -325,7 +329,7 @@ public List<Map.Entry<Class<?>, MessageHandler>> getOrderedWholeMessageHandlers(
return result;
}

static Class<?> getHandlerType(MessageHandler handler) {
private static Class<?> getHandlerType(MessageHandler handler) {
Class<?> root;
if (handler instanceof AsyncMessageHandler) {
return ((AsyncMessageHandler) handler).getType();
Expand Down
16 changes: 8 additions & 8 deletions core/src/main/java/org/glassfish/tyrus/core/TyrusSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.Map.Entry;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
Expand Down Expand Up @@ -61,7 +62,7 @@
* @author Martin Matula (martin.matula at oracle.com)
* @author Pavel Bucek (pavel.bucek at oracle.com)
*/
public class TyrusSession implements Session, DistributedSession {
public class TyrusSession implements DistributedSession {

private static final Logger LOGGER = Logger.getLogger(TyrusSession.class.getName());

Expand Down Expand Up @@ -587,9 +588,10 @@ <T> MessageHandler.Whole<T> getMessageHandler(Class<T> c) {
void notifyMessageHandlers(Object message, boolean last) {
boolean handled = false;

for (MessageHandler handler : getMessageHandlers()) {
for (Entry<Class<?>, MessageHandler> e : this.handlerManager.getRegisteredHandlers().entrySet()) {
MessageHandler handler = e.getValue();
if ((handler instanceof MessageHandler.Partial)
&& MessageHandlerManager.getHandlerType(handler).isAssignableFrom(message.getClass())) {
&& e.getKey().isAssignableFrom(message.getClass())) {

if (handler instanceof AsyncMessageHandler) {
checkMessageSize(message, ((AsyncMessageHandler) handler).getMaxMessageSize());
Expand All @@ -615,11 +617,9 @@ void notifyMessageHandlers(Object message, boolean last) {
}

void notifyPongHandler(PongMessage pongMessage) {
final Set<MessageHandler> messageHandlers = getMessageHandlers();
for (MessageHandler handler : messageHandlers) {
if (MessageHandlerManager.getHandlerType(handler).equals(PongMessage.class)) {
((MessageHandler.Whole<PongMessage>) handler).onMessage(pongMessage);
}
final MessageHandler.Whole<PongMessage> handler = getMessageHandler(PongMessage.class);
if (handler != null) {
handler.onMessage(pongMessage);
}
}

Expand Down
41 changes: 38 additions & 3 deletions core/src/test/java/org/glassfish/tyrus/core/TyrusSessionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ public void onMessage(InputStream message, boolean last) {
public void multiplePongHandlersAsync() {
Session session = createSession(endpointWrapper);


session.addMessageHandler(new MessageHandler.Partial<PongMessage>() {
@Override
public void onMessage(PongMessage message, boolean last) {
Expand All @@ -312,7 +311,6 @@ public void onMessage(PongMessage message, boolean last) {
public void multipleBasicDecodableAsync() {
Session session = createSession(endpointWrapper);


session.addMessageHandler(new MessageHandler.Partial<TyrusSessionTest>() {
@Override
public void onMessage(TyrusSessionTest message, boolean last) {
Expand Down Expand Up @@ -359,7 +357,6 @@ public void onMessage(PongMessage message) {
public void removeHandlers() {
Session session = createSession(endpointWrapper);


final MessageHandler.Partial<String> handler1 = new MessageHandler.Partial<String>() {
@Override
public void onMessage(String message, boolean last) {
Expand Down Expand Up @@ -408,6 +405,44 @@ public void idTest() {
assertFalse(session2.getId().equals(session3.getId()));
}

@Test
public void getLambdaHandlers() {
Session session = createSession(endpointWrapper);

final MessageHandler.Partial<String> handler1 = this::stringPartialHandler;
final MessageHandler.Whole<ByteBuffer> handler2 = this::bytesHandler;
final MessageHandler.Whole<PongMessage> handler3 = this::pongHandler;

session.addMessageHandler(String.class, handler1);
session.addMessageHandler(ByteBuffer.class, handler2);
session.addMessageHandler(PongMessage.class, handler3);

assertTrue(session.getMessageHandlers().contains(handler1));
assertTrue(session.getMessageHandlers().contains(handler2));
assertTrue(session.getMessageHandlers().contains(handler3));

session.removeMessageHandler(handler3);

assertTrue(session.getMessageHandlers().contains(handler1));
assertTrue(session.getMessageHandlers().contains(handler2));
assertFalse(session.getMessageHandlers().contains(handler3));

session.removeMessageHandler(handler2);

assertTrue(session.getMessageHandlers().contains(handler1));
assertFalse(session.getMessageHandlers().contains(handler2));
assertFalse(session.getMessageHandlers().contains(handler3));
}

private void stringPartialHandler(String message, boolean last) {
}

private void bytesHandler(ByteBuffer message) {
}

private void pongHandler(PongMessage message) {
}


@ServerEndpoint(value = "/echo")
private static class EchoEndpoint extends Endpoint {
Expand Down

0 comments on commit 125b6cb

Please sign in to comment.