Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changes LoadbalanceStrategy to accept List #919

Merged
merged 3 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
*/
package io.rsocket.loadbalance;

import java.util.List;
import java.util.function.Supplier;

@FunctionalInterface
public interface LoadbalanceStrategy {

PooledRSocket select(PooledRSocket[] availableRSockets);
WeightedRSocket select(List<WeightedRSocket> availableRSockets);

default Supplier<Stats> statsSupplier() {
return Stats::noOps;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,20 @@
import reactor.core.publisher.Operators;
import reactor.util.context.Context;

/** Default implementation of {@link PooledRSocket} stored in {@link RSocketPool} */
final class DefaultPooledRSocket extends ResolvingOperator<RSocket>
implements CoreSubscriber<RSocket>, PooledRSocket {
/** Default implementation of {@link WeightedRSocket} stored in {@link RSocketPool} */
final class PooledWeightedRSocket extends ResolvingOperator<RSocket>
implements CoreSubscriber<RSocket>, WeightedRSocket {

final RSocketPool parent;
final LoadbalanceRSocketSource loadbalanceRSocketSource;
final Stats stats;

volatile Subscription s;

static final AtomicReferenceFieldUpdater<DefaultPooledRSocket, Subscription> S =
AtomicReferenceFieldUpdater.newUpdater(DefaultPooledRSocket.class, Subscription.class, "s");
static final AtomicReferenceFieldUpdater<PooledWeightedRSocket, Subscription> S =
AtomicReferenceFieldUpdater.newUpdater(PooledWeightedRSocket.class, Subscription.class, "s");

DefaultPooledRSocket(
PooledWeightedRSocket(
RSocketPool parent, LoadbalanceRSocketSource loadbalanceRSocketSource, Stats stats) {
this.parent = parent;
this.stats = stats;
Expand Down Expand Up @@ -128,7 +128,7 @@ public void dispose() {
protected void doOnDispose() {
final RSocketPool parent = this.parent;
for (; ; ) {
final PooledRSocket[] sockets = parent.activeSockets;
final PooledWeightedRSocket[] sockets = parent.activeSockets;
final int activeSocketsCount = sockets.length;

int index = -1;
Expand All @@ -144,7 +144,7 @@ protected void doOnDispose() {
}

final int lastIndex = activeSocketsCount - 1;
final PooledRSocket[] newSockets = new PooledRSocket[lastIndex];
final PooledWeightedRSocket[] newSockets = new PooledWeightedRSocket[lastIndex];
if (index != 0) {
System.arraycopy(sockets, 0, newSockets, 0, index);
}
Expand Down Expand Up @@ -196,8 +196,7 @@ public Stats stats() {
return stats;
}

@Override
public LoadbalanceRSocketSource source() {
LoadbalanceRSocketSource source() {
return loadbalanceRSocketSource;
}

Expand All @@ -211,7 +210,7 @@ static final class RequestTrackingMonoInner<RESULT>

long startTime;

RequestTrackingMonoInner(DefaultPooledRSocket parent, Payload payload, FrameType requestType) {
RequestTrackingMonoInner(PooledWeightedRSocket parent, Payload payload, FrameType requestType) {
super(parent, payload, requestType);
}

Expand Down Expand Up @@ -245,7 +244,7 @@ public void accept(RSocket rSocket, Throwable t) {
return;
}

startTime = ((DefaultPooledRSocket) parent).stats.startRequest();
startTime = ((PooledWeightedRSocket) parent).stats.startRequest();

source.subscribe((CoreSubscriber) this);
} else {
Expand All @@ -257,7 +256,7 @@ public void accept(RSocket rSocket, Throwable t) {
public void onComplete() {
final long state = this.requested;
if (state != TERMINATED_STATE && REQUESTED.compareAndSet(this, state, TERMINATED_STATE)) {
final Stats stats = ((DefaultPooledRSocket) parent).stats;
final Stats stats = ((PooledWeightedRSocket) parent).stats;
final long now = stats.stopRequest(startTime);
stats.record(now - startTime);
super.onComplete();
Expand All @@ -268,7 +267,7 @@ public void onComplete() {
public void onError(Throwable t) {
final long state = this.requested;
if (state != TERMINATED_STATE && REQUESTED.compareAndSet(this, state, TERMINATED_STATE)) {
Stats stats = ((DefaultPooledRSocket) parent).stats;
Stats stats = ((PooledWeightedRSocket) parent).stats;
stats.stopRequest(startTime);
stats.recordError(0.0);
super.onError(t);
Expand All @@ -284,7 +283,7 @@ public void cancel() {

if (state == STATE_SUBSCRIBED) {
this.s.cancel();
((DefaultPooledRSocket) parent).stats.stopRequest(startTime);
((PooledWeightedRSocket) parent).stats.stopRequest(startTime);
} else {
this.parent.remove(this);
ReferenceCountUtil.safeRelease(this.payload);
Expand All @@ -296,7 +295,7 @@ static final class RequestTrackingFluxInner<INPUT>
extends FluxDeferredResolution<INPUT, RSocket> {

RequestTrackingFluxInner(
DefaultPooledRSocket parent, INPUT fluxOrPayload, FrameType requestType) {
PooledWeightedRSocket parent, INPUT fluxOrPayload, FrameType requestType) {
super(parent, fluxOrPayload, requestType);
}

Expand Down Expand Up @@ -329,7 +328,7 @@ public void accept(RSocket rSocket, Throwable t) {
return;
}

((DefaultPooledRSocket) parent).stats.startStream();
((PooledWeightedRSocket) parent).stats.startStream();

source.subscribe(this);
} else {
Expand All @@ -341,7 +340,7 @@ public void accept(RSocket rSocket, Throwable t) {
public void onComplete() {
final long state = this.requested;
if (state != TERMINATED_STATE && REQUESTED.compareAndSet(this, state, TERMINATED_STATE)) {
((DefaultPooledRSocket) parent).stats.stopStream();
((PooledWeightedRSocket) parent).stats.stopStream();
super.onComplete();
}
}
Expand All @@ -350,7 +349,7 @@ public void onComplete() {
public void onError(Throwable t) {
final long state = this.requested;
if (state != TERMINATED_STATE && REQUESTED.compareAndSet(this, state, TERMINATED_STATE)) {
((DefaultPooledRSocket) parent).stats.stopStream();
((PooledWeightedRSocket) parent).stats.stopStream();
super.onError(t);
}
}
Expand All @@ -364,7 +363,7 @@ public void cancel() {

if (state == STATE_SUBSCRIBED) {
this.s.cancel();
((DefaultPooledRSocket) parent).stats.stopStream();
((PooledWeightedRSocket) parent).stats.stopStream();
} else {
this.parent.remove(this);
if (requestType == FrameType.REQUEST_STREAM) {
Expand Down
149 changes: 134 additions & 15 deletions rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
import io.rsocket.RSocket;
import io.rsocket.frame.FrameType;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.concurrent.CancellationException;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Supplier;
Expand All @@ -34,20 +37,20 @@
import reactor.util.annotation.Nullable;

class RSocketPool extends ResolvingOperator<Void>
implements CoreSubscriber<List<LoadbalanceRSocketSource>> {
implements CoreSubscriber<List<LoadbalanceRSocketSource>>, List<WeightedRSocket> {

final DeferredResolutionRSocket deferredResolutionRSocket = new DeferredResolutionRSocket(this);
final LoadbalanceStrategy loadbalanceStrategy;
final Supplier<Stats> statsSupplier;

volatile PooledRSocket[] activeSockets;
volatile PooledWeightedRSocket[] activeSockets;

static final AtomicReferenceFieldUpdater<RSocketPool, PooledRSocket[]> ACTIVE_SOCKETS =
static final AtomicReferenceFieldUpdater<RSocketPool, PooledWeightedRSocket[]> ACTIVE_SOCKETS =
AtomicReferenceFieldUpdater.newUpdater(
RSocketPool.class, PooledRSocket[].class, "activeSockets");
RSocketPool.class, PooledWeightedRSocket[].class, "activeSockets");

static final PooledRSocket[] EMPTY = new PooledRSocket[0];
static final PooledRSocket[] TERMINATED = new PooledRSocket[0];
static final PooledWeightedRSocket[] EMPTY = new PooledWeightedRSocket[0];
static final PooledWeightedRSocket[] TERMINATED = new PooledWeightedRSocket[0];

volatile Subscription s;
static final AtomicReferenceFieldUpdater<RSocketPool, Subscription> S =
Expand Down Expand Up @@ -93,8 +96,8 @@ public void onNext(List<LoadbalanceRSocketSource> loadbalanceRSocketSources) {
return;
}

PooledRSocket[] previouslyActiveSockets;
PooledRSocket[] activeSockets;
PooledWeightedRSocket[] previouslyActiveSockets;
PooledWeightedRSocket[] activeSockets;
for (; ; ) {
HashMap<LoadbalanceRSocketSource, Integer> rSocketSuppliersCopy = new HashMap<>();

Expand All @@ -105,11 +108,11 @@ public void onNext(List<LoadbalanceRSocketSource> loadbalanceRSocketSources) {

// checking intersection of active RSocket with the newly received set
previouslyActiveSockets = this.activeSockets;
PooledRSocket[] nextActiveSockets =
new PooledRSocket[previouslyActiveSockets.length + rSocketSuppliersCopy.size()];
PooledWeightedRSocket[] nextActiveSockets =
new PooledWeightedRSocket[previouslyActiveSockets.length + rSocketSuppliersCopy.size()];
int position = 0;
for (int i = 0; i < previouslyActiveSockets.length; i++) {
PooledRSocket rSocket = previouslyActiveSockets[i];
PooledWeightedRSocket rSocket = previouslyActiveSockets[i];

Integer index = rSocketSuppliersCopy.remove(rSocket.source());
if (index == null) {
Expand All @@ -127,7 +130,7 @@ public void onNext(List<LoadbalanceRSocketSource> loadbalanceRSocketSources) {
} else {
// put newly create RSocket instance
nextActiveSockets[position++] =
new DefaultPooledRSocket(
new PooledWeightedRSocket(
this, loadbalanceRSocketSources.get(index), this.statsSupplier.get());
}
}
Expand All @@ -136,7 +139,7 @@ public void onNext(List<LoadbalanceRSocketSource> loadbalanceRSocketSources) {
// going though brightly new rsocket
for (LoadbalanceRSocketSource newLoadbalanceRSocketSource : rSocketSuppliersCopy.keySet()) {
nextActiveSockets[position++] =
new DefaultPooledRSocket(this, newLoadbalanceRSocketSource, this.statsSupplier.get());
new PooledWeightedRSocket(this, newLoadbalanceRSocketSource, this.statsSupplier.get());
}

// shrank to actual length
Expand Down Expand Up @@ -195,12 +198,38 @@ RSocket select() {

@Nullable
RSocket doSelect() {
PooledRSocket[] sockets = this.activeSockets;
WeightedRSocket[] sockets = this.activeSockets;
if (sockets == EMPTY) {
return null;
}

return this.loadbalanceStrategy.select(sockets);
return this.loadbalanceStrategy.select(this);
}

@Override
public WeightedRSocket get(int index) {
return activeSockets[index];
}

@Override
public int size() {
return activeSockets.length;
}

@Override
public boolean isEmpty() {
return activeSockets.length == 0;
}

@Override
public Object[] toArray() {
return activeSockets;
}

@Override
@SuppressWarnings("unchecked")
public <T> T[] toArray(T[] a) {
return (T[]) activeSockets;
}

static class DeferredResolutionRSocket implements RSocket {
Expand Down Expand Up @@ -325,4 +354,94 @@ public void accept(Void aVoid, Throwable t) {
}
}
}

@Override
public boolean contains(Object o) {
throw new UnsupportedOperationException();
}

@Override
public Iterator<WeightedRSocket> iterator() {
throw new UnsupportedOperationException();
}

@Override
public boolean add(WeightedRSocket weightedRSocket) {
throw new UnsupportedOperationException();
}

@Override
public boolean remove(Object o) {
throw new UnsupportedOperationException();
}

@Override
public boolean containsAll(Collection<?> c) {
throw new UnsupportedOperationException();
}

@Override
public boolean addAll(Collection<? extends WeightedRSocket> c) {
throw new UnsupportedOperationException();
}

@Override
public boolean addAll(int index, Collection<? extends WeightedRSocket> c) {
throw new UnsupportedOperationException();
}

@Override
public boolean removeAll(Collection<?> c) {
throw new UnsupportedOperationException();
}

@Override
public boolean retainAll(Collection<?> c) {
throw new UnsupportedOperationException();
}

@Override
public void clear() {
throw new UnsupportedOperationException();
}

@Override
public WeightedRSocket set(int index, WeightedRSocket element) {
throw new UnsupportedOperationException();
}

@Override
public void add(int index, WeightedRSocket element) {
throw new UnsupportedOperationException();
}

@Override
public WeightedRSocket remove(int index) {
throw new UnsupportedOperationException();
}

@Override
public int indexOf(Object o) {
throw new UnsupportedOperationException();
}

@Override
public int lastIndexOf(Object o) {
throw new UnsupportedOperationException();
}

@Override
public ListIterator<WeightedRSocket> listIterator() {
throw new UnsupportedOperationException();
}

@Override
public ListIterator<WeightedRSocket> listIterator(int index) {
throw new UnsupportedOperationException();
}

@Override
public List<WeightedRSocket> subList(int fromIndex, int toIndex) {
throw new UnsupportedOperationException();
}
}
Loading