Skip to content

Commit

Permalink
Polishing
Browse files Browse the repository at this point in the history
  • Loading branch information
rstoyanchev committed Oct 11, 2023
1 parent a205eab commit 9eb39e1
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -273,7 +273,7 @@ public static SimpMessageHeaderAccessor create(SimpMessageType messageType) {
}

/**
* Create an instance from the payload and headers of the given Message.
* Create an instance by copying the headers of a Message.
*/
public static SimpMessageHeaderAccessor wrap(Message<?> message) {
return new SimpMessageHeaderAccessor(message);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -130,18 +130,17 @@ public UserDestinationResult resolveDestination(Message<?> message) {
return null;
}
String user = parseResult.getUser();
String sourceDestination = parseResult.getSourceDestination();
String sourceDest = parseResult.getSourceDestination();
Set<String> targetSet = new HashSet<>();
for (String sessionId : parseResult.getSessionIds()) {
String actualDestination = parseResult.getActualDestination();
String targetDestination = getTargetDestination(
sourceDestination, actualDestination, sessionId, user);
if (targetDestination != null) {
targetSet.add(targetDestination);
String actualDest = parseResult.getActualDestination();
String targetDest = getTargetDestination(sourceDest, actualDest, sessionId, user);
if (targetDest != null) {
targetSet.add(targetDest);
}
}
String subscribeDestination = parseResult.getSubscribeDestination();
return new UserDestinationResult(sourceDestination, targetSet, subscribeDestination, user);
String subscribeDest = parseResult.getSubscribeDestination();
return new UserDestinationResult(sourceDest, targetSet, subscribeDest, user);
}

@Nullable
Expand Down Expand Up @@ -283,22 +282,37 @@ public ParseResult(String sourceDest, String actualDest, String subscribeDest,
this.user = user;
}

/**
* The destination from the source message, e.g. "/user/{user}/queue/position-updates".
*/
public String getSourceDestination() {
return this.sourceDestination;
}

/**
* The actual destination, without any user prefix, e.g. "/queue/position-updates".
*/
public String getActualDestination() {
return this.actualDestination;
}

/**
* The user destination as it would be on a subscription, "/user/queue/position-updates".
*/
public String getSubscribeDestination() {
return this.subscribeDestination;
}

/**
* The session id or id's for the user.
*/
public Set<String> getSessionIds() {
return this.sessionIds;
}

/**
* The name of the user associated with the session.
*/
@Nullable
public String getUser() {
return this.user;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
/**
* {@code MessageHandler} with support for "user" destinations.
*
* <p>Listens for messages with "user" destinations, translates their destination
* to actual target destinations unique to the active session(s) of a user, and
* then sends the resolved messages to the broker channel to be delivered.
* <p>Listen for messages with "user" destinations, translate the destination to
* a target destination that's unique to the active user session(s), and send
* to the broker channel for delivery.
*
* @author Rossen Stoyanchev
* @since 4.0
Expand Down Expand Up @@ -75,24 +75,24 @@ public class UserDestinationMessageHandler implements MessageHandler, SmartLifec


/**
* Create an instance with the given client and broker channels subscribing
* to handle messages from each and then sending any resolved messages to the
* broker channel.
* Create an instance with the given client and broker channels to subscribe to,
* and then send resolved messages to the broker channel.
* @param clientInboundChannel messages received from clients.
* @param brokerChannel messages sent to the broker.
* @param resolver the resolver for "user" destinations.
* @param destinationResolver the resolver for "user" destinations.
*/
public UserDestinationMessageHandler(SubscribableChannel clientInboundChannel,
SubscribableChannel brokerChannel, UserDestinationResolver resolver) {
public UserDestinationMessageHandler(
SubscribableChannel clientInboundChannel, SubscribableChannel brokerChannel,
UserDestinationResolver destinationResolver) {

Assert.notNull(clientInboundChannel, "'clientInChannel' must not be null");
Assert.notNull(brokerChannel, "'brokerChannel' must not be null");
Assert.notNull(resolver, "resolver must not be null");
Assert.notNull(destinationResolver, "resolver must not be null");

this.clientInboundChannel = clientInboundChannel;
this.brokerChannel = brokerChannel;
this.messagingTemplate = new SimpMessagingTemplate(brokerChannel);
this.destinationResolver = resolver;
this.destinationResolver = destinationResolver;
}


Expand Down Expand Up @@ -182,16 +182,16 @@ public final boolean isRunning() {


@Override
public void handleMessage(Message<?> message) throws MessagingException {
Message<?> messageToUse = message;
public void handleMessage(Message<?> sourceMessage) throws MessagingException {
Message<?> message = sourceMessage;
if (this.broadcastHandler != null) {
messageToUse = this.broadcastHandler.preHandle(message);
if (messageToUse == null) {
message = this.broadcastHandler.preHandle(sourceMessage);
if (message == null) {
return;
}
}

UserDestinationResult result = this.destinationResolver.resolveDestination(messageToUse);
UserDestinationResult result = this.destinationResolver.resolveDestination(message);
if (result == null) {
return;
}
Expand All @@ -201,22 +201,22 @@ public void handleMessage(Message<?> message) throws MessagingException {
logger.trace("No active sessions for user destination: " + result.getSourceDestination());
}
if (this.broadcastHandler != null) {
this.broadcastHandler.handleUnresolved(messageToUse);
this.broadcastHandler.handleUnresolved(message);
}
return;
}

SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(messageToUse);
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(message);
initHeaders(accessor);
accessor.setNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION, result.getSubscribeDestination());
accessor.setLeaveMutable(true);

messageToUse = MessageBuilder.createMessage(messageToUse.getPayload(), accessor.getMessageHeaders());
message = MessageBuilder.createMessage(message.getPayload(), accessor.getMessageHeaders());
if (logger.isTraceEnabled()) {
logger.trace("Translated " + result.getSourceDestination() + " -> " + result.getTargetDestinations());
}
for (String target : result.getTargetDestinations()) {
this.messagingTemplate.send(target, messageToUse);
this.messagingTemplate.send(target, message);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ class StompBrokerRelayMessageHandlerTests {

private StompBrokerRelayMessageHandler brokerRelay;

private StubMessageChannel outboundChannel = new StubMessageChannel();
private final StubMessageChannel outboundChannel = new StubMessageChannel();

private StubTcpOperations tcpClient = new StubTcpOperations();
private final StubTcpOperations tcpClient = new StubTcpOperations();

private ArgumentCaptor<Runnable> messageCountTaskCaptor = ArgumentCaptor.forClass(Runnable.class);
private final ArgumentCaptor<Runnable> messageCountTaskCaptor = ArgumentCaptor.forClass(Runnable.class);


@BeforeEach
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.mockito.Mockito;

import org.springframework.core.testfixture.security.TestPrincipal;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.StubMessageChannel;
import org.springframework.messaging.SubscribableChannel;
Expand Down Expand Up @@ -50,7 +51,8 @@ class UserDestinationMessageHandlerTests {

private final SubscribableChannel brokerChannel = mock();

private final UserDestinationMessageHandler handler = new UserDestinationMessageHandler(new StubMessageChannel(), this.brokerChannel, new DefaultUserDestinationResolver(this.registry));
private final UserDestinationMessageHandler handler = new UserDestinationMessageHandler(
new StubMessageChannel(), this.brokerChannel, new DefaultUserDestinationResolver(this.registry));


@Test
Expand Down Expand Up @@ -184,7 +186,9 @@ void ignoreMessage() {
}


private Message<?> createWith(SimpMessageType type, String user, String sessionId, String destination) {
private Message<?> createWith(
SimpMessageType type, @Nullable String user, @Nullable String sessionId, @Nullable String destination) {

SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type);
if (destination != null) {
headers.setDestination(destination);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
@Nullable
private MessageHeaderInitializer headerInitializer;

private boolean preserveReceiveOrder;

private final Map<String, MessageChannel> messageChannels = new ConcurrentHashMap<>();
@Nullable
private Map<String, MessageChannel> orderedHandlingMessageChannels;

private final Map<String, Principal> stompAuthentications = new ConcurrentHashMap<>();

Expand Down Expand Up @@ -209,7 +208,7 @@ public MessageHeaderInitializer getHeaderInitializer() {
* @since 6.1
*/
public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
this.preserveReceiveOrder = preserveReceiveOrder;
this.orderedHandlingMessageChannels = (preserveReceiveOrder ? new ConcurrentHashMap<>() : null);
}

/**
Expand All @@ -218,7 +217,7 @@ public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
* @since 6.1
*/
public boolean isPreserveReceiveOrder() {
return this.preserveReceiveOrder;
return (this.orderedHandlingMessageChannels != null);
}

@Override
Expand Down Expand Up @@ -253,7 +252,7 @@ public Stats getStats() {
*/
@Override
public void handleMessageFromClient(WebSocketSession session,
WebSocketMessage<?> webSocketMessage, MessageChannel outputChannel) {
WebSocketMessage<?> webSocketMessage, MessageChannel targetChannel) {

List<Message<byte[]>> messages;
try {
Expand Down Expand Up @@ -296,11 +295,11 @@ else if (webSocketMessage instanceof BinaryMessage binaryMessage) {
return;
}

MessageChannel channelToUse =
(this.messageChannels.computeIfAbsent(session.getId(),
id -> this.preserveReceiveOrder ?
new OrderedMessageChannelDecorator(outputChannel, logger) :
outputChannel));
MessageChannel channelToUse = targetChannel;
if (this.orderedHandlingMessageChannels != null) {
channelToUse = this.orderedHandlingMessageChannels.computeIfAbsent(
session.getId(), id -> new OrderedMessageChannelDecorator(targetChannel, logger));
}

for (Message<byte[]> message : messages) {
StompHeaderAccessor headerAccessor =
Expand All @@ -324,7 +323,7 @@ else if (webSocketMessage instanceof BinaryMessage binaryMessage) {
});
}
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
if (!detectImmutableMessageInterceptor(outputChannel)) {
if (!detectImmutableMessageInterceptor(targetChannel)) {
headerAccessor.setImmutable();
}

Expand Down Expand Up @@ -686,7 +685,9 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
outputChannel.send(message);
}
finally {
this.messageChannels.remove(session.getId());
if (this.orderedHandlingMessageChannels != null) {
this.orderedHandlingMessageChannels.remove(session.getId());
}
this.stompAuthentications.remove(session.getId());
SimpAttributesContextHolder.resetAttributes();
simpAttributes.sessionCompleted();
Expand Down

0 comments on commit 9eb39e1

Please sign in to comment.