diff --git a/rsocket-core/src/main/java/io/rsocket/util/ReconnectingRSocket.java b/rsocket-core/src/main/java/io/rsocket/util/ReconnectingRSocket.java new file mode 100644 index 000000000..7e7f41212 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/ReconnectingRSocket.java @@ -0,0 +1,539 @@ +package io.rsocket.util; + +import io.netty.util.internal.ThreadLocalRandom; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.ConnectionCloseException; +import io.rsocket.exceptions.ConnectionErrorException; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.Objects; +import java.util.concurrent.CancellationException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Operators; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +@SuppressWarnings("unchecked") +public class ReconnectingRSocket implements CoreSubscriber, RSocket, Runnable { + + private final long backoffMinInMillis; + private final long backoffMaxInMillis; + private final Mono source; + private final Scheduler scheduler; + private final Predicate errorPredicate; + private final MonoProcessor onDispose; + + RSocket value; + + volatile Consumer[] subscribers; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater SUBSCRIBERS = + AtomicReferenceFieldUpdater.newUpdater( + ReconnectingRSocket.class, Consumer[].class, "subscribers"); + + @SuppressWarnings("rawtypes") + static final Consumer[] EMPTY_UNSUBSCRIBED = new Consumer[0]; + + @SuppressWarnings("rawtypes") + static final Consumer[] EMPTY_SUBSCRIBED = new Consumer[0]; + + @SuppressWarnings("rawtypes") + static final Consumer[] TERMINATED = new Consumer[0]; + + static final ClosedChannelException ON_CLOSE_EXCEPTION = new ClosedChannelException(); + + public static Builder builder() { + return new Builder(); + } + + ReconnectingRSocket( + Mono source, + Predicate errorPredicate, + Scheduler scheduler, + long backoffMinInMillis, + long backoffMaxInMillis) { + + this.source = source; + this.backoffMinInMillis = backoffMinInMillis; + this.backoffMaxInMillis = backoffMaxInMillis; + this.scheduler = scheduler; + this.errorPredicate = errorPredicate; + this.onDispose = MonoProcessor.create(); + + SUBSCRIBERS.lazySet(this, EMPTY_UNSUBSCRIBED); + } + + public void subscribe(Consumer actual) { + if (!add(actual)) { + actual.accept(value); + } + } + + @Override + public void onSubscribe(Subscription subscription) { + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onComplete() { + final RSocket value = this.value; + + if (value == null) { + reconnect(); + } else { + value.onClose().subscribe(null, null, () -> resubscribeWhen(ON_CLOSE_EXCEPTION)); + Consumer[] array = SUBSCRIBERS.getAndSet(this, TERMINATED); + for (Consumer as : array) { + as.accept(value); + } + } + } + + @Override + public void onError(Throwable t) { + reconnect(); + } + + @Override + @SuppressWarnings("unchecked") + public void onNext(@Nullable RSocket value) { + this.value = value; + } + + @Override + public void run() { + source.subscribe(this); + } + + private void reconnect() { + ThreadLocalRandom random = ThreadLocalRandom.current(); + long nextRandomDelay = random.nextLong(backoffMinInMillis, backoffMaxInMillis); + scheduler.schedule(this, nextRandomDelay, TimeUnit.MILLISECONDS); + } + + private boolean resubscribeWhen(Throwable throwable) { + if (onDispose.isDisposed()) { + return false; + } + + if (errorPredicate.test(throwable)) { + final Consumer[] subscribers = this.subscribers; + final RSocket current = this.value; + if ((current == null || current.isDisposed()) + && subscribers == TERMINATED + && SUBSCRIBERS.compareAndSet(this, TERMINATED, EMPTY_SUBSCRIBED)) { + this.value = null; + reconnect(); + } + return true; + } + return false; + } + + @Override + public Mono fireAndForget(Payload payload) { + return new FlatMapInner<>( + this, rsocket -> rsocket.fireAndForget(payload), this::resubscribeWhen); + } + + @Override + public Mono requestResponse(Payload payload) { + return new FlatMapInner<>( + this, rsocket -> rsocket.requestResponse(payload), this::resubscribeWhen); + } + + @Override + public Flux requestStream(Payload payload) { + return new FlatMapManyInner<>( + this, rSocket -> rSocket.requestStream(payload), this::resubscribeWhen); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return new FlatMapManyInner<>( + this, rSocket -> rSocket.requestChannel(payloads), this::resubscribeWhen); + } + + @Override + public Mono metadataPush(Payload payload) { + return new FlatMapInner<>( + this, rsocket -> rsocket.metadataPush(payload), this::resubscribeWhen); + } + + @Override + public double availability() { + RSocket rsocket = this.value; + return rsocket != null ? rsocket.availability() : 0d; + } + + @Override + public void dispose() { + onDispose.dispose(); + RSocket value = this.value; + this.value = null; + if (value != null) { + value.dispose(); + } + } + + @Override + public boolean isDisposed() { + return onDispose.isDisposed(); + } + + @Override + public Mono onClose() { + return onDispose; + } + + boolean add(Consumer ps) { + for (; ; ) { + Consumer[] a = subscribers; + + if (a == TERMINATED) { + return false; + } + + int n = a.length; + @SuppressWarnings("unchecked") + Consumer[] b = new Consumer[n + 1]; + System.arraycopy(a, 0, b, 0, n); + b[n] = ps; + + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + if (a == EMPTY_UNSUBSCRIBED) { + source.subscribe(this); + } + return true; + } + } + } + + static final class FlatMapInner extends Mono + implements CoreSubscriber, Consumer, Subscription, Scannable { + + final ReconnectingRSocket parent; + final Function> mapper; + final Predicate errorPredicate; + + boolean done; + + volatile int state; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(FlatMapInner.class, "state"); + + static final int NONE = 0; + static final int SUBSCRIBED = 1; + static final int CANCELLED = 2; + + CoreSubscriber actual; + Subscription s; + + FlatMapInner( + ReconnectingRSocket parent, + Function> mapper, + Predicate errorPredicate) { + this.parent = parent; + this.mapper = mapper; + this.errorPredicate = errorPredicate; + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (state == NONE && STATE.compareAndSet(this, NONE, SUBSCRIBED)) { + this.actual = actual; + parent.subscribe(this); + } else { + Operators.error(actual, new IllegalStateException("Only a single Subscriber allowed")); + } + } + + @Override + public void accept(RSocket rSocket) { + if (rSocket == null) { + Operators.error(actual, new CancellationException("Disposed")); + } + + Mono source; + try { + source = this.mapper.apply(rSocket); + source.subscribe(this); + } catch (Throwable e) { + Exceptions.throwIfFatal(e); + Operators.error(actual, e); + } + } + + @Override + public Context currentContext() { + return actual.currentContext(); + } + + @Nullable + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return s; + if (key == Attr.ACTUAL) return parent; + if (key == Attr.TERMINATED) return done; + if (key == Attr.CANCELLED) return s == Operators.cancelledSubscription(); + + return null; + } + + @Override + public void onSubscribe(Subscription s) { + this.s = s; + actual.onSubscribe(this); + } + + @Override + public void onNext(T payload) { + if (done) { + Operators.onNextDropped(payload, actual.currentContext()); + return; + } + done = true; + actual.onNext(payload); + } + + @Override + public void onError(Throwable t) { + if (done) { + Operators.onErrorDropped(t, actual.currentContext()); + return; + } + + final CoreSubscriber actual = this.actual; + + if (errorPredicate.test(t)) { + this.actual = null; + STATE.compareAndSet(this, SUBSCRIBED, NONE); + } else { + done = true; + } + actual.onError(t); + } + + @Override + public void onComplete() { + if (done) { + return; + } + done = true; + actual.onComplete(); + } + + @Override + public void request(long n) { + s.request(n); + } + + public void cancel() { + if (STATE.getAndSet(this, CANCELLED) != CANCELLED) { + s.cancel(); + } + } + } + + static final class FlatMapManyInner extends Flux + implements CoreSubscriber, Consumer, Subscription, Scannable { + + final ReconnectingRSocket parent; + final Function> mapper; + final Predicate errorPredicate; + + boolean done; + + volatile int state; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(FlatMapManyInner.class, "state"); + + static final int NONE = 0; + static final int SUBSCRIBED = 1; + static final int CANCELLED = 2; + + CoreSubscriber actual; + Subscription s; + + FlatMapManyInner( + ReconnectingRSocket parent, + Function> mapper, + Predicate errorPredicate) { + this.parent = parent; + this.mapper = mapper; + this.errorPredicate = errorPredicate; + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (state == NONE && STATE.compareAndSet(this, NONE, SUBSCRIBED)) { + this.actual = actual; + parent.subscribe(this); + } else { + Operators.error(actual, new IllegalStateException("Only a single Subscriber allowed")); + } + } + + @Override + public void accept(RSocket rSocket) { + if (rSocket == null) { + Operators.error(actual, new CancellationException("Disposed")); + } + + Flux source; + try { + source = this.mapper.apply(rSocket); + source.subscribe(this); + } catch (Throwable e) { + Exceptions.throwIfFatal(e); + Operators.error(actual, e); + } + } + + @Override + public Context currentContext() { + return actual.currentContext(); + } + + @Nullable + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return s; + if (key == Attr.ACTUAL) return parent; + if (key == Attr.TERMINATED) return done; + if (key == Attr.CANCELLED) return s == Operators.cancelledSubscription(); + + return null; + } + + @Override + public void onSubscribe(Subscription s) { + this.s = s; + actual.onSubscribe(this); + } + + @Override + public void onNext(T payload) { + actual.onNext(payload); + } + + @Override + public void onError(Throwable t) { + if (done) { + Operators.onErrorDropped(t, actual.currentContext()); + return; + } + + final CoreSubscriber actual = this.actual; + + if (errorPredicate.test(t)) { + this.actual = null; + STATE.compareAndSet(this, SUBSCRIBED, NONE); + } else { + done = true; + } + actual.onError(t); + } + + @Override + public void onComplete() { + if (done) { + return; + } + done = true; + actual.onComplete(); + } + + @Override + public void request(long n) { + s.request(n); + } + + public void cancel() { + if (STATE.getAndSet(this, CANCELLED) != CANCELLED) { + s.cancel(); + } + } + } + + public static class Builder { + + private static final Predicate DEFAULT_ERROR_PREDICATE = + throwable -> + throwable instanceof ClosedChannelException + || throwable instanceof ConnectionCloseException + || throwable instanceof ConnectionErrorException; + + Duration backoffMin; + Duration backoffMax; + Supplier> sourceSupplier; + Scheduler scheduler = Schedulers.parallel(); + Predicate errorPredicate = DEFAULT_ERROR_PREDICATE; + + public Builder withSourceRSocket(Supplier> sourceSupplier) { + this.sourceSupplier = sourceSupplier; + return this; + } + + public Builder withRetryPeriod(Duration period) { + return withRetryPeriod(period, period); + } + + public Builder withRetryPeriod(Duration periodMin, Duration periodMax) { + backoffMin = periodMin; + backoffMax = periodMax; + return this; + } + + public Builder withCustomRetryOnErrorPredicate(Predicate predicate) { + errorPredicate = predicate; + return this; + } + + public Builder withDefaultRetryOnErrorPredicate() { + return withCustomRetryOnErrorPredicate(DEFAULT_ERROR_PREDICATE); + } + + public Builder withRetryOnScheduler(Scheduler scheduler) { + this.scheduler = scheduler; + return this; + } + + public ReconnectingRSocket build() { + Objects.requireNonNull(backoffMin, "Specify required retry period"); + Objects.requireNonNull(backoffMax, "Specify required Max retry period"); + Objects.requireNonNull(errorPredicate, "Specify required retryOnError predicate"); + Objects.requireNonNull(sourceSupplier, "Specify required source RSocket supplier"); + Objects.requireNonNull(scheduler, "Specify required Scheduler"); + + return new ReconnectingRSocket( + Mono.defer(sourceSupplier), + errorPredicate, + scheduler, + backoffMin.toMillis(), + backoffMax.toMillis()); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/util/ReconnectingRSocketTest.java b/rsocket-core/src/test/java/io/rsocket/util/ReconnectingRSocketTest.java new file mode 100644 index 000000000..009851e6f --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/util/ReconnectingRSocketTest.java @@ -0,0 +1,172 @@ +package io.rsocket.util; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.function.BiFunction; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import reactor.core.CorePublisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; + +class ReconnectingRSocketTest { + + @SuppressWarnings("unchecked") + static Stream, CorePublisher>> invocations() { + return Stream.of( + (rSocket, arg) -> rSocket.fireAndForget(arg.block()), + (rSocket, arg) -> rSocket.requestResponse(arg.block()), + (rSocket, arg) -> rSocket.requestStream(arg.block()), + RSocket::requestChannel); + } + + private static final Payload TEST_PAYLOAD = DefaultPayload.create(""); + private static final Mono MONO_TEST_PAYLOAD = Mono.just(TEST_PAYLOAD); + + @DisplayName("Verifies that subscriptions to the given source RSocket only on method call") + @ParameterizedTest + @MethodSource("invocations") + public void testSubscribesOnFirstMethodCall( + BiFunction, CorePublisher> invocation) { + + RSocket rSocketMock = Mockito.mock(RSocket.class); + Mockito.when(rSocketMock.fireAndForget(Mockito.any(Payload.class))).thenReturn(Mono.never()); + Mockito.when(rSocketMock.requestResponse(Mockito.any(Payload.class))).thenReturn(Mono.never()); + Mockito.when(rSocketMock.requestStream(Mockito.any(Payload.class))).thenReturn(Flux.never()); + Mockito.when(rSocketMock.requestChannel(Mockito.any())).thenReturn(Flux.never()); + Mockito.when(rSocketMock.onClose()).thenReturn(Mono.never()); + ReconnectingRSocket reconnectingRSocket = + new ReconnectingRSocket( + Mono.fromSupplier(() -> rSocketMock), t -> true, Schedulers.parallel(), 1000, 1500); + + Mockito.verifyZeroInteractions(rSocketMock); + Assertions.assertThat(reconnectingRSocket.value).isNull(); + + CorePublisher corePublisher = invocation.apply(reconnectingRSocket, MONO_TEST_PAYLOAD); + + Mockito.verifyZeroInteractions(rSocketMock); + Assertions.assertThat(reconnectingRSocket.value).isNull(); + + if (corePublisher instanceof Mono) { + ((Mono) corePublisher).subscribe(); + } else { + ((Flux) corePublisher).subscribe(); + } + + invocation.apply(Mockito.verify(rSocketMock), MONO_TEST_PAYLOAD); + Assertions.assertThat(reconnectingRSocket.value).isEqualTo(rSocketMock); + Assertions.assertThat(reconnectingRSocket.subscribers) + .isEqualTo(ReconnectingRSocket.TERMINATED); + } + + @DisplayName( + "Verifies that ReconnectingRSocket reconnect when the source " + + "RSocketMono is empty or error one") + @Test + @SuppressWarnings("unchecked") + public void testReconnectsWhenGotCompletion() { + RSocket rSocketMock = Mockito.mock(RSocket.class); + Mockito.when(rSocketMock.fireAndForget(Mockito.any(Payload.class))).thenReturn(Mono.empty()); + Mockito.when(rSocketMock.onClose()).thenReturn(Mono.never()); + Supplier> rSocketMonoMock = Mockito.mock(Supplier.class); + Mockito.when(rSocketMonoMock.get()) + .thenReturn(Mono.error(new RuntimeException()), Mono.empty(), Mono.just(rSocketMock)); + + ReconnectingRSocket reconnectingRSocket = + new ReconnectingRSocket( + Mono.defer(rSocketMonoMock), t -> true, Schedulers.parallel(), 10, 20); + + Mockito.verifyZeroInteractions(rSocketMock); + Assertions.assertThat(reconnectingRSocket.value).isNull(); + + Mono fnfMono = reconnectingRSocket.fireAndForget(TEST_PAYLOAD); + + Mockito.verifyZeroInteractions(rSocketMock); + Assertions.assertThat(reconnectingRSocket.value).isNull(); + + StepVerifier.create(fnfMono).verifyComplete(); + + Mockito.verify(rSocketMock).fireAndForget(TEST_PAYLOAD); + Assertions.assertThat(reconnectingRSocket.value).isEqualTo(rSocketMock); + Assertions.assertThat(reconnectingRSocket.subscribers) + .isEqualTo(ReconnectingRSocket.TERMINATED); + + Mockito.verify(rSocketMonoMock, Mockito.times(3)).get(); + } + + @DisplayName( + "Verifies that ReconnectingRSocket reconnect when got reconnectable " + + "exception in the logical stream") + @ParameterizedTest + @MethodSource("invocations") + @SuppressWarnings("unchecked") + public void testReconnectsWhenGotLogicalStreamError( + BiFunction, CorePublisher> invocation) { + RSocket rSocketMock = Mockito.mock(RSocket.class); + Mockito.when(rSocketMock.fireAndForget(Mockito.any(Payload.class))) + .thenReturn( + Mono.error(new ClosedChannelException()) + .doOnError(e -> Mockito.when(rSocketMock.isDisposed()).thenReturn(true)), + Mono.empty()); + Mockito.when(rSocketMock.requestResponse(Mockito.any(Payload.class))) + .thenReturn( + Mono.error(new ClosedChannelException()) + .doOnError(e -> Mockito.when(rSocketMock.isDisposed()).thenReturn(true)), + Mono.empty()); + Mockito.when(rSocketMock.requestStream(Mockito.any(Payload.class))) + .thenReturn( + Flux.error(new ClosedChannelException()) + .doOnError(e -> Mockito.when(rSocketMock.isDisposed()).thenReturn(true)), + Flux.empty()); + Mockito.when(rSocketMock.requestChannel(Mockito.any())) + .thenReturn( + Flux.error(new ClosedChannelException()) + .doOnError(e -> Mockito.when(rSocketMock.isDisposed()).thenReturn(true)), + Flux.empty()); + Mockito.when(rSocketMock.onClose()).thenReturn(Mono.never()); + Supplier> rSocketMonoMock = Mockito.mock(Supplier.class); + Mockito.when(rSocketMonoMock.get()) + .thenReturn( + Mono.error(new RuntimeException()), + Mono.empty(), + Mono.just(rSocketMock), + Mono.just(rSocketMock)); + + ReconnectingRSocket reconnectingRSocket = + new ReconnectingRSocket( + Mono.defer(() -> rSocketMonoMock.get()), t -> true, Schedulers.parallel(), 10, 20); + + Mockito.verifyZeroInteractions(rSocketMock); + Assertions.assertThat(reconnectingRSocket.value).isNull(); + + CorePublisher corePublisher = invocation.apply(reconnectingRSocket, MONO_TEST_PAYLOAD); + + Mockito.verifyZeroInteractions(rSocketMock); + Assertions.assertThat(reconnectingRSocket.value).isNull(); + + if (corePublisher instanceof Mono) { + corePublisher = ((Mono) corePublisher).retryBackoff(1, Duration.ofMillis(10)); + } else { + corePublisher = ((Flux) corePublisher).retryBackoff(1, Duration.ofMillis(10)); + } + + StepVerifier.create(corePublisher).verifyComplete(); + + invocation.apply(Mockito.verify(rSocketMock, Mockito.times(2)), MONO_TEST_PAYLOAD); + Assertions.assertThat(reconnectingRSocket.value).isEqualTo(rSocketMock); + Assertions.assertThat(reconnectingRSocket.subscribers) + .isEqualTo(ReconnectingRSocket.TERMINATED); + + Mockito.verify(rSocketMonoMock, Mockito.times(4)).get(); + } +}