diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java index ec4c8cec2..f921365da 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java @@ -27,7 +27,7 @@ import io.rsocket.exceptions.Exceptions; import io.rsocket.frame.*; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.LimitableRequestPublisher; +import io.rsocket.internal.RateLimitableRequestPublisher; import io.rsocket.internal.SynchronizedIntObjectHashMap; import io.rsocket.internal.UnboundedProcessor; import io.rsocket.internal.UnicastMonoProcessor; @@ -47,6 +47,7 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import reactor.core.publisher.*; +import reactor.util.concurrent.Queues; /** * Requester Side of a RSocket socket. Sends {@link ByteBuf}s to a {@link RSocketResponder} of peer @@ -60,7 +61,7 @@ class RSocketRequester implements RSocket { private final PayloadDecoder payloadDecoder; private final Consumer errorConsumer; private final StreamIdSupplier streamIdSupplier; - private final IntObjectMap senders; + private final IntObjectMap senders; private final IntObjectMap> receivers; private final UnboundedProcessor sendProcessor; private final RequesterLeaseHandler leaseHandler; @@ -131,7 +132,7 @@ private void handleSendProcessorError(Throwable t) { } }); - senders.values().forEach(LimitableRequestPublisher::cancel); + senders.values().forEach(RateLimitableRequestPublisher::cancel); } private void handleSendProcessorCancel(SignalType t) { @@ -150,7 +151,7 @@ private void handleSendProcessorCancel(SignalType t) { } }); - senders.values().forEach(LimitableRequestPublisher::cancel); + senders.values().forEach(RateLimitableRequestPublisher::cancel); } @Override @@ -343,8 +344,8 @@ public void accept(long n) { request .transform( f -> { - LimitableRequestPublisher wrapped = - LimitableRequestPublisher.wrap(f); + RateLimitableRequestPublisher wrapped = + RateLimitableRequestPublisher.wrap(f, Queues.SMALL_BUFFER_SIZE); // Need to set this to one for first the frame wrapped.request(1); senders.put(streamId, wrapped); @@ -421,7 +422,7 @@ protected void hookOnError(Throwable t) { .doFinally( s -> { receivers.remove(streamId); - LimitableRequestPublisher sender = senders.remove(streamId); + RateLimitableRequestPublisher sender = senders.remove(streamId); if (sender != null) { sender.cancel(); } @@ -489,7 +490,7 @@ private void setTerminationError(Throwable error) { } private synchronized void cleanUpLimitableRequestPublisher( - LimitableRequestPublisher limitableRequestPublisher) { + RateLimitableRequestPublisher limitableRequestPublisher) { try { limitableRequestPublisher.cancel(); } catch (Throwable t) { @@ -561,7 +562,7 @@ private void handleFrame(int streamId, FrameType type, ByteBuf frame) { break; case CANCEL: { - LimitableRequestPublisher sender = senders.remove(streamId); + RateLimitableRequestPublisher sender = senders.remove(streamId); if (sender != null) { sender.cancel(); } @@ -572,7 +573,7 @@ private void handleFrame(int streamId, FrameType type, ByteBuf frame) { break; case REQUEST_N: { - LimitableRequestPublisher sender = senders.get(streamId); + RateLimitableRequestPublisher sender = senders.get(streamId); if (sender != null) { int n = RequestNFrameFlyweight.requestN(frame); sender.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n); diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/RSocketResponder.java index 3bd221d64..490b00967 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketResponder.java @@ -23,7 +23,7 @@ import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.frame.*; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.LimitableRequestPublisher; +import io.rsocket.internal.RateLimitableRequestPublisher; import io.rsocket.internal.SynchronizedIntObjectHashMap; import io.rsocket.internal.UnboundedProcessor; import io.rsocket.lease.ResponderLeaseHandler; @@ -35,6 +35,7 @@ import reactor.core.Disposable; import reactor.core.Exceptions; import reactor.core.publisher.*; +import reactor.util.concurrent.Queues; /** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ class RSocketResponder implements ResponderRSocket { @@ -46,7 +47,7 @@ class RSocketResponder implements ResponderRSocket { private final Consumer errorConsumer; private final ResponderLeaseHandler leaseHandler; - private final IntObjectMap sendingLimitableSubscriptions; + private final IntObjectMap sendingLimitableSubscriptions; private final IntObjectMap sendingSubscriptions; private final IntObjectMap> channelProcessors; @@ -435,8 +436,8 @@ private void handleStream(int streamId, Flux response, int initialReque response .transform( frameFlux -> { - LimitableRequestPublisher payloads = - LimitableRequestPublisher.wrap(frameFlux); + RateLimitableRequestPublisher payloads = + RateLimitableRequestPublisher.wrap(frameFlux, Queues.SMALL_BUFFER_SIZE); sendingLimitableSubscriptions.put(streamId, payloads); payloads.request( initialRequestN >= Integer.MAX_VALUE ? Long.MAX_VALUE : initialRequestN); diff --git a/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestPublisher.java b/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestPublisher.java new file mode 100755 index 000000000..cdb0d0c0c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestPublisher.java @@ -0,0 +1,242 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import javax.annotation.Nullable; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; + +/** */ +public class RateLimitableRequestPublisher extends Flux implements Subscription { + + private static final int NOT_CANCELED_STATE = 0; + private static final int CANCELED_STATE = 1; + + private final Publisher source; + + private volatile int canceled; + private static final AtomicIntegerFieldUpdater CANCELED = + AtomicIntegerFieldUpdater.newUpdater(RateLimitableRequestPublisher.class, "canceled"); + + private final long prefetch; + private final long limit; + + private long externalRequested; // need sync + private int pendingToFulfil; // need sync since should be checked/zerroed in onNext + // and increased in request + private int deliveredElements; // no need to sync since increased zerroed only in + // the request method + + private boolean subscribed; + + private @Nullable Subscription internalSubscription; + + private RateLimitableRequestPublisher(Publisher source, long prefetch) { + this.source = source; + this.prefetch = prefetch; + this.limit = prefetch == Integer.MAX_VALUE ? Integer.MAX_VALUE : (prefetch - (prefetch >> 2)); + } + + public static RateLimitableRequestPublisher wrap(Publisher source, long prefetch) { + return new RateLimitableRequestPublisher<>(source, prefetch); + } + + @Override + public void subscribe(CoreSubscriber destination) { + synchronized (this) { + if (subscribed) { + throw new IllegalStateException("only one subscriber at a time"); + } + + subscribed = true; + } + final InnerOperator s = new InnerOperator(destination); + + source.subscribe(s); + destination.onSubscribe(s); + } + + @Override + public void request(long n) { + synchronized (this) { + long requested = externalRequested; + if (requested == Long.MAX_VALUE) { + return; + } + externalRequested = Operators.addCap(n, requested); + } + + requestN(); + } + + private void requestN() { + final long r; + final Subscription s; + + synchronized (this) { + s = internalSubscription; + if (s == null) { + return; + } + + final long er = externalRequested; + final long p = prefetch; + final int pendingFulfil = pendingToFulfil; + + if (er != Long.MAX_VALUE || p != Integer.MAX_VALUE) { + // shortcut + if (pendingFulfil == p) { + return; + } + + r = Math.min(p - pendingFulfil, er); + if (er != Long.MAX_VALUE) { + externalRequested -= r; + } + if (p != Integer.MAX_VALUE) { + pendingToFulfil += r; + } + } else { + r = Long.MAX_VALUE; + } + } + + if (r > 0) { + s.request(r); + } + } + + public void cancel() { + if (!isCanceled() && CANCELED.compareAndSet(this, NOT_CANCELED_STATE, CANCELED_STATE)) { + Subscription s; + + synchronized (this) { + s = internalSubscription; + internalSubscription = null; + subscribed = false; + } + + if (s != null) { + s.cancel(); + } + } + } + + private boolean isCanceled() { + return canceled == CANCELED_STATE; + } + + private class InnerOperator implements CoreSubscriber, Subscription { + final Subscriber destination; + + private InnerOperator(Subscriber destination) { + this.destination = destination; + } + + @Override + public void onSubscribe(Subscription s) { + synchronized (RateLimitableRequestPublisher.this) { + RateLimitableRequestPublisher.this.internalSubscription = s; + + if (isCanceled()) { + s.cancel(); + subscribed = false; + RateLimitableRequestPublisher.this.internalSubscription = null; + } + } + + requestN(); + } + + @Override + public void onNext(T t) { + try { + destination.onNext(t); + + if (prefetch == Integer.MAX_VALUE) { + return; + } + + final long l = limit; + int d = deliveredElements + 1; + + if (d == l) { + d = 0; + final long r; + final Subscription s; + + synchronized (RateLimitableRequestPublisher.this) { + long er = externalRequested; + s = internalSubscription; + + if (s == null) { + return; + } + + if (er >= l) { + er -= l; + // keep pendingToFulfil as is since it is eq to prefetch + r = l; + } else { + pendingToFulfil -= l; + if (er > 0) { + r = er; + er = 0; + pendingToFulfil += r; + } else { + r = 0; + } + } + + externalRequested = er; + } + + if (r > 0) { + s.request(r); + } + } + + deliveredElements = d; + } catch (Throwable e) { + onError(e); + } + } + + @Override + public void onError(Throwable t) { + destination.onError(t); + } + + @Override + public void onComplete() { + destination.onComplete(); + } + + @Override + public void request(long n) {} + + @Override + public void cancel() { + RateLimitableRequestPublisher.this.cancel(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/RateLimitableRequestPublisherTest.java b/rsocket-core/src/test/java/io/rsocket/internal/RateLimitableRequestPublisherTest.java new file mode 100644 index 000000000..af4c528e9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/RateLimitableRequestPublisherTest.java @@ -0,0 +1,140 @@ +package io.rsocket.internal; + +import static org.junit.jupiter.api.Assertions.*; + +import java.time.Duration; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Consumer; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; + +class RateLimitableRequestPublisherTest { + + @Test + public void testThatRequest1WillBePropagatedUpstream() { + Flux source = + Flux.just(1) + .subscribeOn(Schedulers.parallel()) + .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); + + RateLimitableRequestPublisher rateLimitableRequestPublisher = + RateLimitableRequestPublisher.wrap(source, 128); + + StepVerifier.create(rateLimitableRequestPublisher) + .then(() -> rateLimitableRequestPublisher.request(1)) + .expectNext(1) + .expectComplete() + .verify(Duration.ofMillis(1000)); + } + + @Test + public void testThatRequest256WillBePropagatedToUpstreamWithLimitedRate() { + Flux source = + Flux.range(0, 256) + .subscribeOn(Schedulers.parallel()) + .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); + + RateLimitableRequestPublisher rateLimitableRequestPublisher = + RateLimitableRequestPublisher.wrap(source, 128); + + StepVerifier.create(rateLimitableRequestPublisher) + .then(() -> rateLimitableRequestPublisher.request(256)) + .expectNextCount(256) + .expectComplete() + .verify(Duration.ofMillis(1000)); + } + + @Test + public void testThatRequest256WillBePropagatedToUpstreamWithLimitedRateInFewSteps() { + Flux source = + Flux.range(0, 256) + .subscribeOn(Schedulers.parallel()) + .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); + + RateLimitableRequestPublisher rateLimitableRequestPublisher = + RateLimitableRequestPublisher.wrap(source, 128); + + StepVerifier.create(rateLimitableRequestPublisher) + .then(() -> rateLimitableRequestPublisher.request(10)) + .expectNextCount(5) + .then(() -> rateLimitableRequestPublisher.request(128)) + .expectNextCount(133) + .expectNoEvent(Duration.ofMillis(10)) + .then(() -> rateLimitableRequestPublisher.request(Long.MAX_VALUE)) + .expectNextCount(118) + .expectComplete() + .verify(Duration.ofMillis(1000)); + } + + @Test + public void testThatRequestInRandomFashionWillBePropagatedToUpstreamWithLimitedRateInFewSteps() { + Flux source = + Flux.range(0, 10000000) + .subscribeOn(Schedulers.parallel()) + .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); + + RateLimitableRequestPublisher rateLimitableRequestPublisher = + RateLimitableRequestPublisher.wrap(source, 128); + + StepVerifier.create(rateLimitableRequestPublisher) + .then( + () -> + Flux.interval(Duration.ofMillis(1000)) + .onBackpressureDrop() + .subscribe( + new Consumer() { + int count = 10000000; + + @Override + public void accept(Long __) { + int random = ThreadLocalRandom.current().nextInt(1, 512); + + long request = Math.min(random, count); + + count -= request; + + rateLimitableRequestPublisher.request(count); + } + })) + .expectNextCount(10000000) + .expectComplete() + .verify(Duration.ofMillis(30000)); + } + + @Test + public void testThatRequestLongMaxValueWillBeDeliveredInSeparateChunks() { + Flux source = + Flux.range(0, 10000000) + .subscribeOn(Schedulers.parallel()) + .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); + + RateLimitableRequestPublisher rateLimitableRequestPublisher = + RateLimitableRequestPublisher.wrap(source, 128); + + StepVerifier.create(rateLimitableRequestPublisher) + .then(() -> rateLimitableRequestPublisher.request(Long.MAX_VALUE)) + .expectNextCount(10000000) + .expectComplete() + .verify(Duration.ofMillis(30000)); + } + + @Test + public void testThatRequestLongMaxWithIntegerMaxValuePrefetchWillBeDeliveredAsLongMaxValue() { + Flux source = + Flux.range(0, 10000000) + .subscribeOn(Schedulers.parallel()) + .doOnRequest(r -> Assertions.assertThat(r).isEqualTo(Long.MAX_VALUE)); + + RateLimitableRequestPublisher rateLimitableRequestPublisher = + RateLimitableRequestPublisher.wrap(source, Integer.MAX_VALUE); + + StepVerifier.create(rateLimitableRequestPublisher) + .then(() -> rateLimitableRequestPublisher.request(Long.MAX_VALUE)) + .expectNextCount(10000000) + .expectComplete() + .verify(Duration.ofMillis(30000)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java index fd48cd9d3..6298b0c3a 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java @@ -27,6 +27,7 @@ import org.slf4j.LoggerFactory; import reactor.core.publisher.DirectProcessor; import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; @@ -40,7 +41,9 @@ public class TestDuplexConnection implements DuplexConnection { private final LinkedBlockingQueue sent; private final DirectProcessor sentPublisher; + private final FluxSink sendSink; private final DirectProcessor received; + private final FluxSink receivedSink; private final MonoProcessor onClose; private final ConcurrentLinkedQueue> sendSubscribers; private volatile double availability = 1; @@ -49,7 +52,9 @@ public class TestDuplexConnection implements DuplexConnection { public TestDuplexConnection() { sent = new LinkedBlockingQueue<>(); received = DirectProcessor.create(); + receivedSink = received.sink(); sentPublisher = DirectProcessor.create(); + sendSink = sentPublisher.sink(); sendSubscribers = new ConcurrentLinkedQueue<>(); onClose = MonoProcessor.create(); } @@ -65,7 +70,7 @@ public Mono send(Publisher frames) { .doOnNext( frame -> { sent.offer(frame); - sentPublisher.onNext(frame); + sendSink.next(frame); }) .doOnError(throwable -> logger.error("Error in send stream on test connection.", throwable)) .subscribe(subscriber); @@ -116,7 +121,7 @@ public Publisher getSentAsPublisher() { public void addToReceivedBuffer(ByteBuf... received) { for (ByteBuf frame : received) { - this.received.onNext(frame); + this.receivedSink.next(frame); } } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java index 627b1d7da..3ef5fb7c8 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java @@ -95,7 +95,7 @@ public void startup() { requestCount = new AtomicInteger(); disconnectionCounter = new CountDownLatch(1); - TcpServerTransport serverTransport = TcpServerTransport.create(0); + TcpServerTransport serverTransport = TcpServerTransport.create("localhost", 0); server = RSocketFactory.receive() diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java index 6c8f0e8fa..7a30a7fd1 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java @@ -21,7 +21,7 @@ public class InteractionsLoadTest { @Test @SlowTest public void channel() { - TcpServerTransport serverTransport = TcpServerTransport.create(0); + TcpServerTransport serverTransport = TcpServerTransport.create("localhost", 0); CloseableChannel server = RSocketFactory.receive() diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java index 41e437fee..9e7f5b0a7 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java @@ -46,7 +46,7 @@ public class TcpIntegrationTest { @Before public void startup() { - TcpServerTransport serverTransport = TcpServerTransport.create(0); + TcpServerTransport serverTransport = TcpServerTransport.create("localhost", 0); server = RSocketFactory.receive() .acceptor((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java index 62d7da336..575993c18 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java @@ -57,7 +57,7 @@ public void startup() { this.responseMessage = responseMessage.toString(); this.metaData = metaData.toString(); - TcpServerTransport serverTransport = TcpServerTransport.create(randomPort); + TcpServerTransport serverTransport = TcpServerTransport.create("localhost", randomPort); server = RSocketFactory.receive() .fragment(frameSize) diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java index 53b164247..b40f35e51 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java @@ -38,7 +38,7 @@ public static void main(String... args) { serverRSocketFactory .frameDecoder(PayloadDecoder.ZERO_COPY) .acceptor(new PingHandler()) - .transport(TcpServerTransport.create(port)) + .transport(TcpServerTransport.create("localhost", port)) .start() .block() .onClose() diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java index 84c185e26..b6cbfea34 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java @@ -70,7 +70,7 @@ void createNullTcpClient() { @DisplayName("creates server with port") @Test void createPort() { - assertThat(TcpServerTransport.create(8000)).isNotNull(); + assertThat(TcpServerTransport.create("localhost", 8000)).isNotNull(); } @DisplayName("creates client with TcpServer") @@ -97,7 +97,7 @@ void start() { @Test void startNullAcceptor() { assertThatNullPointerException() - .isThrownBy(() -> TcpServerTransport.create(8000).start(null, 0)) + .isThrownBy(() -> TcpServerTransport.create("localhost", 8000).start(null, 0)) .withMessage("acceptor must not be null"); } }