diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java index 70b2c0b2e545..66206f667953 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java @@ -143,7 +143,7 @@ private Message decodeMessage(ByteBuffer byteBuffer, @Nullable MultiValu StompCommand stompCommand = StompCommand.valueOf(command); headerAccessor = StompHeaderAccessor.create(stompCommand); initHeaders(headerAccessor); - readHeaders(byteBuffer, headerAccessor); + readHeaders(stompCommand, byteBuffer, headerAccessor); payload = readPayload(byteBuffer, headerAccessor); } if (payload != null) { @@ -215,7 +215,9 @@ private String readCommand(ByteBuffer byteBuffer) { return StreamUtils.copyToString(command, StandardCharsets.UTF_8); } - private void readHeaders(ByteBuffer byteBuffer, StompHeaderAccessor headerAccessor) { + private void readHeaders(StompCommand stompCommand, ByteBuffer byteBuffer, StompHeaderAccessor headerAccessor) { + boolean shouldUnescape = (stompCommand != StompCommand.CONNECT && stompCommand != StompCommand.STOMP + && stompCommand != StompCommand.CONNECTED); while (true) { ByteArrayOutputStream headerStream = new ByteArrayOutputStream(256); boolean headerComplete = false; @@ -236,8 +238,8 @@ private void readHeaders(ByteBuffer byteBuffer, StompHeaderAccessor headerAccess } } else { - String headerName = unescape(header.substring(0, colonIndex)); - String headerValue = unescape(header.substring(colonIndex + 1)); + String headerName = shouldUnescape ? unescape(header.substring(0, colonIndex)) : header.substring(0, colonIndex); + String headerValue = shouldUnescape ? unescape(header.substring(colonIndex + 1)) : header.substring(colonIndex + 1); try { headerAccessor.addNativeHeader(headerName, headerValue); } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompDecoderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompDecoderTests.java index 439b157d2d6e..b0797dc6c171 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompDecoderTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/StompDecoderTests.java @@ -159,6 +159,23 @@ public void decodeFrameWithEscapedHeaders() { assertThat(headers.getFirstNativeHeader("a:\r\n\\b")).isEqualTo("alpha:bravo\r\n\\"); } + @Test + public void decodeFrameWithHeaderWithBackslashValue() { + String accept = "accept-version:1.1\n"; + String keyAndValueWithBackslash = "key:\\value\n"; + + Message frame = decode("CONNECT\n" + accept + keyAndValueWithBackslash + "\n\0"); + StompHeaderAccessor headers = StompHeaderAccessor.wrap(frame); + + assertThat(headers.getCommand()).isEqualTo(StompCommand.CONNECT); + + assertThat(headers.toNativeHeaderMap().size()).isEqualTo(2); + assertThat(headers.getFirstNativeHeader("accept-version")).isEqualTo("1.1"); + assertThat(headers.getFirstNativeHeader("key")).isEqualTo("\\value"); + + assertThat(frame.getPayload().length).isEqualTo(0); + } + @Test public void decodeFrameBodyNotAllowed() { assertThatExceptionOfType(StompConversionException.class).isThrownBy(() ->