diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java index c241752667ce..08a27792f5ce 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -123,7 +123,7 @@ public final void onDataAvailable() { } /** - * Subclasses can call this method to delegate a contain notification when + * Subclasses can call this method to delegate a container notification when * all data has been read. */ public void onAllDataRead() { @@ -362,6 +362,12 @@ void onError(AbstractListenerReadPublisher publisher, Throwable ex) { publisher.errorPending = ex; publisher.handlePendingCompletionOrError(); } + + @Override + void cancel(AbstractListenerReadPublisher publisher) { + publisher.completionPending = true; + publisher.handlePendingCompletionOrError(); + } }, NO_DEMAND { @@ -435,6 +441,13 @@ void onError(AbstractListenerReadPublisher publisher, Throwable ex) { publisher.errorPending = ex; publisher.handlePendingCompletionOrError(); } + + @Override + void cancel(AbstractListenerReadPublisher publisher) { + publisher.discardData(); + publisher.completionPending = true; + publisher.handlePendingCompletionOrError(); + } }, COMPLETED { diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java index 8f42edbf676e..e6ac7463f422 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ListenerReadPublisherTests.java @@ -104,6 +104,9 @@ protected void checkOnDataAvailable() { @Override protected DataBuffer read() { + if (this.discardCalls != 0) { + return null; + } this.readCalls++; return mock(); } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java index ab2d9171ca94..4c54f726e02f 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/adapter/AbstractListenerWebSocketSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -290,8 +290,10 @@ else if (rsReadLogger.isTraceEnabled()) { @Override protected void discardData() { + Queue queue = this.pendingMessages; + this.pendingMessages = Queues.empty().get(); // prevent further reading while (true) { - WebSocketMessage message = (WebSocketMessage) this.pendingMessages.poll(); + WebSocketMessage message = (WebSocketMessage) queue.poll(); if (message == null) { return; }