From ed172d62692b97441be365226d2dd0973210ec69 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 12 Sep 2023 10:00:47 +0100 Subject: [PATCH] ByteBuffer handling for Jetty WebSocket messages Closes gh-31182 --- .../adapter/JettyWebSocketHandlerAdapter.java | 255 ++++++++++++++++-- .../jetty/JettyWebSocketHandlerAdapter.java | 14 +- 2 files changed, 238 insertions(+), 31 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java index ad917a681729..71c0a2c840c1 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/JettyWebSocketHandlerAdapter.java @@ -17,8 +17,10 @@ package org.springframework.web.reactive.socket.adapter; import java.nio.ByteBuffer; +import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.function.Function; +import java.util.function.IntPredicate; import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.Frame; @@ -31,14 +33,15 @@ import org.eclipse.jetty.websocket.api.annotations.WebSocket; import org.eclipse.jetty.websocket.core.OpCode; +import org.springframework.core.io.buffer.CloseableDataBuffer; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.web.reactive.socket.CloseStatus; import org.springframework.web.reactive.socket.WebSocketHandler; import org.springframework.web.reactive.socket.WebSocketMessage; import org.springframework.web.reactive.socket.WebSocketMessage.Type; -import org.springframework.web.reactive.socket.WebSocketSession; /** * Jetty {@link WebSocket @WebSocket} handler that delegates events to a @@ -83,17 +86,20 @@ public void onWebSocketOpen(Session session) { @OnWebSocketMessage public void onWebSocketText(String message) { if (this.delegateSession != null) { - WebSocketMessage webSocketMessage = toMessage(Type.TEXT, message); + byte[] bytes = message.getBytes(StandardCharsets.UTF_8); + DataBuffer buffer = this.delegateSession.bufferFactory().wrap(bytes); + WebSocketMessage webSocketMessage = new WebSocketMessage(Type.TEXT, buffer); this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); } } @OnWebSocketMessage - public void onWebSocketBinary(ByteBuffer buffer, Callback callback) { + public void onWebSocketBinary(ByteBuffer byteBuffer, Callback callback) { if (this.delegateSession != null) { - WebSocketMessage webSocketMessage = toMessage(Type.BINARY, buffer); + DataBuffer buffer = this.delegateSession.bufferFactory().wrap(byteBuffer); + buffer = new JettyDataBuffer(buffer, callback); + WebSocketMessage webSocketMessage = new WebSocketMessage(Type.BINARY, buffer); this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); - callback.succeed(); } } @@ -101,35 +107,15 @@ public void onWebSocketBinary(ByteBuffer buffer, Callback callback) { public void onWebSocketFrame(Frame frame, Callback callback) { if (this.delegateSession != null) { if (OpCode.PONG == frame.getOpCode()) { - ByteBuffer buffer = (frame.getPayload() != null ? frame.getPayload() : EMPTY_PAYLOAD); - WebSocketMessage webSocketMessage = toMessage(Type.PONG, buffer); + ByteBuffer byteBuffer = (frame.getPayload() != null ? frame.getPayload() : EMPTY_PAYLOAD); + DataBuffer buffer = this.delegateSession.bufferFactory().wrap(byteBuffer); + buffer = new JettyDataBuffer(buffer, callback); + WebSocketMessage webSocketMessage = new WebSocketMessage(Type.PONG, buffer); this.delegateSession.handleMessage(webSocketMessage.getType(), webSocketMessage); - callback.succeed(); } } } - private WebSocketMessage toMessage(Type type, T message) { - WebSocketSession session = this.delegateSession; - Assert.state(session != null, "Cannot create message without a session"); - if (Type.TEXT.equals(type)) { - byte[] bytes = ((String) message).getBytes(StandardCharsets.UTF_8); - DataBuffer buffer = session.bufferFactory().wrap(bytes); - return new WebSocketMessage(Type.TEXT, buffer); - } - else if (Type.BINARY.equals(type)) { - DataBuffer buffer = session.bufferFactory().wrap((ByteBuffer) message); - return new WebSocketMessage(Type.BINARY, buffer); - } - else if (Type.PONG.equals(type)) { - DataBuffer buffer = session.bufferFactory().wrap((ByteBuffer) message); - return new WebSocketMessage(Type.PONG, buffer); - } - else { - throw new IllegalArgumentException("Unexpected message type: " + message); - } - } - @OnWebSocketClose public void onWebSocketClose(int statusCode, String reason) { if (this.delegateSession != null) { @@ -144,4 +130,215 @@ public void onWebSocketError(Throwable cause) { } } + + private static final class JettyDataBuffer implements CloseableDataBuffer { + + private final DataBuffer delegate; + + private final Callback callback; + + public JettyDataBuffer(DataBuffer delegate, Callback callback) { + Assert.notNull(delegate, "'delegate` must not be null"); + Assert.notNull(callback, "Callback must not be null"); + this.delegate = delegate; + this.callback = callback; + } + + @Override + public void close() { + this.callback.succeed(); + } + + // delegation + + @Override + public DataBufferFactory factory() { + return this.delegate.factory(); + } + + @Override + public int indexOf(IntPredicate predicate, int fromIndex) { + return this.delegate.indexOf(predicate, fromIndex); + } + + @Override + public int lastIndexOf(IntPredicate predicate, int fromIndex) { + return this.delegate.lastIndexOf(predicate, fromIndex); + } + + @Override + public int readableByteCount() { + return this.delegate.readableByteCount(); + } + + @Override + public int writableByteCount() { + return this.delegate.writableByteCount(); + } + + @Override + public int capacity() { + return this.delegate.capacity(); + } + + @Override + @Deprecated + public DataBuffer capacity(int capacity) { + this.delegate.capacity(capacity); + return this; + } + + @Override + public DataBuffer ensureWritable(int capacity) { + this.delegate.ensureWritable(capacity); + return this; + } + + @Override + public int readPosition() { + return this.delegate.readPosition(); + } + + @Override + public DataBuffer readPosition(int readPosition) { + this.delegate.readPosition(readPosition); + return this; + } + + @Override + public int writePosition() { + return this.delegate.writePosition(); + } + + @Override + public DataBuffer writePosition(int writePosition) { + this.delegate.writePosition(writePosition); + return this; + } + + @Override + public byte getByte(int index) { + return this.delegate.getByte(index); + } + + @Override + public byte read() { + return this.delegate.read(); + } + + @Override + public DataBuffer read(byte[] destination) { + this.delegate.read(destination); + return this; + } + + @Override + public DataBuffer read(byte[] destination, int offset, int length) { + this.delegate.read(destination, offset, length); + return this; + } + + @Override + public DataBuffer write(byte b) { + this.delegate.write(b); + return this; + } + + @Override + public DataBuffer write(byte[] source) { + this.delegate.write(source); + return this; + } + + @Override + public DataBuffer write(byte[] source, int offset, int length) { + this.delegate.write(source, offset, length); + return this; + } + + @Override + public DataBuffer write(DataBuffer... buffers) { + this.delegate.write(buffers); + return this; + } + + @Override + public DataBuffer write(ByteBuffer... buffers) { + this.delegate.write(buffers); + return this; + } + + @Override + @Deprecated + public DataBuffer slice(int index, int length) { + DataBuffer delegateSlice = this.delegate.slice(index, length); + return new JettyDataBuffer(delegateSlice, this.callback); + } + + @Override + public DataBuffer split(int index) { + DataBuffer delegateSplit = this.delegate.split(index); + return new JettyDataBuffer(delegateSplit, this.callback); + } + + @Override + @Deprecated + public ByteBuffer asByteBuffer() { + return this.delegate.asByteBuffer(); + } + + @Override + @Deprecated + public ByteBuffer asByteBuffer(int index, int length) { + return this.delegate.asByteBuffer(index, length); + } + + @Override + @Deprecated + public ByteBuffer toByteBuffer(int index, int length) { + return this.delegate.toByteBuffer(index, length); + } + + @Override + public void toByteBuffer(int srcPos, ByteBuffer dest, int destPos, int length) { + this.delegate.toByteBuffer(srcPos, dest, destPos, length); + } + + @Override + public ByteBufferIterator readableByteBuffers() { + ByteBufferIterator delegateIterator = this.delegate.readableByteBuffers(); + return new JettyByteBufferIterator(delegateIterator); + } + + @Override + public ByteBufferIterator writableByteBuffers() { + ByteBufferIterator delegateIterator = this.delegate.writableByteBuffers(); + return new JettyByteBufferIterator(delegateIterator); + } + + @Override + public String toString(int index, int length, Charset charset) { + return this.delegate.toString(index, length, charset); + } + + + private record JettyByteBufferIterator(ByteBufferIterator delegate) implements ByteBufferIterator { + + @Override + public void close() { + this.delegate.close(); + } + + @Override + public boolean hasNext() { + return this.delegate.hasNext(); + } + + @Override + public ByteBuffer next() { + return this.delegate.next(); + } + } + } + } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketHandlerAdapter.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketHandlerAdapter.java index 334fe8b25c8a..a18e931c72b0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketHandlerAdapter.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketHandlerAdapter.java @@ -90,11 +90,13 @@ public void onWebSocketText(String payload) { @OnWebSocketMessage public void onWebSocketBinary(ByteBuffer payload, Callback callback) { - BinaryMessage message = new BinaryMessage(payload, true); + BinaryMessage message = new BinaryMessage(copyByteBuffer(payload), true); try { this.webSocketHandler.handleMessage(this.wsSession, message); + callback.succeed(); } catch (Exception ex) { + callback.fail(ex); ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.wsSession, ex, logger); } } @@ -103,16 +105,24 @@ public void onWebSocketBinary(ByteBuffer payload, Callback callback) { public void onWebSocketFrame(Frame frame, Callback callback) { if (OpCode.PONG == frame.getOpCode()) { ByteBuffer payload = frame.getPayload() != null ? frame.getPayload() : EMPTY_PAYLOAD; - PongMessage message = new PongMessage(payload); + PongMessage message = new PongMessage(copyByteBuffer(payload)); try { this.webSocketHandler.handleMessage(this.wsSession, message); + callback.succeed(); } catch (Exception ex) { + callback.fail(ex); ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.wsSession, ex, logger); } } } + private static ByteBuffer copyByteBuffer(ByteBuffer src) { + ByteBuffer dest = ByteBuffer.allocate(src.capacity()); + dest.put(0, src, 0, src.remaining()); + return dest; + } + @OnWebSocketClose public void onWebSocketClose(int statusCode, String reason) { CloseStatus closeStatus = new CloseStatus(statusCode, reason);