Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug when process OOMs due to retained LimitableRequestPublishers in RSocketServer #638

Merged
merged 9 commits into from
May 21, 2019
13 changes: 3 additions & 10 deletions rsocket-core/src/main/java/io/rsocket/RSocketClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,8 @@ class RSocketClient implements RSocket {
this.sendProcessor = new UnboundedProcessor<>();

connection.onClose().doFinally(signalType -> terminate()).subscribe(null, errorConsumer);

sendProcessor
.doOnRequest(
r -> {
for (LimitableRequestPublisher lrp : senders.values()) {
lrp.increaseInternalLimit(r);
}
})
.transform(connection::send)
connection
.send(sendProcessor)
.doFinally(this::handleSendProcessorCancel)
.subscribe(null, this::handleSendProcessorError);

Expand Down Expand Up @@ -329,7 +322,7 @@ public void accept(long n) {
.transform(
f -> {
LimitableRequestPublisher<Payload> wrapped =
LimitableRequestPublisher.wrap(f, sendProcessor.available());
LimitableRequestPublisher.wrap(f);
// Need to set this to one for first the frame
wrapped.request(1);
senders.put(streamId, wrapped);
Expand Down
28 changes: 10 additions & 18 deletions rsocket-core/src/main/java/io/rsocket/RSocketServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,8 @@ class RSocketServer implements ResponderRSocket {
// connections
this.sendProcessor = new UnboundedProcessor<>();

sendProcessor
.doOnRequest(
r -> {
for (LimitableRequestPublisher lrp : sendingLimitableSubscriptions.values()) {
lrp.increaseInternalLimit(r);
}
})
.transform(connection::send)
connection
.send(sendProcessor)
.doFinally(this::handleSendProcessorCancel)
.subscribe(null, this::handleSendProcessorError);

Expand Down Expand Up @@ -322,16 +316,14 @@ private void handleFrame(ByteBuf frame) {
handleRequestN(streamId, frame);
break;
case REQUEST_STREAM:
handleStream(
streamId,
requestStream(payloadDecoder.apply(frame)),
RequestStreamFrameFlyweight.initialRequestN(frame));
int streamInitialRequestN = RequestStreamFrameFlyweight.initialRequestN(frame);
Payload streamPayload = payloadDecoder.apply(frame);
handleStream(streamId, requestStream(streamPayload), streamInitialRequestN);
break;
case REQUEST_CHANNEL:
handleChannel(
streamId,
payloadDecoder.apply(frame),
RequestChannelFrameFlyweight.initialRequestN(frame));
int channelInitialRequestN = RequestChannelFrameFlyweight.initialRequestN(frame);
Payload channelPayload = payloadDecoder.apply(frame);
handleChannel(streamId, channelPayload, channelInitialRequestN);
break;
case METADATA_PUSH:
metadataPush(payloadDecoder.apply(frame));
Expand Down Expand Up @@ -459,7 +451,7 @@ private void handleStream(int streamId, Flux<Payload> response, int initialReque
.transform(
frameFlux -> {
LimitableRequestPublisher<Payload> payloads =
LimitableRequestPublisher.wrap(frameFlux, sendProcessor.available());
LimitableRequestPublisher.wrap(frameFlux);
sendingLimitableSubscriptions.put(streamId, payloads);
payloads.request(
initialRequestN >= Integer.MAX_VALUE ? Long.MAX_VALUE : initialRequestN);
Expand Down Expand Up @@ -535,7 +527,7 @@ private void handleCancelFrame(int streamId) {
Subscription subscription = sendingSubscriptions.remove(streamId);

if (subscription == null) {
subscription = sendingLimitableSubscriptions.get(streamId);
subscription = sendingLimitableSubscriptions.remove(streamId);
}

if (subscription != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ public static <T> LimitableRequestPublisher<T> wrap(Publisher<T> source, long pr
return new LimitableRequestPublisher<>(source, prefetch);
}

public static <T> LimitableRequestPublisher<T> wrap(Publisher<T> source) {
return wrap(source, Long.MAX_VALUE);
}

@Override
public void subscribe(CoreSubscriber<? super T> destination) {
synchronized (this) {
Expand Down
29 changes: 0 additions & 29 deletions rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,12 @@
import io.rsocket.exceptions.ApplicationErrorException;
import io.rsocket.exceptions.RejectedSetupException;
import io.rsocket.frame.*;
import io.rsocket.test.util.TestDuplexConnection;
import io.rsocket.test.util.TestSubscriber;
import io.rsocket.util.DefaultPayload;
import io.rsocket.util.EmptyPayload;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;
import org.junit.Rule;
Expand Down Expand Up @@ -214,32 +211,6 @@ public void testChannelRequestServerSideCancellation() {
Assertions.assertThat(request.isDisposed()).isTrue();
}

@Test(timeout = 2_000)
@SuppressWarnings("unchecked")
public void
testClientSideRequestChannelShouldNotHangInfinitelySendingElementsAndShouldProduceDataValuingConnectionBackpressure() {
final Queue<Long> requests = new ConcurrentLinkedQueue<>();
rule.connection.dispose();
rule.connection = new TestDuplexConnection();
rule.connection.setInitialSendRequestN(256);
rule.init();

rule.socket
.requestChannel(
Flux.<Payload>generate(s -> s.next(EmptyPayload.INSTANCE)).doOnRequest(requests::add))
.subscribe();

int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL);

assertThat("Unexpected error.", rule.errors, is(empty()));

rule.connection.addToReceivedBuffer(
RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId, 2));
rule.connection.addToReceivedBuffer(
RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId, Integer.MAX_VALUE));
Assertions.assertThat(requests).containsOnly(1L, 2L, 253L);
}

public int sendRequestResponse(Publisher<Payload> response) {
Subscriber<Payload> sub = TestSubscriber.create();
response.subscribe(sub);
Expand Down
88 changes: 0 additions & 88 deletions rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,12 @@
import io.rsocket.util.DefaultPayload;
import io.rsocket.util.EmptyPayload;
import java.util.Collection;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import org.assertj.core.api.Assertions;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.Mockito;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class RSocketServerTest {
Expand Down Expand Up @@ -111,89 +106,6 @@ public Mono<Payload> requestResponse(Payload payload) {
assertThat("Subscription not cancelled.", cancelled.get(), is(true));
}

@Test(timeout = 2_000)
@SuppressWarnings("unchecked")
public void
testServerSideRequestStreamShouldNotHangInfinitelySendingElementsAndShouldProduceDataValuingConnectionBackpressure() {
final int streamId = 5;
final Queue<Object> received = new ConcurrentLinkedQueue<>();
final Queue<Long> requests = new ConcurrentLinkedQueue<>();

rule.setAcceptingSocket(
new AbstractRSocket() {
@Override
public Flux<Payload> requestStream(Payload payload) {
return Flux.<Payload>generate(s -> s.next(payload.retain())).doOnRequest(requests::add);
}
},
256);

rule.sendRequest(streamId, FrameType.REQUEST_STREAM);

assertThat("Unexpected error.", rule.errors, is(empty()));

Subscriber next = rule.connection.getSendSubscribers().iterator().next();

Mockito.doAnswer(
invocation -> {
received.add(invocation.getArgument(0));

if (received.size() == 256) {
throw new RuntimeException();
}

return null;
})
.when(next)
.onNext(Mockito.any());

rule.connection.addToReceivedBuffer(
RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId, Integer.MAX_VALUE));
Assertions.assertThat(requests).containsOnly(1L, 2L, 253L);
}

@Test(timeout = 2_000)
@SuppressWarnings("unchecked")
public void
testServerSideRequestChannelShouldNotHangInfinitelySendingElementsAndShouldProduceDataValuingConnectionBackpressure() {
final int streamId = 5;
final Queue<Object> received = new ConcurrentLinkedQueue<>();
final Queue<Long> requests = new ConcurrentLinkedQueue<>();

rule.setAcceptingSocket(
new AbstractRSocket() {
@Override
public Flux<Payload> requestChannel(Publisher<Payload> payload) {
return Flux.<Payload>generate(s -> s.next(EmptyPayload.INSTANCE))
.doOnRequest(requests::add);
}
},
256);

rule.sendRequest(streamId, FrameType.REQUEST_CHANNEL);

assertThat("Unexpected error.", rule.errors, is(empty()));

Subscriber next = rule.connection.getSendSubscribers().iterator().next();

Mockito.doAnswer(
invocation -> {
received.add(invocation.getArgument(0));

if (received.size() == 256) {
throw new RuntimeException();
}

return null;
})
.when(next)
.onNext(Mockito.any());

rule.connection.addToReceivedBuffer(
RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId, Integer.MAX_VALUE));
Assertions.assertThat(requests).containsOnly(1L, 2L, 253L);
}

public static class ServerSocketRule extends AbstractSocketRule<RSocketServer> {

private RSocket acceptingSocket;
Expand Down