diff --git a/api/src/main/java/io/grpc/EquivalentAddressGroup.java b/api/src/main/java/io/grpc/EquivalentAddressGroup.java index 969c82e010b..4b3db006684 100644 --- a/api/src/main/java/io/grpc/EquivalentAddressGroup.java +++ b/api/src/main/java/io/grpc/EquivalentAddressGroup.java @@ -128,6 +128,9 @@ public int hashCode() { */ @Override public boolean equals(Object other) { + if (this == other) { + return true; + } if (!(other instanceof EquivalentAddressGroup)) { return false; } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index da2bc072afc..50e311651c7 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -161,7 +161,7 @@ /** Unit tests for {@link ManagedChannelImpl}. */ @RunWith(JUnit4.class) // TODO(creamsoup) remove backward compatible check when fully migrated -@SuppressWarnings("deprecation") +@SuppressWarnings({"deprecation"}) public class ManagedChannelImplTest { private static final int DEFAULT_PORT = 447; diff --git a/core/src/testFixtures/java/io/grpc/internal/TestUtils.java b/core/src/testFixtures/java/io/grpc/internal/TestUtils.java index 974f36e595c..055a7f80283 100644 --- a/core/src/testFixtures/java/io/grpc/internal/TestUtils.java +++ b/core/src/testFixtures/java/io/grpc/internal/TestUtils.java @@ -21,9 +21,11 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ChannelLogger; import io.grpc.ClientStreamTracer; +import io.grpc.EquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -143,6 +145,14 @@ public Runnable answer(InvocationOnMock invocation) throws Throwable { return captor; } + @SuppressWarnings("ReferenceEquality") + public static final EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { + if (eag.getAttributes() == Attributes.EMPTY) { + return eag; + } + return new EquivalentAddressGroup(eag.getAddresses()); + } + private TestUtils() { } diff --git a/util/build.gradle b/util/build.gradle index a05c55b27bb..cdd32e0ceb5 100644 --- a/util/build.gradle +++ b/util/build.gradle @@ -1,5 +1,6 @@ plugins { id "java-library" + id "java-test-fixtures" id "maven-publish" id "me.champeau.jmh" @@ -19,11 +20,18 @@ dependencies { implementation libraries.animalsniffer.annotations, libraries.guava - testImplementation testFixtures(project(':grpc-api')), + testImplementation libraries.guava.testlib, + testFixtures(project(':grpc-api')), testFixtures(project(':grpc-core')), project(':grpc-testing') - testImplementation libraries.guava.testlib + testFixturesApi project(':grpc-core') + testFixturesImplementation libraries.guava, + libraries.junit, + libraries.mockito.core, + testFixtures(project(':grpc-api')), + testFixtures(project(':grpc-core')), + project(':grpc-testing') jmh project(':grpc-testing') signature libraries.signature.java diff --git a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java index 8f2269af261..ff27570c254 100644 --- a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java @@ -16,25 +16,32 @@ package io.grpc.util; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import io.grpc.ConnectivityState; +import io.grpc.EquivalentAddressGroup; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.SynchronizationContext.ScheduledHandle; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.internal.PickFirstLoadBalancerProvider; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -46,23 +53,26 @@ @Internal public abstract class MultiChildLoadBalancer extends LoadBalancer { - @VisibleForTesting - public static final int DELAYED_CHILD_DELETION_TIME_MINUTES = 15; private static final Logger logger = Logger.getLogger(MultiChildLoadBalancer.class.getName()); - private final Map childLbStates = new HashMap<>(); + private final Map childLbStates = new LinkedHashMap<>(); private final Helper helper; - protected final SynchronizationContext syncContext; - private final ScheduledExecutorService timeService; // Set to true if currently in the process of handling resolved addresses. - private boolean resolvingAddresses; + @VisibleForTesting + protected boolean resolvingAddresses; + + protected final PickFirstLoadBalancerProvider pickFirstLbProvider = + new PickFirstLoadBalancerProvider(); + + protected ConnectivityState currentConnectivityState; protected MultiChildLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); - this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); - this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); logger.log(Level.FINE, "Created"); } + protected abstract SubchannelPicker getSubchannelPicker( + Map childPickers); + protected SubchannelPicker getInitialPicker() { return EMPTY_PICKER; } @@ -71,12 +81,67 @@ protected SubchannelPicker getErrorPicker(Status error) { return new FixedResultPicker(PickResult.withError(error)); } - protected abstract Map getPolicySelectionMap( - ResolvedAddresses resolvedAddresses); + /** + * Generally, the only reason to override this is to expose it to a test of a LB in a different + * package. + */ + @VisibleForTesting + protected Collection getChildLbStates() { + return childLbStates.values(); + } - protected abstract SubchannelPicker getSubchannelPicker( - Map childPickers); + /** + * Generally, the only reason to override this is to expose it to a test of a LB in a + * different package. + */ + + protected ChildLbState getChildLbState(Object key) { + if (key == null) { + return null; + } + if (key instanceof EquivalentAddressGroup) { + key = new Endpoint((EquivalentAddressGroup) key); + } + return childLbStates.get(key); + } + + /** + * Generally, the only reason to override this is to expose it to a test of a LB in a different + * package. + */ + protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) { + return getChildLbState(new Endpoint(eag)); + } + + /** + * Override to utilize parsing of the policy configuration or alternative helper/lb generation. + */ + protected Map createChildLbMap(ResolvedAddresses resolvedAddresses) { + Map childLbMap = new HashMap<>(); + List addresses = resolvedAddresses.getAddresses(); + for (EquivalentAddressGroup eag : addresses) { + Endpoint endpoint = new Endpoint(eag); // keys need to be just addresses + ChildLbState existingChildLbState = childLbStates.get(endpoint); + if (existingChildLbState != null) { + childLbMap.put(endpoint, existingChildLbState); + } else { + childLbMap.put(endpoint, createChildLbState(endpoint, null, getInitialPicker())); + } + } + return childLbMap; + } + + /** + * Override to create an instance of a subclass. + */ + protected ChildLbState createChildLbState(Object key, Object policyConfig, + SubchannelPicker initialPicker) { + return new ChildLbState(key, pickFirstLbProvider, policyConfig, initialPicker); + } + /** + * Override to completely replace the default logic or to do additional activities. + */ @Override public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { try { @@ -87,25 +152,71 @@ public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { } } + /** + * Override this if your keys are not of type Endpoint. + * @param key Key to identify the ChildLbState + * @param resolvedAddresses list of addresses which include attributes + * @param childConfig a load balancing policy config. This field is optional. + * @return a fully loaded ResolvedAddresses object for the specified key + */ + protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses, + Object childConfig) { + if (key instanceof EquivalentAddressGroup) { + key = new Endpoint((EquivalentAddressGroup) key); + } + checkArgument(key instanceof Endpoint, "key is wrong type"); + + // Retrieve the non-stripped version + EquivalentAddressGroup eagToUse = null; + for (EquivalentAddressGroup currEag : resolvedAddresses.getAddresses()) { + if (key.equals(new Endpoint(currEag))) { + eagToUse = currEag; + break; + } + } + + checkNotNull(eagToUse, key + " no longer present in load balancer children"); + + return resolvedAddresses.toBuilder() + .setAddresses(Collections.singletonList(eagToUse)) + .setLoadBalancingPolicyConfig(childConfig) + .build(); + } + private boolean acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) { logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses); - Map newChildPolicies = getPolicySelectionMap(resolvedAddresses); - for (Map.Entry entry : newChildPolicies.entrySet()) { + Map newChildren = createChildLbMap(resolvedAddresses); + + if (newChildren.isEmpty()) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + "NameResolver returned no usable address. " + resolvedAddresses)); + return false; + } + + // Do adds and updates + for (Map.Entry entry : newChildren.entrySet()) { final Object key = entry.getKey(); - LoadBalancerProvider childPolicyProvider = entry.getValue().getProvider(); + LoadBalancerProvider childPolicyProvider = entry.getValue().getPolicyProvider(); Object childConfig = entry.getValue().getConfig(); if (!childLbStates.containsKey(key)) { - childLbStates.put(key, new ChildLbState(key, childPolicyProvider, getInitialPicker())); + childLbStates.put(key, entry.getValue()); } else { - childLbStates.get(key).reactivate(childPolicyProvider); + // Reuse the existing one + ChildLbState existingChildLbState = childLbStates.get(key); + if (existingChildLbState.isDeactivated()) { + existingChildLbState.reactivate(childPolicyProvider); + } } + LoadBalancer childLb = childLbStates.get(key).lb; - ResolvedAddresses childAddresses = - resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build(); - childLb.handleResolvedAddresses(childAddresses); + ResolvedAddresses childAddresses = getChildAddresses(key, resolvedAddresses, childConfig); + childLbStates.get(key).setResolvedAddresses(childAddresses); // update child state + childLb.handleResolvedAddresses(childAddresses); // update child LB } - for (Object key : childLbStates.keySet()) { - if (!newChildPolicies.containsKey(key)) { + + // Do removals + for (Object key : ImmutableList.copyOf(childLbStates.keySet())) { + if (!newChildren.containsKey(key)) { childLbStates.get(key).deactivate(); } } @@ -117,19 +228,23 @@ private boolean acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddres @Override public void handleNameResolutionError(Status error) { - logger.log(Level.WARNING, "Received name resolution error: {0}", error); - boolean gotoTransientFailure = true; - for (ChildLbState state : childLbStates.values()) { - if (!state.deactivated) { - gotoTransientFailure = false; - state.lb.handleNameResolutionError(error); - } - } - if (gotoTransientFailure) { - helper.updateBalancingState(TRANSIENT_FAILURE, getErrorPicker(error)); + if (currentConnectivityState != READY) { + updateHelperBalancingState(TRANSIENT_FAILURE, getErrorPicker(error)); } } + protected void handleNameResolutionError(ChildLbState child, Status error) { + child.lb.handleNameResolutionError(error); + } + + /** + * If true, then when a subchannel state changes to idle, the corresponding child will + * have requestConnection called on its LB. + */ + protected boolean reconnectOnIdle() { + return true; + } + @Override public void shutdown() { logger.log(Level.INFO, "Shutdown"); @@ -139,10 +254,10 @@ public void shutdown() { childLbStates.clear(); } - private void updateOverallBalancingState() { + protected void updateOverallBalancingState() { ConnectivityState overallState = null; final Map childPickers = new HashMap<>(); - for (ChildLbState childLbState : childLbStates.values()) { + for (ChildLbState childLbState : getChildLbStates()) { if (childLbState.deactivated) { continue; } @@ -151,11 +266,17 @@ private void updateOverallBalancingState() { } if (overallState != null) { helper.updateBalancingState(overallState, getSubchannelPicker(childPickers)); + currentConnectivityState = overallState; } } + protected final void updateHelperBalancingState(ConnectivityState newState, + SubchannelPicker newPicker) { + helper.updateBalancingState(newState, newPicker); + } + @Nullable - private static ConnectivityState aggregateState( + protected static ConnectivityState aggregateState( @Nullable ConnectivityState overallState, ConnectivityState childState) { if (overallState == null) { return childState; @@ -172,70 +293,155 @@ private static ConnectivityState aggregateState( return overallState; } - private final class ChildLbState { + protected Helper getHelper() { + return helper; + } + + protected void removeChild(Object key) { + childLbStates.remove(key); + } + + /** + * Filters out non-ready and deactivated child load balancers (subchannels). + */ + protected List getReadyChildren() { + List activeChildren = new ArrayList<>(); + for (ChildLbState child : getChildLbStates()) { + if (!child.isDeactivated() && child.getCurrentState() == READY) { + activeChildren.add(child); + } + } + return activeChildren; + } + + /** + * This represents the state of load balancer children. Each endpoint (represented by an + * EquivalentAddressGroup or EDS string) will have a separate ChildLbState which in turn will + * define a GracefulSwitchLoadBalancer. When the GracefulSwitchLoadBalancer is activated, a + * single PickFirstLoadBalancer will be created which will then create a subchannel and start + * trying to connect to it. + * + *

A ChildLbStateHelper is the glue between ChildLbState and the helpers associated with the + * petiole policy above and the PickFirstLoadBalancer's helper below. + * + *

If you wish to store additional state information related to each subchannel, then extend + * this class. + */ + public class ChildLbState { private final Object key; + private ResolvedAddresses resolvedAddresses; + private final Object config; private final GracefulSwitchLoadBalancer lb; private LoadBalancerProvider policyProvider; private ConnectivityState currentState = CONNECTING; private SubchannelPicker currentPicker; private boolean deactivated; - @Nullable - ScheduledHandle deletionTimer; - ChildLbState(Object key, LoadBalancerProvider policyProvider, SubchannelPicker initialPicker) { + public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig, + SubchannelPicker initialPicker) { this.key = key; this.policyProvider = policyProvider; lb = new GracefulSwitchLoadBalancer(new ChildLbStateHelper()); lb.switchTo(policyProvider); currentPicker = initialPicker; + config = childConfig; } - void deactivate() { - if (deactivated) { - return; + @Override + public String toString() { + return "Address = " + key + + ", state = " + currentState + + ", picker type: " + currentPicker.getClass() + + ", lb: " + lb.delegate().getClass() + + (deactivated ? ", deactivated" : ""); + } + + public Object getKey() { + return key; + } + + Object getConfig() { + return config; + } + + public LoadBalancerProvider getPolicyProvider() { + return policyProvider; + } + + protected Subchannel getSubchannels(PickSubchannelArgs args) { + if (getCurrentPicker() == null) { + return null; } + return getCurrentPicker().pickSubchannel(args).getSubchannel(); + } - class DeletionTask implements Runnable { - @Override - public void run() { - shutdown(); - childLbStates.remove(key); - } + public ConnectivityState getCurrentState() { + return currentState; + } + + public SubchannelPicker getCurrentPicker() { + return currentPicker; + } + + public EquivalentAddressGroup getEag() { + if (resolvedAddresses == null || resolvedAddresses.getAddresses().isEmpty()) { + return null; } + return resolvedAddresses.getAddresses().get(0); + } - deletionTimer = - syncContext.schedule( - new DeletionTask(), - DELAYED_CHILD_DELETION_TIME_MINUTES, - TimeUnit.MINUTES, - timeService); + public boolean isDeactivated() { + return deactivated; + } + + protected void setDeactivated() { deactivated = true; - logger.log(Level.FINE, "Child balancer {0} deactivated", key); } - void reactivate(LoadBalancerProvider policyProvider) { - if (deletionTimer != null && deletionTimer.isPending()) { - deletionTimer.cancel(); - deactivated = false; - logger.log(Level.FINE, "Child balancer {0} reactivated", key); + protected void setResolvedAddresses(ResolvedAddresses newAddresses) { + checkNotNull(newAddresses, "Missing address list for child"); + resolvedAddresses = newAddresses; + } + + protected void deactivate() { + if (deactivated) { + return; } + + shutdown(); + childLbStates.remove(key); + deactivated = true; + logger.log(Level.FINE, "Child balancer {0} deactivated", key); + } + + protected void reactivate(LoadBalancerProvider policyProvider) { if (!this.policyProvider.getPolicyName().equals(policyProvider.getPolicyName())) { Object[] objects = { key, this.policyProvider.getPolicyName(),policyProvider.getPolicyName()}; logger.log(Level.FINE, "Child balancer {0} switching policy from {1} to {2}", objects); lb.switchTo(policyProvider); this.policyProvider = policyProvider; + } else { + logger.log(Level.FINE, "Child balancer {0} reactivated", key); + lb.acceptResolvedAddresses(resolvedAddresses); } + + deactivated = false; } - void shutdown() { - if (deletionTimer != null && deletionTimer.isPending()) { - deletionTimer.cancel(); - } + protected void shutdown() { lb.shutdown(); + this.currentState = SHUTDOWN; logger.log(Level.FINE, "Child balancer {0} deleted", key); } + /** + * ChildLbStateHelper is the glue between ChildLbState and the helpers associated with the + * petiole policy above and the PickFirstLoadBalancer's helper below. + * + *

The ChildLbState updates happen during updateBalancingState. Otherwise, it is doing + * simple forwarding. + */ private final class ChildLbStateHelper extends ForwardingLoadBalancerHelper { @Override @@ -251,6 +457,9 @@ public void updateBalancingState(final ConnectivityState newState, currentState = newState; currentPicker = newPicker; if (!deactivated && !resolvingAddresses) { + if (newState == IDLE && reconnectOnIdle()) { + lb.requestConnection(); + } updateOverallBalancingState(); } } @@ -261,4 +470,58 @@ protected Helper delegate() { } } } + + /** + * Endpoint is an optimization to quickly lookup and compare EquivalentAddressGroup address sets. + * Ignores the attributes, orders the addresses in a deterministic manner and converts each + * address into a string for easy comparison. Also caches the hashcode. + * Is used as a key for ChildLbState for most load balancers (ClusterManagerLB uses a String). + */ + protected static class Endpoint { + final String[] addrs; + final int hashCode; + + Endpoint(EquivalentAddressGroup eag) { + checkNotNull(eag, "eag"); + + addrs = new String[eag.getAddresses().size()]; + int i = 0; + for (SocketAddress address : eag.getAddresses()) { + addrs[i] = address.toString(); + } + Arrays.sort(addrs); + + hashCode = Arrays.hashCode(addrs); + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null) { + return false; + } + + if (!(other instanceof Endpoint)) { + return false; + } + Endpoint o = (Endpoint) other; + if (o.hashCode != hashCode || o.addrs.length != addrs.length) { + return false; + } + + return Arrays.equals(o.addrs, this.addrs); + } + + @Override + public String toString() { + return Arrays.toString(addrs); + } + } } diff --git a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java index 56097084928..3bd83ffe104 100644 --- a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java @@ -16,11 +16,9 @@ package io.grpc.util; -import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; -import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; @@ -37,13 +35,10 @@ import io.grpc.Status; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; -import java.util.Set; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import javax.annotation.Nonnull; @@ -52,131 +47,22 @@ * EquivalentAddressGroup}s from the {@link NameResolver}. */ @Internal -public class RoundRobinLoadBalancer extends LoadBalancer { +public class RoundRobinLoadBalancer extends MultiChildLoadBalancer { @VisibleForTesting static final Attributes.Key> STATE_INFO = Attributes.Key.create("state-info"); - private final Helper helper; - private final Map subchannels = - new HashMap<>(); private final Random random; - private ConnectivityState currentState; protected RoundRobinPicker currentPicker = new EmptyPicker(EMPTY_OK); public RoundRobinLoadBalancer(Helper helper) { - this.helper = checkNotNull(helper, "helper"); + super(helper); this.random = new Random(); } @Override - public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - if (resolvedAddresses.getAddresses().isEmpty()) { - handleNameResolutionError(Status.UNAVAILABLE.withDescription( - "NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses() - + ", attrs=" + resolvedAddresses.getAttributes())); - return false; - } - - List servers = resolvedAddresses.getAddresses(); - Set currentAddrs = subchannels.keySet(); - Map latestAddrs = stripAttrs(servers); - Set removedAddrs = setsDifference(currentAddrs, latestAddrs.keySet()); - - for (Map.Entry latestEntry : - latestAddrs.entrySet()) { - EquivalentAddressGroup strippedAddressGroup = latestEntry.getKey(); - EquivalentAddressGroup originalAddressGroup = latestEntry.getValue(); - Subchannel existingSubchannel = subchannels.get(strippedAddressGroup); - if (existingSubchannel != null) { - // EAG's Attributes may have changed. - existingSubchannel.updateAddresses(Collections.singletonList(originalAddressGroup)); - continue; - } - // Create new subchannels for new addresses. - - // NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel - // doesn't need them. They're describing the resolved server list but we're not taking - // any action based on this information. - Attributes.Builder subchannelAttrs = Attributes.newBuilder() - // NB(lukaszx0): because attributes are immutable we can't set new value for the key - // after creation but since we can mutate the values we leverage that and set - // AtomicReference which will allow mutating state info for given channel. - .set(STATE_INFO, - new Ref<>(ConnectivityStateInfo.forNonError(IDLE))); - - final Subchannel subchannel = checkNotNull( - helper.createSubchannel(CreateSubchannelArgs.newBuilder() - .setAddresses(originalAddressGroup) - .setAttributes(subchannelAttrs.build()) - .build()), - "subchannel"); - subchannel.start(new SubchannelStateListener() { - @Override - public void onSubchannelState(ConnectivityStateInfo state) { - processSubchannelState(subchannel, state); - } - }); - subchannels.put(strippedAddressGroup, subchannel); - subchannel.requestConnection(); - } - - ArrayList removedSubchannels = new ArrayList<>(); - for (EquivalentAddressGroup addressGroup : removedAddrs) { - removedSubchannels.add(subchannels.remove(addressGroup)); - } - - // Update the picker before shutting down the subchannels, to reduce the chance of the race - // between picking a subchannel and shutting it down. - updateBalancingState(); - - // Shutdown removed subchannels - for (Subchannel removedSubchannel : removedSubchannels) { - shutdownSubchannel(removedSubchannel); - } - - return true; - } - - @Override - public void handleNameResolutionError(Status error) { - if (currentState != READY) { - updateBalancingState(TRANSIENT_FAILURE, new EmptyPicker(error)); - } - } - - private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { - if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) { - return; - } - if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) { - helper.refreshNameResolution(); - } - if (stateInfo.getState() == IDLE) { - subchannel.requestConnection(); - } - Ref subchannelStateRef = getSubchannelStateInfoRef(subchannel); - if (subchannelStateRef.value.getState().equals(TRANSIENT_FAILURE)) { - if (stateInfo.getState().equals(CONNECTING) || stateInfo.getState().equals(IDLE)) { - return; - } - } - subchannelStateRef.value = stateInfo; - updateBalancingState(); - } - - private void shutdownSubchannel(Subchannel subchannel) { - subchannel.shutdown(); - getSubchannelStateInfoRef(subchannel).value = - ConnectivityStateInfo.forNonError(SHUTDOWN); - } - - @Override - public void shutdown() { - for (Subchannel subchannel : getSubchannels()) { - shutdownSubchannel(subchannel); - } - subchannels.clear(); + protected SubchannelPicker getSubchannelPicker(Map childPickers) { + throw new UnsupportedOperationException(); // local updateOverallBalancingState doesn't use this } private static final Status EMPTY_OK = Status.OK.withDescription("no subchannels ready"); @@ -184,102 +70,54 @@ public void shutdown() { /** * Updates picker with the list of active subchannels (state == READY). */ - @SuppressWarnings("ReferenceEquality") - private void updateBalancingState() { - List activeList = filterNonFailingSubchannels(getSubchannels()); + @Override + protected void updateOverallBalancingState() { + List activeList = getReadyChildren(); if (activeList.isEmpty()) { - // No READY subchannels, determine aggregate state and error status + // No READY subchannels + + // RRLB will request connection immediately on subchannel IDLE. boolean isConnecting = false; - Status aggStatus = EMPTY_OK; - for (Subchannel subchannel : getSubchannels()) { - ConnectivityStateInfo stateInfo = getSubchannelStateInfoRef(subchannel).value; - // This subchannel IDLE is not because of channel IDLE_TIMEOUT, - // in which case LB is already shutdown. - // RRLB will request connection immediately on subchannel IDLE. - if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) { + for (ChildLbState childLbState : getChildLbStates()) { + ConnectivityState state = childLbState.getCurrentState(); + if (state == CONNECTING || state == IDLE) { isConnecting = true; + break; } - if (aggStatus == EMPTY_OK || !aggStatus.isOk()) { - aggStatus = stateInfo.getStatus(); - } } - updateBalancingState(isConnecting ? CONNECTING : TRANSIENT_FAILURE, - // If all subchannels are TRANSIENT_FAILURE, return the Status associated with - // an arbitrary subchannel, otherwise return OK. - new EmptyPicker(aggStatus)); + + if (isConnecting) { + updateBalancingState(CONNECTING, new EmptyPicker(Status.OK)); + } else { + updateBalancingState(TRANSIENT_FAILURE, createReadyPicker(getChildLbStates())); + } } else { updateBalancingState(READY, createReadyPicker(activeList)); } } private void updateBalancingState(ConnectivityState state, RoundRobinPicker picker) { - if (state != currentState || !picker.isEquivalentTo(currentPicker)) { - helper.updateBalancingState(state, picker); - currentState = state; + if (state != currentConnectivityState || !picker.isEquivalentTo(currentPicker)) { + getHelper().updateBalancingState(state, picker); + currentConnectivityState = state; currentPicker = picker; } } - protected RoundRobinPicker createReadyPicker(List activeList) { + protected RoundRobinPicker createReadyPicker(Collection children) { // initialize the Picker to a random start index to ensure that a high frequency of Picker // churn does not skew subchannel selection. - int startIndex = random.nextInt(activeList.size()); - return new ReadyPicker(activeList, startIndex); - } + int startIndex = random.nextInt(children.size()); - /** - * Filters out non-ready subchannels. - */ - private static List filterNonFailingSubchannels( - Collection subchannels) { - List readySubchannels = new ArrayList<>(subchannels.size()); - for (Subchannel subchannel : subchannels) { - if (isReady(subchannel)) { - readySubchannels.add(subchannel); - } + List pickerList = new ArrayList<>(); + for (ChildLbState child : children) { + SubchannelPicker picker = child.getCurrentPicker(); + pickerList.add(picker); } - return readySubchannels; - } - - /** - * Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and - * remove all attributes. The values are the original EAGs. - */ - private static Map stripAttrs( - List groupList) { - Map addrs = new HashMap<>(groupList.size() * 2); - for (EquivalentAddressGroup group : groupList) { - addrs.put(stripAttrs(group), group); - } - return addrs; - } - - private static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { - return new EquivalentAddressGroup(eag.getAddresses()); - } - - @VisibleForTesting - protected Collection getSubchannels() { - return subchannels.values(); - } - - private static Ref getSubchannelStateInfoRef( - Subchannel subchannel) { - return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO"); - } - - // package-private to avoid synthetic access - static boolean isReady(Subchannel subchannel) { - return getSubchannelStateInfoRef(subchannel).value.getState() == READY; - } - private static Set setsDifference(Set a, Set b) { - Set aCopy = new HashSet<>(a); - aCopy.removeAll(b); - return aCopy; + return new ReadyPicker(pickerList, startIndex); } - // Only subclasses are ReadyPicker or EmptyPicker public abstract static class RoundRobinPicker extends SubchannelPicker { public abstract boolean isEquivalentTo(RoundRobinPicker picker); } @@ -289,40 +127,42 @@ static class ReadyPicker extends RoundRobinPicker { private static final AtomicIntegerFieldUpdater indexUpdater = AtomicIntegerFieldUpdater.newUpdater(ReadyPicker.class, "index"); - private final List list; // non-empty + private final List subchannelPickers; // non-empty @SuppressWarnings("unused") private volatile int index; - public ReadyPicker(List list, int startIndex) { + public ReadyPicker(List list, int startIndex) { Preconditions.checkArgument(!list.isEmpty(), "empty list"); - this.list = list; + this.subchannelPickers = list; this.index = startIndex - 1; } @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withSubchannel(nextSubchannel()); + return subchannelPickers.get(nextIndex()).pickSubchannel(args); } @Override public String toString() { - return MoreObjects.toStringHelper(ReadyPicker.class).add("list", list).toString(); + return MoreObjects.toStringHelper(ReadyPicker.class) + .add("subchannelPickers", subchannelPickers) + .toString(); } - private Subchannel nextSubchannel() { - int size = list.size(); + private int nextIndex() { + int size = subchannelPickers.size(); int i = indexUpdater.incrementAndGet(this); if (i >= size) { int oldi = i; i %= size; indexUpdater.compareAndSet(this, oldi, i); } - return list.get(i); + return i; } @VisibleForTesting - List getList() { - return list; + List getSubchannelPickers() { + return subchannelPickers; } @Override @@ -333,7 +173,8 @@ public boolean isEquivalentTo(RoundRobinPicker picker) { ReadyPicker other = (ReadyPicker) picker; // the lists cannot contain duplicate subchannels return other == this - || (list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list)); + || (subchannelPickers.size() == other.subchannelPickers.size() && new HashSet<>( + subchannelPickers).containsAll(other.subchannelPickers)); } } diff --git a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java index 13f13421a1e..ac5bd8b98c4 100644 --- a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java @@ -512,7 +512,7 @@ public void successRateOneOutlier_configChange() { loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); - generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 8); + generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 12); // Move forward in time to a point where the detection timer has fired. forwardTime(config); @@ -546,7 +546,7 @@ public void successRateOneOutlier_unejected() { assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0))); // Now we produce more load, but the subchannel start working and is no longer an outlier. - generateLoad(ImmutableMap.of(), 8); + generateLoad(ImmutableMap.of(), 12); // Move forward in time to a point where the detection timer has fired. fakeClock.forwardTime(config.maxEjectionTimeNanos + 1, TimeUnit.NANOSECONDS); diff --git a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index 23b6e1c10c8..25686ac8a39 100644 --- a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -22,26 +22,23 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.util.RoundRobinLoadBalancer.STATE_INFO; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; @@ -53,18 +50,20 @@ import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; -import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Status; +import io.grpc.internal.TestUtils; +import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.util.RoundRobinLoadBalancer.EmptyPicker; import io.grpc.util.RoundRobinLoadBalancer.ReadyPicker; -import io.grpc.util.RoundRobinLoadBalancer.Ref; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -75,10 +74,8 @@ import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; -import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; -import org.mockito.stubbing.Answer; /** Unit test for {@link RoundRobinLoadBalancer}. */ @RunWith(JUnit4.class) @@ -89,9 +86,8 @@ public class RoundRobinLoadBalancerTest { private RoundRobinLoadBalancer loadBalancer; private final List servers = Lists.newArrayList(); - private final Map, Subchannel> subchannels = Maps.newLinkedHashMap(); - private final Map subchannelStateListeners = - Maps.newLinkedHashMap(); + private final Map, Subchannel> subchannels = + new ConcurrentHashMap<>(); private final Attributes affinity = Attributes.newBuilder().set(MAJOR_KEY, "I got the keys").build(); @@ -101,8 +97,8 @@ public class RoundRobinLoadBalancerTest { private ArgumentCaptor stateCaptor; @Captor private ArgumentCaptor createArgsCaptor; - @Mock - private Helper mockHelper; + private TestHelper testHelperInst = new TestHelper(); + private Helper mockHelper = mock(Helper.class, delegatesTo(testHelperInst)); @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown(). private PickSubchannelArgs mockArgs; @@ -113,34 +109,16 @@ public void setUp() { SocketAddress addr = new FakeSocketAddress("server" + i); EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); servers.add(eag); - Subchannel sc = mock(Subchannel.class); - subchannels.put(Arrays.asList(eag), sc); } - when(mockHelper.createSubchannel(any(CreateSubchannelArgs.class))) - .then(new Answer() { - @Override - public Subchannel answer(InvocationOnMock invocation) throws Throwable { - CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; - final Subchannel subchannel = subchannels.get(args.getAddresses()); - when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); - when(subchannel.getAttributes()).thenReturn(args.getAttributes()); - doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - subchannelStateListeners.put( - subchannel, (SubchannelStateListener) invocation.getArguments()[0]); - return null; - } - }).when(subchannel).start(any(SubchannelStateListener.class)); - return subchannel; - } - }); - loadBalancer = new RoundRobinLoadBalancer(mockHelper); } + private boolean acceptAddresses(List eagList, Attributes attrs) { + return loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(eagList).setAttributes(attrs).build()); + } + @After public void tearDown() throws Exception { verifyNoMoreInteractions(mockArgs); @@ -148,10 +126,9 @@ public void tearDown() throws Exception { @Test public void pickAfterResolved() throws Exception { - final Subchannel readySubchannel = subchannels.values().iterator().next(); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + boolean addressesAccepted = acceptAddresses(servers, affinity); assertThat(addressesAccepted).isTrue(); + final Subchannel readySubchannel = subchannels.values().iterator().next(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture()); @@ -178,10 +155,6 @@ public void pickAfterResolved() throws Exception { @Test public void pickAfterResolvedUpdatedHosts() throws Exception { - Subchannel removedSubchannel = mock(Subchannel.class); - Subchannel oldSubchannel = mock(Subchannel.class); - Subchannel newSubchannel = mock(Subchannel.class); - Attributes.Key key = Attributes.Key.create("check-that-it-is-propagated"); FakeSocketAddress removedAddr = new FakeSocketAddress("removed"); EquivalentAddressGroup removedEag = new EquivalentAddressGroup(removedAddr); @@ -193,6 +166,13 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { EquivalentAddressGroup newEag = new EquivalentAddressGroup( newAddr, Attributes.newBuilder().set(key, "newattr").build()); + Subchannel removedSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(removedEag).build()); + Subchannel oldSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(oldEag1).build()); + Subchannel newSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(newEag).build()); + subchannels.put(Collections.singletonList(removedEag), removedSubchannel); subchannels.put(Collections.singletonList(oldEag1), oldSubchannel); subchannels.put(Collections.singletonList(newEag), newSubchannel); @@ -201,9 +181,7 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(currentServers).setAttributes(affinity) - .build()); + boolean addressesAccepted = acceptAddresses(currentServers, affinity); assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -218,8 +196,11 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { verify(removedSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).requestConnection(); - assertThat(loadBalancer.getSubchannels()).containsExactly(removedSubchannel, - oldSubchannel); + assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2); + assertThat(loadBalancer.getChildLbStateEag(removedEag).getCurrentPicker().pickSubchannel(null) + .getSubchannel()).isEqualTo(removedSubchannel); + assertThat(loadBalancer.getChildLbStateEag(oldEag1).getCurrentPicker().pickSubchannel(null) + .getSubchannel()).isEqualTo(oldSubchannel); // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); @@ -232,13 +213,15 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { verify(oldSubchannel, times(1)).updateAddresses(Arrays.asList(oldEag2)); verify(removedSubchannel, times(1)).shutdown(); - deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(SHUTDOWN)); deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(loadBalancer.getSubchannels()).containsExactly(oldSubchannel, - newSubchannel); + assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2); + assertThat(loadBalancer.getChildLbStateEag(newEag).getCurrentPicker() + .pickSubchannel(null).getSubchannel()).isEqualTo(newSubchannel); + assertThat(loadBalancer.getChildLbStateEag(oldEag2).getCurrentPicker() + .pickSubchannel(null).getSubchannel()).isEqualTo(oldSubchannel); - verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(mockHelper, times(6)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); picker = pickerCaptor.getValue(); @@ -250,29 +233,26 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) - .build()); + boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAccepted).isTrue(); - Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); - Ref subchannelStateInfo = subchannel.getAttributes().get( - STATE_INFO); + + // TODO figure out if this method testing the right things + + ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); + Subchannel subchannel = childLbState.getCurrentPicker().pickSubchannel(null).getSubchannel(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); - assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE)); + assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING); - deliverSubchannelState(subchannel, - ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class); - assertThat(subchannelStateInfo.value).isEqualTo( - ConnectivityStateInfo.forNonError(READY)); + assertThat(childLbState.getCurrentState()).isEqualTo(READY); Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"); deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); - assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); - assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error); + assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); inOrder.verify(mockHelper).refreshNameResolution(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); @@ -280,8 +260,7 @@ public void pickAfterStateChange() throws Exception { deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).refreshNameResolution(); - assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); - assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error); + assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); verify(subchannel, times(2)).requestConnection(); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); @@ -291,15 +270,14 @@ public void pickAfterStateChange() throws Exception { @Test public void ignoreShutdownSubchannelStateChange() { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) - .build()); + boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); loadBalancer.shutdown(); - for (Subchannel sc : loadBalancer.getSubchannels()) { - verify(sc).shutdown(); + for (ChildLbState child : loadBalancer.getChildLbStates()) { + Subchannel sc = child.getCurrentPicker().pickSubchannel(null).getSubchannel(); + verify(child).shutdown(); // When the subchannel is being shut down, a SHUTDOWN connectivity state is delivered // back to the subchannel state listener. deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(SHUTDOWN)); @@ -311,36 +289,34 @@ public void ignoreShutdownSubchannelStateChange() { @Test public void stayTransientFailureUntilReady() { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) - .build()); + boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + Map childToSubChannelMap = new HashMap<>(); // Simulate state transitions for each subchannel individually. - for (Subchannel sc : loadBalancer.getSubchannels()) { + for ( ChildLbState child : loadBalancer.getChildLbStates()) { + Subchannel sc = child.getSubchannels(mockArgs); + childToSubChannelMap.put(child, sc); Status error = Status.UNKNOWN.withDescription("connection broken"); deliverSubchannelState( sc, ConnectivityStateInfo.forTransientFailure(error)); + assertEquals(TRANSIENT_FAILURE, child.getCurrentState()); inOrder.verify(mockHelper).refreshNameResolution(); deliverSubchannelState( sc, ConnectivityStateInfo.forNonError(CONNECTING)); - Ref scStateInfo = sc.getAttributes().get( - STATE_INFO); - assertThat(scStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); - assertThat(scStateInfo.value.getStatus()).isEqualTo(error); + assertEquals(TRANSIENT_FAILURE, child.getCurrentState()); } - inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(EmptyPicker.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(ReadyPicker.class)); inOrder.verifyNoMoreInteractions(); - Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); + ChildLbState child = loadBalancer.getChildLbStates().iterator().next(); + Subchannel subchannel = childToSubChannelMap.get(child); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - Ref subchannelStateInfo = subchannel.getAttributes().get( - STATE_INFO); - assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(READY)); + assertThat(child.getCurrentState()).isEqualTo(READY); inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); @@ -350,16 +326,15 @@ public void stayTransientFailureUntilReady() { @Test public void refreshNameResolutionWhenSubchannelConnectionBroken() { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) - .build()); + boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAccepted).isTrue(); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); // Simulate state transitions for each subchannel individually. - for (Subchannel sc : loadBalancer.getSubchannels()) { + for (ChildLbState child : loadBalancer.getChildLbStates()) { + Subchannel sc = child.getSubchannels(mockArgs); verify(sc).requestConnection(); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); Status error = Status.UNKNOWN.withDescription("connection broken"); @@ -383,11 +358,12 @@ public void pickerRoundRobin() throws Exception { Subchannel subchannel1 = mock(Subchannel.class); Subchannel subchannel2 = mock(Subchannel.class); - ReadyPicker picker = new ReadyPicker(Collections.unmodifiableList( - Lists.newArrayList(subchannel, subchannel1, subchannel2)), - 0 /* startIndex */); + ArrayList pickers = Lists.newArrayList( + TestUtils.pickerOf(subchannel), TestUtils.pickerOf(subchannel1), + TestUtils.pickerOf(subchannel2)); - assertThat(picker.getList()).containsExactly(subchannel, subchannel1, subchannel2); + ReadyPicker picker = new ReadyPicker(Collections.unmodifiableList(pickers), + 0 /* startIndex */); assertEquals(subchannel, picker.pickSubchannel(mockArgs).getSubchannel()); assertEquals(subchannel1, picker.pickSubchannel(mockArgs).getSubchannel()); @@ -399,7 +375,7 @@ public void pickerRoundRobin() throws Exception { public void pickerEmptyList() throws Exception { SubchannelPicker picker = new EmptyPicker(Status.UNKNOWN); - assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel()); + assertNull(picker.pickSubchannel(mockArgs).getSubchannel()); assertEquals(Status.UNKNOWN, picker.pickSubchannel(mockArgs).getStatus()); } @@ -417,12 +393,13 @@ public void nameResolutionErrorWithNoChannels() throws Exception { @Test public void nameResolutionErrorWithActiveChannels() throws Exception { + boolean addressesAccepted = acceptAddresses(servers, affinity); final Subchannel readySubchannel = subchannels.values().iterator().next(); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + loadBalancer.resolvingAddresses = true; loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); + loadBalancer.resolvingAddresses = false; verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(mockHelper, times(2)) @@ -443,15 +420,14 @@ public void nameResolutionErrorWithActiveChannels() throws Exception { @Test public void subchannelStateIsolation() throws Exception { + boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); + assertThat(addressesAccepted).isTrue(); + Iterator subchannelIterator = subchannels.values().iterator(); Subchannel sc1 = subchannelIterator.next(); Subchannel sc2 = subchannelIterator.next(); Subchannel sc3 = subchannelIterator.next(); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) - .build()); - assertThat(addressesAccepted).isTrue(); verify(sc1, times(1)).requestConnection(); verify(sc2, times(1)).requestConnection(); verify(sc3, times(1)).requestConnection(); @@ -491,7 +467,7 @@ public void subchannelStateIsolation() throws Exception { public void readyPicker_emptyList() { // ready picker list must be non-empty try { - new ReadyPicker(Collections.emptyList(), 0); + new ReadyPicker(Collections.emptyList(), 0); fail(); } catch (IllegalArgumentException expected) { } @@ -503,9 +479,10 @@ public void internalPickerComparisons() { EmptyPicker emptyOk2 = new EmptyPicker(Status.OK.withDescription("different OK")); EmptyPicker emptyErr = new EmptyPicker(Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯")); + acceptAddresses(servers, Attributes.EMPTY); // create subchannels Iterator subchannelIterator = subchannels.values().iterator(); - Subchannel sc1 = subchannelIterator.next(); - Subchannel sc2 = subchannelIterator.next(); + SubchannelPicker sc1 = TestUtils.pickerOf(subchannelIterator.next()); + SubchannelPicker sc2 = TestUtils.pickerOf(subchannelIterator.next()); ReadyPicker ready1 = new ReadyPicker(Arrays.asList(sc1, sc2), 0); ReadyPicker ready2 = new ReadyPicker(Arrays.asList(sc1), 0); ReadyPicker ready3 = new ReadyPicker(Arrays.asList(sc2, sc1), 1); @@ -526,18 +503,26 @@ public void internalPickerComparisons() { public void emptyAddresses() { assertThat(loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() - .setAddresses(Collections.emptyList()) + .setAddresses(Collections.emptyList()) .setAttributes(affinity) .build())).isFalse(); } - private static List getList(SubchannelPicker picker) { - return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() : - Collections.emptyList(); + private List getList(SubchannelPicker picker) { + + if (picker instanceof ReadyPicker) { + List subchannelList = new ArrayList<>(); + for (SubchannelPicker childPicker : ((ReadyPicker) picker).getSubchannelPickers()) { + subchannelList.add(childPicker.pickSubchannel(mockArgs).getSubchannel()); + } + return subchannelList; + } else { + return new ArrayList<>(); + } } private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { - subchannelStateListeners.get(subchannel).onSubchannelState(newState); + testHelperInst.deliverSubchannelState(subchannel, newState); } private static class FakeSocketAddress extends SocketAddress { @@ -552,4 +537,12 @@ public String toString() { return "FakeSocketAddress-" + name; } } + + private class TestHelper extends AbstractTestHelper { + + @Override + public Map, Subchannel> getSubchannelMap() { + return subchannels; + } + } } diff --git a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java new file mode 100644 index 00000000000..2afb133877b --- /dev/null +++ b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java @@ -0,0 +1,196 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.Mockito.mock; + +import com.google.common.collect.Maps; +import io.grpc.Attributes; +import io.grpc.Channel; +import io.grpc.ChannelLogger; +import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A real class that can be used as a delegate of a mock Helper to provide more real representation + * and track the subchannels as is needed with petiole policies where the subchannels are no + * longer direct children of the loadbalancer. + *
+ * To use it replace
+ * \@mock Helper mockHelper
+ * with
+ *

Helper mockHelper = mock(Helper.class, delegatesTo(new TestHelper()));

+ *
+ * TestHelper will need to define accessors for the maps that information is store within as + * those maps need to be defined in the Test class. + */ +public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper { + + private final Map mockToRealSubChannelMap = new HashMap<>(); + private final Map subchannelStateListeners = + Maps.newLinkedHashMap(); + + public abstract Map, Subchannel> getSubchannelMap(); + + public Map getMockToRealSubChannelMap() { + return mockToRealSubChannelMap; + } + + public Subchannel getRealForMockSubChannel(Subchannel mock) { + Subchannel realSc = getMockToRealSubChannelMap().get(mock); + if (realSc == null) { + realSc = mock; + } + return realSc; + } + + public Map getSubchannelStateListeners() { + return subchannelStateListeners; + } + + public void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { + Subchannel realSc = getMockToRealSubChannelMap().get(subchannel); + if (realSc == null) { + realSc = subchannel; + } + SubchannelStateListener listener = getSubchannelStateListeners().get(realSc); + if (listener == null) { + throw new IllegalArgumentException("subchannel does not have a matching listener"); + } + listener.onSubchannelState(newState); + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + // do nothing, should have been done in the wrapper helpers + } + + @Override + protected Helper delegate() { + throw new UnsupportedOperationException("This helper class is only for use in this test"); + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + Subchannel subchannel = getSubchannelMap().get(args.getAddresses()); + if (subchannel == null) { + TestSubchannel delegate = new TestSubchannel(args); + subchannel = mock(Subchannel.class, delegatesTo(delegate)); + getSubchannelMap().put(args.getAddresses(), subchannel); + getMockToRealSubChannelMap().put(subchannel, delegate); + } + + return subchannel; + } + + @Override + public void refreshNameResolution() { + // no-op + } + + public void setChannel(Subchannel subchannel, Channel channel) { + ((TestSubchannel)subchannel).channel = channel; + } + + @Override + public String toString() { + return "Test Helper"; + } + + private class TestSubchannel extends ForwardingSubchannel { + CreateSubchannelArgs args; + Channel channel; + + public TestSubchannel(CreateSubchannelArgs args) { + this.args = args; + } + + @Override + protected Subchannel delegate() { + throw new UnsupportedOperationException("Only to be used in tests"); + } + + @Override + public List getAllAddresses() { + return args.getAddresses(); + } + + @Override + public Attributes getAttributes() { + return args.getAttributes(); + } + + @Override + public void requestConnection() { + // Ignore, we will manually update state + } + + @Override + public void updateAddresses(List addrs) { + if (args.getAddresses().equals(addrs)) { + return; // no changes so it's a no-op + } + + List oldAddrs = args.getAddresses(); + Subchannel oldTarget = getSubchannelMap().get(oldAddrs); + + this.args = args.toBuilder().setAddresses(addrs).build(); + getSubchannelMap().put(addrs, oldTarget); + getSubchannelMap().remove(oldAddrs); + } + + @Override + public void start(SubchannelStateListener listener) { + getSubchannelStateListeners().put(this, listener); + } + + @Override + public void shutdown() { + getSubchannelStateListeners().remove(this); + for (EquivalentAddressGroup eag : getAllAddresses()) { + getSubchannelMap().remove(Collections.singletonList(eag)); + } + } + + @Override + public Channel asChannel() { + return channel; + } + + @Override + public ChannelLogger getChannelLogger() { + return mock(ChannelLogger.class); + } + + @Override + public String toString() { + return "Mock Subchannel" + args.toString(); + } + } +} + diff --git a/xds/build.gradle b/xds/build.gradle index 3f3cf6a0f6e..a6db9db9937 100644 --- a/xds/build.gradle +++ b/xds/build.gradle @@ -58,7 +58,8 @@ dependencies { def nettyDependency = implementation project(':grpc-netty') testImplementation project(':grpc-rls') - testImplementation testFixtures(project(':grpc-core')) + testImplementation testFixtures(project(':grpc-core')), + testFixtures(project(':grpc-util')) annotationProcessor libraries.auto.value // At runtime use the epoll included in grpc-netty-shaded diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java index a4489204236..62fab0d12a6 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java @@ -16,36 +16,77 @@ package io.grpc.xds; +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; + +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import io.grpc.InternalLogId; +import io.grpc.LoadBalancerProvider; import io.grpc.Status; +import io.grpc.SynchronizationContext; +import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.util.MultiChildLoadBalancer; import io.grpc.xds.ClusterManagerLoadBalancerProvider.ClusterManagerConfig; import io.grpc.xds.XdsLogger.XdsLogLevel; import java.util.HashMap; import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; /** - * The top-level load balancing policy. + * The top-level load balancing policy for use in XDS. + * This policy does not immediately delete its children. Instead, it marks them deactivated + * and starts a timer for deletion. If a subsequent address update restores the child, then it is + * simply reactivated instead of built from scratch. This is necessary because XDS can frequently + * remove and then add back a server as machines are rebooted or repurposed for load management. + * + *

Note that this LB does not automatically reconnect children who go into IDLE status */ class ClusterManagerLoadBalancer extends MultiChildLoadBalancer { + // 15 minutes is long enough for a reboot and the services to restart while not so long that + // many children are waiting for cleanup. + @VisibleForTesting + public static final int DELAYED_CHILD_DELETION_TIME_MINUTES = 15; + protected final SynchronizationContext syncContext; + private final ScheduledExecutorService timeService; private final XdsLogger logger; ClusterManagerLoadBalancer(Helper helper) { super(helper); + this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); + this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); logger = XdsLogger.withLogId( InternalLogId.allocate("cluster_manager-lb", helper.getAuthority())); + logger.log(XdsLogLevel.INFO, "Created"); } @Override - protected Map getPolicySelectionMap( - ResolvedAddresses resolvedAddresses) { + protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses, + Object childConfig) { + return resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build(); + } + + @Override + protected Map createChildLbMap(ResolvedAddresses resolvedAddresses) { ClusterManagerConfig config = (ClusterManagerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - Map newChildPolicies = new HashMap<>(config.childPolicies); + Map newChildPolicies = new HashMap<>(); + if (config != null) { + for (Entry entry : config.childPolicies.entrySet()) { + ChildLbState child = getChildLbState(entry.getKey()); + if (child == null) { + child = new ClusterManagerLbState(entry.getKey(), + entry.getValue().getProvider(), entry.getValue().getConfig(), getInitialPicker()); + } + newChildPolicies.put(entry.getKey(), child); + } + } logger.log( XdsLogLevel.INFO, "Received cluster_manager lb config: child names={0}", newChildPolicies.keySet()); @@ -75,4 +116,84 @@ public String toString() { } }; } + + @Override + public void handleNameResolutionError(Status error) { + logger.log(XdsLogLevel.WARNING, "Received name resolution error: {0}", error); + boolean gotoTransientFailure = true; + for (ChildLbState state : getChildLbStates()) { + if (!state.isDeactivated()) { + gotoTransientFailure = false; + handleNameResolutionError(state, error); + } + } + if (gotoTransientFailure) { + getHelper().updateBalancingState(TRANSIENT_FAILURE, getErrorPicker(error)); + } + } + + @Override + protected boolean reconnectOnIdle() { + return false; + } + + /** + * This differs from the base class in the use of the deletion timer. When it is deactivated, + * rather than immediately calling shutdown it starts a timer. If shutdown or reactivate + * are called before the timer fires, the timer is canceled. Otherwise, time timer calls shutdown + * and removes the child from the petiole policy when it is triggered. + */ + private class ClusterManagerLbState extends ChildLbState { + @Nullable + ScheduledHandle deletionTimer; + + public ClusterManagerLbState(Object key, LoadBalancerProvider policyProvider, + Object childConfig, SubchannelPicker initialPicker) { + super(key, policyProvider, childConfig, initialPicker); + } + + @Override + protected void shutdown() { + if (deletionTimer != null && deletionTimer.isPending()) { + deletionTimer.cancel(); + } + super.shutdown(); + } + + @Override + protected void reactivate(LoadBalancerProvider policyProvider) { + if (deletionTimer != null && deletionTimer.isPending()) { + deletionTimer.cancel(); + logger.log(XdsLogLevel.DEBUG, "Child balancer {0} reactivated", getKey()); + } + + super.reactivate(policyProvider); + } + + @Override + protected void deactivate() { + if (isDeactivated()) { + return; + } + + class DeletionTask implements Runnable { + + @Override + public void run() { + shutdown(); + removeChild(getKey()); + } + } + + deletionTimer = + syncContext.schedule( + new DeletionTask(), + DELAYED_CHILD_DELETION_TIME_MINUTES, + TimeUnit.MINUTES, + timeService); + setDeactivated(); + logger.log(XdsLogLevel.DEBUG, "Child balancer {0} deactivated", getKey()); + } + + } } diff --git a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java index b4aa39821d2..8c2ae612db5 100644 --- a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java @@ -21,7 +21,6 @@ import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; -import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.xds.LeastRequestLoadBalancerProvider.DEFAULT_CHOICE_COUNT; import static io.grpc.xds.LeastRequestLoadBalancerProvider.MAX_CHOICE_COUNT; @@ -35,20 +34,17 @@ import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.ConnectivityState; -import io.grpc.ConnectivityStateInfo; -import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.util.MultiChildLoadBalancer; import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.Nonnull; @@ -60,21 +56,13 @@ * The default sampling amount of two is also known as * the "power of two choices" (P2C). */ -final class LeastRequestLoadBalancer extends LoadBalancer { - @VisibleForTesting - static final Attributes.Key> STATE_INFO = - Attributes.Key.create("state-info"); - @VisibleForTesting - static final Attributes.Key IN_FLIGHTS = - Attributes.Key.create("in-flights"); +final class LeastRequestLoadBalancer extends MultiChildLoadBalancer { + private static final Status EMPTY_OK = Status.OK.withDescription("no subchannels ready"); + private static final EmptyPicker EMPTY_LR_PICKER = new EmptyPicker(EMPTY_OK); - private final Helper helper; private final ThreadSafeRandom random; - private final Map subchannels = - new HashMap<>(); - private ConnectivityState currentState; - private LeastRequestPicker currentPicker = new EmptyPicker(EMPTY_OK); + private LeastRequestPicker currentPicker = EMPTY_LR_PICKER; private int choiceCount = DEFAULT_CHOICE_COUNT; LeastRequestLoadBalancer(Helper helper) { @@ -83,255 +71,167 @@ final class LeastRequestLoadBalancer extends LoadBalancer { @VisibleForTesting LeastRequestLoadBalancer(Helper helper, ThreadSafeRandom random) { - this.helper = checkNotNull(helper, "helper"); + super(helper); this.random = checkNotNull(random, "random"); } + @Override + protected SubchannelPicker getSubchannelPicker(Map childPickers) { + throw new UnsupportedOperationException( + "LeastRequestLoadBalancer uses its ChildLbStates, not these child pickers directly"); + } + @Override public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - if (resolvedAddresses.getAddresses().isEmpty()) { - handleNameResolutionError(Status.UNAVAILABLE.withDescription( - "NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses() - + ", attrs=" + resolvedAddresses.getAttributes())); - return false; - } + // Need to update choiceCount before calling super so that the updateBalancingState call has the + // new value. However, if the update fails we need to revert it. + int oldChoiceCount = choiceCount; LeastRequestConfig config = (LeastRequestConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - // Config may be null if least_request is used outside xDS if (config != null) { choiceCount = config.choiceCount; } - List servers = resolvedAddresses.getAddresses(); - Set currentAddrs = subchannels.keySet(); - Map latestAddrs = stripAttrs(servers); - Set removedAddrs = setsDifference(currentAddrs, latestAddrs.keySet()); - - for (Map.Entry latestEntry : - latestAddrs.entrySet()) { - EquivalentAddressGroup strippedAddressGroup = latestEntry.getKey(); - EquivalentAddressGroup originalAddressGroup = latestEntry.getValue(); - Subchannel existingSubchannel = subchannels.get(strippedAddressGroup); - if (existingSubchannel != null) { - // EAG's Attributes may have changed. - existingSubchannel.updateAddresses(Collections.singletonList(originalAddressGroup)); - continue; - } - // Create new subchannels for new addresses. - Attributes.Builder subchannelAttrs = Attributes.newBuilder() - .set(STATE_INFO, new Ref<>(ConnectivityStateInfo.forNonError(IDLE))) - // Used to track the in flight requests on this particular subchannel - .set(IN_FLIGHTS, new AtomicInteger(0)); - - final Subchannel subchannel = checkNotNull( - helper.createSubchannel(CreateSubchannelArgs.newBuilder() - .setAddresses(originalAddressGroup) - .setAttributes(subchannelAttrs.build()) - .build()), - "subchannel"); - subchannel.start(new SubchannelStateListener() { - @Override - public void onSubchannelState(ConnectivityStateInfo state) { - processSubchannelState(subchannel, state); - } - }); - subchannels.put(strippedAddressGroup, subchannel); - subchannel.requestConnection(); - } - - ArrayList removedSubchannels = new ArrayList<>(); - for (EquivalentAddressGroup addressGroup : removedAddrs) { - removedSubchannels.add(subchannels.remove(addressGroup)); - } - - // Update the picker before shutting down the subchannels, to reduce the chance of the race - // between picking a subchannel and shutting it down. - updateBalancingState(); - - // Shutdown removed subchannels - for (Subchannel removedSubchannel : removedSubchannels) { - shutdownSubchannel(removedSubchannel); - } - - return true; - } + boolean successfulUpdate = super.acceptResolvedAddresses(resolvedAddresses); - @Override - public void handleNameResolutionError(Status error) { - if (currentState != READY) { - updateBalancingState(TRANSIENT_FAILURE, new EmptyPicker(error)); + if (!successfulUpdate) { + choiceCount = oldChoiceCount; } - } - private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { - if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) { - return; - } - if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) { - helper.refreshNameResolution(); - } - if (stateInfo.getState() == IDLE) { - subchannel.requestConnection(); - } - Ref subchannelStateRef = getSubchannelStateInfoRef(subchannel); - if (subchannelStateRef.value.getState().equals(TRANSIENT_FAILURE)) { - if (stateInfo.getState().equals(CONNECTING) || stateInfo.getState().equals(IDLE)) { - return; - } - } - subchannelStateRef.value = stateInfo; - updateBalancingState(); - } - - private void shutdownSubchannel(Subchannel subchannel) { - subchannel.shutdown(); - getSubchannelStateInfoRef(subchannel).value = - ConnectivityStateInfo.forNonError(SHUTDOWN); + return successfulUpdate; } @Override - public void shutdown() { - for (Subchannel subchannel : getSubchannels()) { - shutdownSubchannel(subchannel); - } - subchannels.clear(); + protected SubchannelPicker getErrorPicker(Status error) { + return new EmptyPicker(error); } - private static final Status EMPTY_OK = Status.OK.withDescription("no subchannels ready"); - /** * Updates picker with the list of active subchannels (state == READY). + * + *

+ * If no active subchannels exist, but some are in TRANSIENT_FAILURE then returns a picker + * with all of the children in TF so that the application code will get an error from a varying + * random one when it tries to get a subchannel. + *

*/ @SuppressWarnings("ReferenceEquality") - private void updateBalancingState() { - List activeList = filterNonFailingSubchannels(getSubchannels()); + @Override + protected void updateOverallBalancingState() { + List activeList = getReadyChildren(); if (activeList.isEmpty()) { // No READY subchannels, determine aggregate state and error status boolean isConnecting = false; - Status aggStatus = EMPTY_OK; - for (Subchannel subchannel : getSubchannels()) { - ConnectivityStateInfo stateInfo = getSubchannelStateInfoRef(subchannel).value; - // This subchannel IDLE is not because of channel IDLE_TIMEOUT, - // in which case LB is already shutdown. - // LRLB will request connection immediately on subchannel IDLE. - if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) { + List childrenInTf = new ArrayList<>(); + for (ChildLbState childLbState : getChildLbStates()) { + ConnectivityState state = childLbState.getCurrentState(); + if (state == CONNECTING || state == IDLE) { isConnecting = true; - } - if (aggStatus == EMPTY_OK || !aggStatus.isOk()) { - aggStatus = stateInfo.getStatus(); + } else if (state == TRANSIENT_FAILURE) { + childrenInTf.add(childLbState); } } - updateBalancingState(isConnecting ? CONNECTING : TRANSIENT_FAILURE, - // If all subchannels are TRANSIENT_FAILURE, return the Status associated with - // an arbitrary subchannel, otherwise return OK. - new EmptyPicker(aggStatus)); + if (isConnecting) { + updateBalancingState(CONNECTING, EMPTY_LR_PICKER); + } else { + // Give it all the failing children and let it randomly pick among them + updateBalancingState(TRANSIENT_FAILURE, + new ReadyPicker(childrenInTf, choiceCount, random)); + } } else { updateBalancingState(READY, new ReadyPicker(activeList, choiceCount, random)); } } - private void updateBalancingState(ConnectivityState state, LeastRequestPicker picker) { - if (state != currentState || !picker.isEquivalentTo(currentPicker)) { - helper.updateBalancingState(state, picker); - currentState = state; - currentPicker = picker; - } + @Override + protected ChildLbState createChildLbState(Object key, Object policyConfig, + SubchannelPicker initialPicker) { + return new LeastRequestLbState(key, pickFirstLbProvider, policyConfig, initialPicker); } - /** - * Filters out non-ready subchannels. - */ - private static List filterNonFailingSubchannels( - Collection subchannels) { - List readySubchannels = new ArrayList<>(subchannels.size()); - for (Subchannel subchannel : subchannels) { - if (isReady(subchannel)) { - readySubchannels.add(subchannel); - } + private void updateBalancingState(ConnectivityState state, LeastRequestPicker picker) { + if (state != currentConnectivityState || !picker.isEquivalentTo(currentPicker)) { + super.updateHelperBalancingState(state, picker); + currentConnectivityState = state; + currentPicker = picker; } - return readySubchannels; } /** - * Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and - * remove all attributes. The values are the original EAGs. + * This should ONLY be used by tests. */ - private static Map stripAttrs( - List groupList) { - Map addrs = new HashMap<>(groupList.size() * 2); - for (EquivalentAddressGroup group : groupList) { - addrs.put(stripAttrs(group), group); - } - return addrs; - } - - private static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { - return new EquivalentAddressGroup(eag.getAddresses()); - } - @VisibleForTesting - Collection getSubchannels() { - return subchannels.values(); + void setResolvingAddresses(boolean newValue) { + super.resolvingAddresses = newValue; } - private static Ref getSubchannelStateInfoRef( - Subchannel subchannel) { - return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO"); - } - - private static AtomicInteger getInFlights(Subchannel subchannel) { - return checkNotNull(subchannel.getAttributes().get(IN_FLIGHTS), "IN_FLIGHTS"); + // Expose for tests in this package. + @Override + protected Collection getChildLbStates() { + return super.getChildLbStates(); } - // package-private to avoid synthetic access - static boolean isReady(Subchannel subchannel) { - return getSubchannelStateInfoRef(subchannel).value.getState() == READY; + // Expose for tests in this package. + @Override + protected ChildLbState getChildLbState(Object key) { + return super.getChildLbState(key); } - private static Set setsDifference(Set a, Set b) { - Set aCopy = new HashSet<>(a); - aCopy.removeAll(b); - return aCopy; + // Expose for tests in this package. + private static AtomicInteger getInFlights(ChildLbState childLbState) { + return ((LeastRequestLbState)childLbState).activeRequests; } - // Only subclasses are ReadyPicker or EmptyPicker - private abstract static class LeastRequestPicker extends SubchannelPicker { + @VisibleForTesting + abstract static class LeastRequestPicker extends SubchannelPicker { abstract boolean isEquivalentTo(LeastRequestPicker picker); } @VisibleForTesting static final class ReadyPicker extends LeastRequestPicker { - private final List list; // non-empty + private final List childLbStates; // non-empty private final int choiceCount; private final ThreadSafeRandom random; - ReadyPicker(List list, int choiceCount, ThreadSafeRandom random) { - checkArgument(!list.isEmpty(), "empty list"); - this.list = list; + ReadyPicker(List childLbStates, int choiceCount, ThreadSafeRandom random) { + checkArgument(!childLbStates.isEmpty(), "empty list"); + this.childLbStates = childLbStates; this.choiceCount = choiceCount; this.random = checkNotNull(random, "random"); } @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - final Subchannel subchannel = nextSubchannel(); - final OutstandingRequestsTracingFactory factory = - new OutstandingRequestsTracingFactory(getInFlights(subchannel)); - return PickResult.withSubchannel(subchannel, factory); + final ChildLbState childLbState = nextChildToUse(); + PickResult childResult = childLbState.getCurrentPicker().pickSubchannel(args); + + if (!childResult.getStatus().isOk() || childResult.getSubchannel() == null) { + return childResult; + } + + if (childResult.getStreamTracerFactory() != null) { + // Already wrapped, so just use the current picker for selected child + return childResult; + } else { + // Wrap the subchannel + OutstandingRequestsTracingFactory factory = + new OutstandingRequestsTracingFactory(getInFlights(childLbState)); + return PickResult.withSubchannel(childResult.getSubchannel(), factory); + } } @Override public String toString() { return MoreObjects.toStringHelper(ReadyPicker.class) - .add("list", list) + .add("list", childLbStates) .add("choiceCount", choiceCount) .toString(); } - private Subchannel nextSubchannel() { - Subchannel candidate = list.get(random.nextInt(list.size())); + private ChildLbState nextChildToUse() { + ChildLbState candidate = childLbStates.get(random.nextInt(childLbStates.size())); for (int i = 0; i < choiceCount - 1; ++i) { - Subchannel sampled = list.get(random.nextInt(list.size())); + ChildLbState sampled = childLbStates.get(random.nextInt(childLbStates.size())); if (getInFlights(sampled).get() < getInFlights(candidate).get()) { candidate = sampled; } @@ -340,10 +240,11 @@ private Subchannel nextSubchannel() { } @VisibleForTesting - List getList() { - return list; + List getChildLbStates() { + return childLbStates; } + @VisibleForTesting @Override boolean isEquivalentTo(LeastRequestPicker picker) { if (!(picker instanceof ReadyPicker)) { @@ -352,7 +253,8 @@ boolean isEquivalentTo(LeastRequestPicker picker) { ReadyPicker other = (ReadyPicker) picker; // the lists cannot contain duplicate subchannels return other == this - || ((list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list)) + || ((childLbStates.size() == other.childLbStates.size() && new HashSet<>( + childLbStates).containsAll(other.childLbStates)) && choiceCount == other.choiceCount); } } @@ -381,16 +283,10 @@ boolean isEquivalentTo(LeastRequestPicker picker) { public String toString() { return MoreObjects.toStringHelper(EmptyPicker.class).add("status", status).toString(); } - } - /** - * A lighter weight Reference than AtomicReference. - */ - static final class Ref { - T value; - - Ref(T value) { - this.value = value; + @VisibleForTesting + Status getStatus() { + return status; } } @@ -435,4 +331,17 @@ public String toString() { .toString(); } } + + protected class LeastRequestLbState extends ChildLbState { + private final AtomicInteger activeRequests = new AtomicInteger(0); + + public LeastRequestLbState(Object key, LoadBalancerProvider policyProvider, + Object childConfig, SubchannelPicker initialPicker) { + super(key, policyProvider, childConfig, initialPicker); + } + + int getActiveRequests() { + return activeRequests.get(); + } + } } diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index 833683729c2..17fed2ac023 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -17,17 +17,20 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkElementIndex; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Deadline.Ticker; import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.SynchronizationContext; @@ -40,11 +43,13 @@ import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener; import io.grpc.xds.orca.OrcaPerRequestUtil; import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -90,6 +95,14 @@ public WeightedRoundRobinLoadBalancer(WrrHelper helper, Ticker ticker, Random ra this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, random); } + @Override + protected ChildLbState createChildLbState(Object key, Object policyConfig, + SubchannelPicker initialPicker) { + ChildLbState childLbState = new WeightedChildLbState(key, pickFirstLbProvider, policyConfig, + initialPicker); + return childLbState; + } + @Override public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (resolvedAddresses.getLoadBalancingPolicyConfig() == null) { @@ -111,9 +124,96 @@ public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { } @Override - public RoundRobinPicker createReadyPicker(List activeList) { - return new WeightedRoundRobinPicker(activeList, config.enableOobLoadReport, - config.errorUtilizationPenalty); + public RoundRobinPicker createReadyPicker(Collection activeList) { + return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), + config.enableOobLoadReport, config.errorUtilizationPenalty); + } + + // Expose for tests in this package. + @Override + protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) { + return super.getChildLbStateEag(eag); + } + + @VisibleForTesting + final class WeightedChildLbState extends ChildLbState { + + private final Set subchannels = new HashSet<>(); + private volatile long lastUpdated; + private volatile long nonEmptySince; + private volatile double weight = 0; + + private OrcaReportListener orcaReportListener; + + public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig, + SubchannelPicker initialPicker) { + super(key, policyProvider, childConfig, initialPicker); + } + + private double getWeight() { + if (config == null) { + return 0; + } + long now = ticker.nanoTime(); + if (now - lastUpdated >= config.weightExpirationPeriodNanos) { + nonEmptySince = infTime; + return 0; + } else if (now - nonEmptySince < config.blackoutPeriodNanos + && config.blackoutPeriodNanos > 0) { + return 0; + } else { + return weight; + } + } + + public void addSubchannel(WrrSubchannel wrrSubchannel) { + subchannels.add(wrrSubchannel); + } + + public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty) { + if (orcaReportListener != null + && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty) { + return orcaReportListener; + } + orcaReportListener = new OrcaReportListener(errorUtilizationPenalty); + return orcaReportListener; + } + + public void removeSubchannel(WrrSubchannel wrrSubchannel) { + subchannels.remove(wrrSubchannel); + } + + final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener { + private final float errorUtilizationPenalty; + + OrcaReportListener(float errorUtilizationPenalty) { + this.errorUtilizationPenalty = errorUtilizationPenalty; + } + + @Override + public void onLoadReport(MetricReport report) { + double newWeight = 0; + // Prefer application utilization and fallback to CPU utilization if unset. + double utilization = + report.getApplicationUtilization() > 0 ? report.getApplicationUtilization() + : report.getCpuUtilization(); + if (utilization > 0 && report.getQps() > 0) { + double penalty = 0; + if (report.getEps() > 0 && errorUtilizationPenalty > 0) { + penalty = report.getEps() / report.getQps() * errorUtilizationPenalty; + } + newWeight = report.getQps() / (utilization + penalty); + } + if (newWeight == 0) { + return; + } + if (nonEmptySince == infTime) { + nonEmptySince = ticker.nanoTime(); + } + lastUpdated = ticker.nanoTime(); + weight = newWeight; + } + } } private final class UpdateWeightTask implements Runnable { @@ -128,16 +228,18 @@ public void run() { } private void afterAcceptAddresses() { - for (Subchannel subchannel : getSubchannels()) { - WrrSubchannel weightedSubchannel = (WrrSubchannel) subchannel; - if (config.enableOobLoadReport) { - OrcaOobUtil.setListener(weightedSubchannel, - weightedSubchannel.new OrcaReportListener(config.errorUtilizationPenalty), - OrcaOobUtil.OrcaReportingConfig.newBuilder() - .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS) - .build()); - } else { - OrcaOobUtil.setListener(weightedSubchannel, null, null); + for (ChildLbState child : getChildLbStates()) { + WeightedChildLbState wChild = (WeightedChildLbState) child; + for (WrrSubchannel weightedSubchannel : wChild.subchannels) { + if (config.enableOobLoadReport) { + OrcaOobUtil.setListener(weightedSubchannel, + wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty), + OrcaOobUtil.OrcaReportingConfig.newBuilder() + .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS) + .build()); + } else { + OrcaOobUtil.setListener(weightedSubchannel, null, null); + } } } } @@ -169,105 +271,69 @@ protected Helper delegate() { @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { - return wrr.new WrrSubchannel(delegate().createSubchannel(args)); + checkElementIndex(0, args.getAddresses().size(), "Empty address group"); + WeightedChildLbState childLbState = + (WeightedChildLbState) wrr.getChildLbStateEag(args.getAddresses().get(0)); + return wrr.new WrrSubchannel(delegate().createSubchannel(args), childLbState); } } @VisibleForTesting final class WrrSubchannel extends ForwardingSubchannel { private final Subchannel delegate; - private volatile long lastUpdated; - private volatile long nonEmptySince; - private volatile double weight; + private final WeightedChildLbState owner; - WrrSubchannel(Subchannel delegate) { + WrrSubchannel(Subchannel delegate, WeightedChildLbState owner) { this.delegate = checkNotNull(delegate, "delegate"); + this.owner = checkNotNull(owner, "owner"); } @Override public void start(SubchannelStateListener listener) { + owner.addSubchannel(this); delegate().start(new SubchannelStateListener() { @Override public void onSubchannelState(ConnectivityStateInfo newState) { if (newState.getState().equals(ConnectivityState.READY)) { - nonEmptySince = infTime; + owner.nonEmptySince = infTime; } listener.onSubchannelState(newState); } }); } - private double getWeight() { - if (config == null) { - return 0; - } - long now = ticker.nanoTime(); - if (now - lastUpdated >= config.weightExpirationPeriodNanos) { - nonEmptySince = infTime; - return 0; - } else if (now - nonEmptySince < config.blackoutPeriodNanos - && config.blackoutPeriodNanos > 0) { - return 0; - } else { - return weight; - } - } - @Override protected Subchannel delegate() { return delegate; } - final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener { - private final float errorUtilizationPenalty; - - OrcaReportListener(float errorUtilizationPenalty) { - this.errorUtilizationPenalty = errorUtilizationPenalty; - } - - @Override - public void onLoadReport(MetricReport report) { - double newWeight = 0; - // Prefer application utilization and fallback to CPU utilization if unset. - double utilization = - report.getApplicationUtilization() > 0 ? report.getApplicationUtilization() - : report.getCpuUtilization(); - if (utilization > 0 && report.getQps() > 0) { - double penalty = 0; - if (report.getEps() > 0 && errorUtilizationPenalty > 0) { - penalty = report.getEps() / report.getQps() * errorUtilizationPenalty; - } - newWeight = report.getQps() / (utilization + penalty); - } - if (newWeight == 0) { - return; - } - if (nonEmptySince == infTime) { - nonEmptySince = ticker.nanoTime(); - } - lastUpdated = ticker.nanoTime(); - weight = newWeight; - } + @Override + public void shutdown() { + super.shutdown(); + owner.removeSubchannel(this); } } @VisibleForTesting final class WeightedRoundRobinPicker extends RoundRobinPicker { - private final List list; + private final List children; private final Map subchannelToReportListenerMap = new HashMap<>(); private final boolean enableOobLoadReport; private final float errorUtilizationPenalty; private volatile StaticStrideScheduler scheduler; - WeightedRoundRobinPicker(List list, boolean enableOobLoadReport, + WeightedRoundRobinPicker(List children, boolean enableOobLoadReport, float errorUtilizationPenalty) { - checkNotNull(list, "list"); - Preconditions.checkArgument(!list.isEmpty(), "empty list"); - this.list = list; - for (Subchannel subchannel : list) { - this.subchannelToReportListenerMap.put(subchannel, - ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty)); + checkNotNull(children, "children"); + Preconditions.checkArgument(!children.isEmpty(), "empty child list"); + this.children = children; + for (ChildLbState child : children) { + WeightedChildLbState wChild = (WeightedChildLbState) child; + for (WrrSubchannel subchannel : wChild.subchannels) { + this.subchannelToReportListenerMap + .put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty)); + } } this.enableOobLoadReport = enableOobLoadReport; this.errorUtilizationPenalty = errorUtilizationPenalty; @@ -276,22 +342,24 @@ final class WeightedRoundRobinPicker extends RoundRobinPicker { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - Subchannel subchannel = list.get(scheduler.pick()); + ChildLbState childLbState = children.get(scheduler.pick()); + WeightedChildLbState wChild = (WeightedChildLbState) childLbState; + PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args); + Subchannel subchannel = pickResult.getSubchannel(); if (!enableOobLoadReport) { return PickResult.withSubchannel(subchannel, - OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( subchannelToReportListenerMap.getOrDefault(subchannel, - ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty)))); + wChild.getOrCreateOrcaListener(errorUtilizationPenalty)))); } else { return PickResult.withSubchannel(subchannel); } } private void updateWeight() { - float[] newWeights = new float[list.size()]; - for (int i = 0; i < list.size(); i++) { - WrrSubchannel subchannel = (WrrSubchannel) list.get(i); - double newWeight = subchannel.getWeight(); + float[] newWeights = new float[children.size()]; + for (int i = 0; i < children.size(); i++) { + double newWeight = ((WeightedChildLbState)children.get(i)).getWeight(); newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; } this.scheduler = new StaticStrideScheduler(newWeights, sequence); @@ -302,12 +370,12 @@ public String toString() { return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class) .add("enableOobLoadReport", enableOobLoadReport) .add("errorUtilizationPenalty", errorUtilizationPenalty) - .add("list", list).toString(); + .add("list", children).toString(); } @VisibleForTesting - List getList() { - return list; + List getChildren() { + return children; } @Override @@ -322,7 +390,8 @@ public boolean isEquivalentTo(RoundRobinPicker picker) { // the lists cannot contain duplicate subchannels return enableOobLoadReport == other.enableOobLoadReport && Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0 - && list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list); + && children.size() == other.children.size() && new HashSet<>( + children).containsAll(other.children); } } @@ -504,11 +573,13 @@ private Builder() { } + @SuppressWarnings("UnusedReturnValue") Builder setBlackoutPeriodNanos(long blackoutPeriodNanos) { this.blackoutPeriodNanos = blackoutPeriodNanos; return this; } + @SuppressWarnings("UnusedReturnValue") Builder setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos) { this.weightExpirationPeriodNanos = weightExpirationPeriodNanos; return this; diff --git a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java index c90a9f58d31..32e905225d2 100644 --- a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java @@ -202,7 +202,9 @@ public interface OrcaOobReportListener { */ public static void setListener(Subchannel subchannel, OrcaOobReportListener listener, OrcaReportingConfig config) { - SubchannelImpl orcaSubchannel = subchannel.getAttributes().get(ORCA_REPORTING_STATE_KEY); + Attributes attributes = subchannel.getAttributes(); + SubchannelImpl orcaSubchannel = + (attributes == null) ? null : attributes.get(ORCA_REPORTING_STATE_KEY); if (orcaSubchannel == null) { throw new IllegalArgumentException("Subchannel does not have orca Out-Of-Band stream enabled." + " Try to use a subchannel created by OrcaOobUtil.OrcaHelper."); @@ -241,7 +243,9 @@ protected Helper delegate() { public Subchannel createSubchannel(CreateSubchannelArgs args) { syncContext.throwIfNotInThisSynchronizationContext(); Subchannel subchannel = super.createSubchannel(args); - SubchannelImpl orcaSubchannel = subchannel.getAttributes().get(ORCA_REPORTING_STATE_KEY); + Attributes attributes = subchannel.getAttributes(); + SubchannelImpl orcaSubchannel = + (attributes == null) ? null : attributes.get(ORCA_REPORTING_STATE_KEY); OrcaReportingState orcaState; if (orcaSubchannel == null) { // Only the first load balancing policy requesting ORCA reports instantiates an diff --git a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java index 33676367866..aeae59c6122 100644 --- a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java @@ -22,17 +22,15 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.xds.LeastRequestLoadBalancer.IN_FLIGHTS; -import static io.grpc.xds.LeastRequestLoadBalancer.STATE_INFO; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -60,10 +58,13 @@ import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.util.AbstractTestHelper; +import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.xds.LeastRequestLoadBalancer.EmptyPicker; import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; +import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestLbState; +import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestPicker; import io.grpc.xds.LeastRequestLoadBalancer.ReadyPicker; -import io.grpc.xds.LeastRequestLoadBalancer.Ref; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Arrays; @@ -71,6 +72,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -81,10 +83,8 @@ import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; -import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; -import org.mockito.stubbing.Answer; /** Unit test for {@link LeastRequestLoadBalancer}. */ @RunWith(JUnit4.class) @@ -96,8 +96,6 @@ public class LeastRequestLoadBalancerTest { private LeastRequestLoadBalancer loadBalancer; private final List servers = Lists.newArrayList(); private final Map, Subchannel> subchannels = Maps.newLinkedHashMap(); - private final Map subchannelStateListeners = - Maps.newLinkedHashMap(); private final Attributes affinity = Attributes.newBuilder().set(MAJOR_KEY, "I got the keys").build(); @@ -107,8 +105,9 @@ public class LeastRequestLoadBalancerTest { private ArgumentCaptor stateCaptor; @Captor private ArgumentCaptor createArgsCaptor; - @Mock - private Helper mockHelper; + private final TestHelper testHelperInstance = new TestHelper(); + private Helper helper = mock(Helper.class, delegatesTo(testHelperInstance)); + @Mock private ThreadSafeRandom mockRandom; @@ -121,31 +120,9 @@ public void setUp() { SocketAddress addr = new FakeSocketAddress("server" + i); EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); servers.add(eag); - Subchannel sc = mock(Subchannel.class); - subchannels.put(Arrays.asList(eag), sc); } - when(mockHelper.createSubchannel(any(CreateSubchannelArgs.class))) - .then(new Answer() { - @Override - public Subchannel answer(InvocationOnMock invocation) throws Throwable { - CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; - final Subchannel subchannel = subchannels.get(args.getAddresses()); - when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); - when(subchannel.getAttributes()).thenReturn(args.getAttributes()); - doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - subchannelStateListeners.put( - subchannel, (SubchannelStateListener) invocation.getArguments()[0]); - return null; - } - }).when(subchannel).start(any(SubchannelStateListener.class)); - return subchannel; - } - }); - loadBalancer = new LeastRequestLoadBalancer(mockHelper, mockRandom); + loadBalancer = new LeastRequestLoadBalancer(helper, mockRandom); } @After @@ -156,13 +133,13 @@ public void tearDown() throws Exception { @Test public void pickAfterResolved() throws Exception { - final Subchannel readySubchannel = subchannels.values().iterator().next(); boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); assertThat(addressesAccepted).isTrue(); + final Subchannel readySubchannel = subchannels.values().iterator().next(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); - verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture()); + verify(helper, times(3)).createSubchannel(createArgsCaptor.capture()); List> capturedAddrs = new ArrayList<>(); for (CreateSubchannelArgs arg : createArgsCaptor.getAllValues()) { capturedAddrs.add(arg.getAddresses()); @@ -174,22 +151,18 @@ public void pickAfterResolved() throws Exception { verify(subchannel, never()).shutdown(); } - verify(mockHelper, times(2)) + verify(helper, times(2)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); assertEquals(CONNECTING, stateCaptor.getAllValues().get(0)); assertEquals(READY, stateCaptor.getAllValues().get(1)); assertThat(getList(pickerCaptor.getValue())).containsExactly(readySubchannel); - verifyNoMoreInteractions(mockHelper); + verifyNoMoreInteractions(helper); } @Test public void pickAfterResolvedUpdatedHosts() throws Exception { - Subchannel removedSubchannel = mock(Subchannel.class); - Subchannel oldSubchannel = mock(Subchannel.class); - Subchannel newSubchannel = mock(Subchannel.class); - Attributes.Key key = Attributes.Key.create("check-that-it-is-propagated"); FakeSocketAddress removedAddr = new FakeSocketAddress("removed"); EquivalentAddressGroup removedEag = new EquivalentAddressGroup(removedAddr); @@ -201,33 +174,33 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { EquivalentAddressGroup newEag = new EquivalentAddressGroup( newAddr, Attributes.newBuilder().set(key, "newattr").build()); - subchannels.put(Collections.singletonList(removedEag), removedSubchannel); - subchannels.put(Collections.singletonList(oldEag1), oldSubchannel); - subchannels.put(Collections.singletonList(newEag), newSubchannel); - List currentServers = Lists.newArrayList(removedEag, oldEag1); - InOrder inOrder = inOrder(mockHelper); + InOrder inOrder = inOrder(helper); boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(currentServers).setAttributes(affinity) .build()); assertThat(addressesAccepted).isTrue(); + Subchannel removedSubchannel = getSubchannel(removedEag); + Subchannel oldSubchannel = getSubchannel(oldEag1); + SubchannelStateListener removedListener = + testHelperInstance.getSubchannelStateListeners() + .get(testHelperInstance.getRealForMockSubChannel(removedSubchannel)); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(READY)); deliverSubchannelState(oldSubchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); + inOrder.verify(helper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); SubchannelPicker picker = pickerCaptor.getValue(); assertThat(getList(picker)).containsExactly(removedSubchannel, oldSubchannel); verify(removedSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).requestConnection(); - assertThat(loadBalancer.getSubchannels()).containsExactly(removedSubchannel, - oldSubchannel); + assertThat(getChildEags(loadBalancer)).containsExactly(removedEag, oldEag1); // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); @@ -236,81 +209,105 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { ResolvedAddresses.newBuilder().setAddresses(latestServers).setAttributes(affinity).build()); assertThat(addressesAccepted).isTrue(); + Subchannel newSubchannel = getSubchannel(newEag); + verify(newSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).updateAddresses(Arrays.asList(oldEag2)); verify(removedSubchannel, times(1)).shutdown(); - deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(SHUTDOWN)); + removedListener.onSubchannelState(ConnectivityStateInfo.forNonError(SHUTDOWN)); deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(loadBalancer.getSubchannels()).containsExactly(oldSubchannel, - newSubchannel); + assertThat(getChildEags(loadBalancer)).containsExactly(oldEag2, newEag); + + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + inOrder.verify(helper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); + + assertThat(getList(pickerCaptor.getValue())).containsExactly(oldSubchannel, newSubchannel); + + verifyNoMoreInteractions(helper); + } + + private Subchannel getSubchannel(EquivalentAddressGroup removedEag) { + return subchannels.get(Collections.singletonList(removedEag)); + } + + private Subchannel getSubchannel(ChildLbState childLbState) { + return subchannels.get(Collections.singletonList(childLbState.getEag())); + } - verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); + private static List getChildEags(LeastRequestLoadBalancer loadBalancer) { + return loadBalancer.getChildLbStates().stream() + .map(ChildLbState::getEag) + // .map(EquivalentAddressGroup::getAddresses) + .collect(Collectors.toList()); + } - picker = pickerCaptor.getValue(); - assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel); + private List getSubchannels(LeastRequestLoadBalancer lb) { + return lb.getChildLbStates().stream() + .map(this::getSubchannel) + .collect(Collectors.toList()); + } - verifyNoMoreInteractions(mockHelper); + private LeastRequestLbState getChildLbState(PickResult pickResult) { + EquivalentAddressGroup eag = pickResult.getSubchannel().getAddresses(); + return (LeastRequestLbState) loadBalancer.getChildLbState(eag); } @Test public void pickAfterStateChange() throws Exception { - InOrder inOrder = inOrder(mockHelper); + InOrder inOrder = inOrder(helper); boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); assertThat(addressesAccepted).isTrue(); - Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); - Ref subchannelStateInfo = subchannel.getAttributes().get( - STATE_INFO); + ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); + Subchannel subchannel = getSubchannel(childLbState); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); - assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE)); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); + inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class); - assertThat(subchannelStateInfo.value).isEqualTo( - ConnectivityStateInfo.forNonError(READY)); + assertThat(childLbState.getCurrentState()).isEqualTo(READY); Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"); deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); - assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); - assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error); - inOrder.verify(mockHelper).refreshNameResolution(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); + assertThat(childLbState.getCurrentPicker().toString()).contains(error.toString()); + inOrder.verify(helper).refreshNameResolution(); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); - inOrder.verify(mockHelper).refreshNameResolution(); - assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); - assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error); + inOrder.verify(helper).refreshNameResolution(); + assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); + assertThat(childLbState.getCurrentPicker().toString()).contains(error.toString()); verify(subchannel, times(2)).requestConnection(); - verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verifyNoMoreInteractions(mockHelper); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verifyNoMoreInteractions(helper); } @Test public void pickAfterConfigChange() { final LeastRequestConfig oldConfig = new LeastRequestConfig(4); final LeastRequestConfig newConfig = new LeastRequestConfig(6); - final Subchannel readySubchannel = subchannels.values().iterator().next(); boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity) .setLoadBalancingPolicyConfig(oldConfig).build()); assertThat(addressesAccepted).isTrue(); + final Subchannel readySubchannel = subchannels.values().iterator().next(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); - verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(mockHelper, times(2)) + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper, times(2)) .updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture()); - // At this point it should use a ReadyPicker with oldConfig + // At this point it should use a ReadyPicker with oldConfig and 1 ready subchannel pickerCaptor.getValue().pickSubchannel(mockArgs); verify(mockRandom, times(oldConfig.choiceCount)).nextInt(1); @@ -318,26 +315,26 @@ public void pickAfterConfigChange() { ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity) .setLoadBalancingPolicyConfig(newConfig).build()); assertThat(addressesAccepted).isTrue(); - verify(mockHelper, times(3)) + verify(helper, times(3)) .updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture()); // At this point it should use a ReadyPicker with newConfig pickerCaptor.getValue().pickSubchannel(mockArgs); verify(mockRandom, times(oldConfig.choiceCount + newConfig.choiceCount)).nextInt(1); - verifyNoMoreInteractions(mockHelper); + verifyNoMoreInteractions(helper); } @Test public void ignoreShutdownSubchannelStateChange() { - InOrder inOrder = inOrder(mockHelper); + InOrder inOrder = inOrder(helper); boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); assertThat(addressesAccepted).isTrue(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); loadBalancer.shutdown(); - for (Subchannel sc : loadBalancer.getSubchannels()) { + for (Subchannel sc : getSubchannels(loadBalancer)) { verify(sc).shutdown(); // When the subchannel is being shut down, a SHUTDOWN connectivity state is delivered // back to the subchannel state listener. @@ -349,71 +346,101 @@ public void ignoreShutdownSubchannelStateChange() { @Test public void stayTransientFailureUntilReady() { - InOrder inOrder = inOrder(mockHelper); + InOrder inOrder = inOrder(helper); boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); assertThat(addressesAccepted).isTrue(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); // Simulate state transitions for each subchannel individually. - for (Subchannel sc : loadBalancer.getSubchannels()) { + for (ChildLbState childLbState : loadBalancer.getChildLbStates()) { + Subchannel sc = getSubchannel(childLbState); Status error = Status.UNKNOWN.withDescription("connection broken"); - deliverSubchannelState( - sc, - ConnectivityStateInfo.forTransientFailure(error)); - inOrder.verify(mockHelper).refreshNameResolution(); - deliverSubchannelState( - sc, - ConnectivityStateInfo.forNonError(CONNECTING)); - Ref scStateInfo = sc.getAttributes().get( - STATE_INFO); - assertThat(scStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); - assertThat(scStateInfo.value.getStatus()).isEqualTo(error); + deliverSubchannelState(sc, ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(helper).refreshNameResolution(); + deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); + assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); } - inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(EmptyPicker.class)); + inOrder.verify(helper) + .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + assertThat(getStatusString((LeastRequestPicker)pickerCaptor.getValue())) + .contains("Status{code=UNKNOWN, description=connection broken"); inOrder.verifyNoMoreInteractions(); - Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); + ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); + Subchannel subchannel = getSubchannel(childLbState); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - Ref subchannelStateInfo = subchannel.getAttributes().get( - STATE_INFO); - assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); + assertThat(childLbState.getCurrentState()).isEqualTo(READY); + inOrder.verify(helper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); + + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verifyNoMoreInteractions(helper); + } + + private String getStatusString(LeastRequestPicker picker) { + if (picker == null) { + return ""; + } + + if (picker instanceof EmptyPicker) { + if (((EmptyPicker) picker).getStatus() == null) { + return ""; + } + return ((EmptyPicker) picker).getStatus().toString(); + } else if (picker instanceof ReadyPicker) { + List childLbStates = ((ReadyPicker)picker).getChildLbStates(); + if (childLbStates == null || childLbStates.isEmpty()) { + return ""; + }; + + // Note that this is dependent on PickFirst's picker toString retaining the representation + // of the status, but since it is a test and we don't want to expose this value it seems + // a reasonable tradeoff + String pickerStr = childLbStates.get(0).getCurrentPicker().toString(); + int beg = pickerStr.indexOf(", status=Status{"); + if (beg < 0) { + return ""; + } + int end = pickerStr.indexOf('}', beg); + if (end < 0) { + return ""; + } + return pickerStr.substring(beg + ", status=".length(), end + 1); + } - verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verifyNoMoreInteractions(mockHelper); + throw new IllegalArgumentException("Unrecognized picker: " + picker); } @Test public void refreshNameResolutionWhenSubchannelConnectionBroken() { - InOrder inOrder = inOrder(mockHelper); + InOrder inOrder = inOrder(helper); boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); assertThat(addressesAccepted).isTrue(); - verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); // Simulate state transitions for each subchannel individually. - for (Subchannel sc : loadBalancer.getSubchannels()) { + for (Subchannel sc : getSubchannels(loadBalancer)) { verify(sc).requestConnection(); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); Status error = Status.UNKNOWN.withDescription("connection broken"); deliverSubchannelState(sc, ConnectivityStateInfo.forTransientFailure(error)); - inOrder.verify(mockHelper).refreshNameResolution(); + inOrder.verify(helper).refreshNameResolution(); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); + inOrder.verify(helper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); // Simulate receiving go-away so READY subchannels transit to IDLE. deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(IDLE)); - inOrder.verify(mockHelper).refreshNameResolution(); + inOrder.verify(helper).refreshNameResolution(); verify(sc, times(2)).requestConnection(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); } - verifyNoMoreInteractions(mockHelper); + verifyNoMoreInteractions(helper); } @Test @@ -426,68 +453,64 @@ public void pickerLeastRequest() throws Exception { .build()); assertThat(addressesAccepted).isTrue(); - assertEquals(3, loadBalancer.getSubchannels().size()); + assertEquals(3, loadBalancer.getChildLbStates().size()); - List subchannels = Lists.newArrayList(loadBalancer.getSubchannels()); + List childLbStates = Lists.newArrayList(loadBalancer.getChildLbStates()); // Make sure all inFlight counters have started at 0 - assertEquals(0, - subchannels.get(0).getAttributes().get(IN_FLIGHTS).get()); - assertEquals(0, - subchannels.get(1).getAttributes().get(IN_FLIGHTS).get()); - assertEquals(0, - subchannels.get(2).getAttributes().get(IN_FLIGHTS).get()); - - for (Subchannel sc : subchannels) { - deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(READY)); + for (int i = 0; i < 3; i++) { + assertEquals("counter for child " + i, 0, + ((LeastRequestLbState) childLbStates.get(i)).getActiveRequests()); + } + + for (ChildLbState cs : childLbStates) { + deliverSubchannelState(getSubchannel(cs), ConnectivityStateInfo.forNonError(READY)); } // Capture the active ReadyPicker once all subchannels are READY - verify(mockHelper, times(4)) + verify(helper, times(4)) .updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class); ReadyPicker picker = (ReadyPicker) pickerCaptor.getValue(); - assertThat(picker.getList()).containsExactlyElementsIn(subchannels); + assertThat(picker.getChildLbStates()).containsExactlyElementsIn(childLbStates); // Make random return 0, then 2 for the sample indexes. - when(mockRandom.nextInt(subchannels.size())).thenReturn(0, 2); + when(mockRandom.nextInt(childLbStates.size())).thenReturn(0, 2); PickResult pickResult1 = picker.pickSubchannel(mockArgs); - verify(mockRandom, times(choiceCount)).nextInt(subchannels.size()); - assertEquals(subchannels.get(0), pickResult1.getSubchannel()); + verify(mockRandom, times(choiceCount)).nextInt(childLbStates.size()); + assertEquals(childLbStates.get(0), getChildLbState(pickResult1)); // This simulates sending the actual RPC on the picked channel ClientStreamTracer streamTracer1 = pickResult1.getStreamTracerFactory() .newClientStreamTracer(StreamInfo.newBuilder().build(), new Metadata()); streamTracer1.streamCreated(Attributes.EMPTY, new Metadata()); - assertEquals(1, - pickResult1.getSubchannel().getAttributes().get(IN_FLIGHTS).get()); + assertEquals(1, getChildLbState(pickResult1).getActiveRequests()); // For the second pick it should pick the one with lower inFlight. - when(mockRandom.nextInt(subchannels.size())).thenReturn(0, 2); + when(mockRandom.nextInt(childLbStates.size())).thenReturn(0, 2); PickResult pickResult2 = picker.pickSubchannel(mockArgs); // Since this is the second pick we expect the total random samples to be choiceCount * 2 - verify(mockRandom, times(choiceCount * 2)).nextInt(subchannels.size()); - assertEquals(subchannels.get(2), pickResult2.getSubchannel()); + verify(mockRandom, times(choiceCount * 2)).nextInt(childLbStates.size()); + assertEquals(childLbStates.get(2), getChildLbState(pickResult2)); // For the third pick we unavoidably pick subchannel with index 1. - when(mockRandom.nextInt(subchannels.size())).thenReturn(1, 1); + when(mockRandom.nextInt(childLbStates.size())).thenReturn(1, 1); PickResult pickResult3 = picker.pickSubchannel(mockArgs); - verify(mockRandom, times(choiceCount * 3)).nextInt(subchannels.size()); - assertEquals(subchannels.get(1), pickResult3.getSubchannel()); + verify(mockRandom, times(choiceCount * 3)).nextInt(childLbStates.size()); + assertEquals(childLbStates.get(1), getChildLbState(pickResult3)); // Finally ensure a finished RPC decreases inFlight streamTracer1.streamClosed(Status.OK); - assertEquals(0, - pickResult1.getSubchannel().getAttributes().get(IN_FLIGHTS).get()); + assertEquals(0, getChildLbState(pickResult1).getActiveRequests()); } @Test public void pickerEmptyList() throws Exception { SubchannelPicker picker = new EmptyPicker(Status.UNKNOWN); - assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel()); + assertNull(picker.pickSubchannel(mockArgs).getSubchannel()); assertEquals(Status.UNKNOWN, picker.pickSubchannel(mockArgs).getStatus()); } @@ -495,28 +518,37 @@ public void pickerEmptyList() throws Exception { @Test public void nameResolutionErrorWithNoChannels() throws Exception { Status error = Status.NOT_FOUND.withDescription("nameResolutionError"); + loadBalancer.setResolvingAddresses(true); loadBalancer.handleNameResolutionError(error); - verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + loadBalancer.setResolvingAddresses(false); + verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs); assertNull(pickResult.getSubchannel()); assertEquals(error, pickResult.getStatus()); - verifyNoMoreInteractions(mockHelper); + verifyNoMoreInteractions(helper); } @Test public void nameResolutionErrorWithActiveChannels() throws Exception { int choiceCount = 8; - final Subchannel readySubchannel = subchannels.values().iterator().next(); boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setLoadBalancingPolicyConfig(new LeastRequestConfig(choiceCount)) .setAddresses(servers).setAttributes(affinity).build()); assertThat(addressesAccepted).isTrue(); + final Subchannel readySubchannel = subchannels.values().iterator().next(); + deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + // TODO This test assumes that existing subchannels are left unchanged while the logic we have + // is to tell all of the children that there was a nameResolutionError. This seems to me to + // make more sense, just ignore a bad update. + loadBalancer.setResolvingAddresses(true); loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); + loadBalancer.setResolvingAddresses(false); - verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(mockHelper, times(2)) + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper, times(2)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); Iterator stateIterator = stateCaptor.getAllValues().iterator(); @@ -531,20 +563,20 @@ public void nameResolutionErrorWithActiveChannels() throws Exception { LoadBalancer.PickResult pickResult2 = pickerCaptor.getValue().pickSubchannel(mockArgs); verify(mockRandom, times(choiceCount * 2)).nextInt(1); assertEquals(readySubchannel, pickResult2.getSubchannel()); - verifyNoMoreInteractions(mockHelper); + verifyNoMoreInteractions(helper); } @Test public void subchannelStateIsolation() throws Exception { + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) + .build()); + assertThat(addressesAccepted).isTrue(); Iterator subchannelIterator = subchannels.values().iterator(); Subchannel sc1 = subchannelIterator.next(); Subchannel sc2 = subchannelIterator.next(); Subchannel sc3 = subchannelIterator.next(); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) - .build()); - assertThat(addressesAccepted).isTrue(); verify(sc1, times(1)).requestConnection(); verify(sc2, times(1)).requestConnection(); verify(sc3, times(1)).requestConnection(); @@ -555,7 +587,7 @@ public void subchannelStateIsolation() throws Exception { deliverSubchannelState(sc2, ConnectivityStateInfo.forNonError(IDLE)); deliverSubchannelState(sc3, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); - verify(mockHelper, times(6)) + verify(helper, times(6)) .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); Iterator stateIterator = stateCaptor.getAllValues().iterator(); Iterator pickers = pickerCaptor.getAllValues().iterator(); @@ -584,7 +616,7 @@ public void subchannelStateIsolation() throws Exception { public void readyPicker_emptyList() { try { // ready picker list must be non-empty - new ReadyPicker(Collections.emptyList(), 2, mockRandom); + new ReadyPicker(Collections.emptyList(), 2, mockRandom); fail(); } catch (IllegalArgumentException expected) { } @@ -596,15 +628,19 @@ public void internalPickerComparisons() { EmptyPicker emptyOk2 = new EmptyPicker(Status.OK.withDescription("different OK")); EmptyPicker emptyErr = new EmptyPicker(Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯")); - Iterator subchannelIterator = subchannels.values().iterator(); - Subchannel sc1 = subchannelIterator.next(); - Subchannel sc2 = subchannelIterator.next(); - ReadyPicker ready1 = new ReadyPicker(Arrays.asList(sc1, sc2), 2, mockRandom); - ReadyPicker ready2 = new ReadyPicker(Arrays.asList(sc1), 2, mockRandom); - ReadyPicker ready3 = new ReadyPicker(Arrays.asList(sc2, sc1), 2, mockRandom); - ReadyPicker ready4 = new ReadyPicker(Arrays.asList(sc1, sc2), 2, mockRandom); - ReadyPicker ready5 = new ReadyPicker(Arrays.asList(sc2, sc1), 2, mockRandom); - ReadyPicker ready6 = new ReadyPicker(Arrays.asList(sc2, sc1), 8, mockRandom); + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + + + Iterator iterator = loadBalancer.getChildLbStates().iterator(); + ChildLbState child1 = iterator.next(); + ChildLbState child2 = iterator.next(); + ReadyPicker ready1 = new ReadyPicker(Arrays.asList(child1, child2), 2, mockRandom); + ReadyPicker ready2 = new ReadyPicker(Arrays.asList(child1), 2, mockRandom); + ReadyPicker ready3 = new ReadyPicker(Arrays.asList(child2, child1), 2, mockRandom); + ReadyPicker ready4 = new ReadyPicker(Arrays.asList(child1, child2), 2, mockRandom); + ReadyPicker ready5 = new ReadyPicker(Arrays.asList(child2, child1), 2, mockRandom); + ReadyPicker ready6 = new ReadyPicker(Arrays.asList(child2, child1), 8, mockRandom); assertTrue(emptyOk1.isEquivalentTo(emptyOk2)); assertFalse(emptyOk1.isEquivalentTo(emptyErr)); @@ -623,16 +659,22 @@ public void emptyAddresses() { ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setAttributes(affinity) - .build())).isFalse(); + .build())) + .isFalse(); } - private static List getList(SubchannelPicker picker) { - return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() : - Collections.emptyList(); + private List getList(SubchannelPicker picker) { + if (picker instanceof ReadyPicker) { + return ((ReadyPicker) picker).getChildLbStates().stream() + .map(this::getSubchannel) + .collect(Collectors.toList()); + } else { + return Collections.emptyList(); + } } private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { - subchannelStateListeners.get(subchannel).onSubchannelState(newState); + testHelperInstance.deliverSubchannelState(subchannel, newState); } private static class FakeSocketAddress extends SocketAddress { @@ -647,4 +689,12 @@ public String toString() { return "FakeSocketAddress-" + name; } } + + private class TestHelper extends AbstractTestHelper { + + @Override + public Map, Subchannel> getSubchannelMap() { + return subchannels; + } + } } diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index ac08f69f88c..c59ad1318e2 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -17,11 +17,10 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -35,7 +34,6 @@ import com.google.protobuf.Duration; import io.grpc.Attributes; import io.grpc.Channel; -import io.grpc.ChannelLogger; import io.grpc.ClientCall; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; @@ -50,12 +48,15 @@ import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; +import io.grpc.internal.TestUtils; import io.grpc.services.InternalCallMetricRecorder; import io.grpc.services.MetricReport; +import io.grpc.util.AbstractTestHelper; +import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.xds.WeightedRoundRobinLoadBalancer.StaticStrideScheduler; +import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedChildLbState; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker; -import io.grpc.xds.WeightedRoundRobinLoadBalancer.WrrSubchannel; import java.net.SocketAddress; import java.util.Arrays; import java.util.HashMap; @@ -67,6 +68,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Before; @@ -87,8 +89,8 @@ public class WeightedRoundRobinLoadBalancerTest { @Rule public final MockitoRule mockito = MockitoJUnit.rule(); - @Mock - Helper helper; + private final TestHelper testHelperInstance = new TestHelper(); + private Helper helper = mock(Helper.class, delegatesTo(testHelperInstance)); @Mock private LoadBalancer.PickSubchannelArgs mockArgs; @@ -99,9 +101,8 @@ public class WeightedRoundRobinLoadBalancerTest { private ArgumentCaptor pickerCaptor2; private final List servers = Lists.newArrayList(); - private final Map, Subchannel> subchannels = Maps.newLinkedHashMap(); - + private final Map mockToRealSubChannelMap = new HashMap<>(); private final Map subchannelStateListeners = Maps.newLinkedHashMap(); @@ -134,7 +135,8 @@ public void setup() { SocketAddress addr = new FakeSocketAddress("server" + i); EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); servers.add(eag); - Subchannel sc = mock(Subchannel.class); + Subchannel sc = helper.createSubchannel(CreateSubchannelArgs.newBuilder().setAddresses(eag) + .build()); Channel channel = mock(Channel.class); when(channel.newCall(any(), any())).then( new Answer>() { @@ -147,35 +149,13 @@ public ClientCall answer( return clientCall; } }); - when(sc.asChannel()).thenReturn(channel); + testHelperInstance.setChannel(mockToRealSubChannelMap.get(sc), channel); subchannels.put(Arrays.asList(eag), sc); } - when(helper.getSynchronizationContext()).thenReturn(syncContext); - when(helper.getScheduledExecutorService()).thenReturn( - fakeClock.getScheduledExecutorService()); - when(helper.createSubchannel(any(CreateSubchannelArgs.class))) - .then(new Answer() { - @Override - public Subchannel answer(InvocationOnMock invocation) throws Throwable { - CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; - final Subchannel subchannel = subchannels.get(args.getAddresses()); - when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); - when(subchannel.getAttributes()).thenReturn(args.getAttributes()); - when(subchannel.getChannelLogger()).thenReturn(mock(ChannelLogger.class)); - doAnswer( - new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - subchannelStateListeners.put( - subchannel, (SubchannelStateListener) invocation.getArguments()[0]); - return null; - } - }).when(subchannel).start(any(SubchannelStateListener.class)); - return subchannel; - } - }); wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker(), new FakeRandom(0)); + + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); } @Test @@ -183,44 +163,44 @@ public void wrrLifeCycle() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel connectingSubchannel = it.next(); - subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.CONNECTING)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); - assertThat(weightedPicker.getList().size()).isEqualTo(1); + assertThat(weightedPicker.getChildren().size()).isEqualTo(1); weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - assertThat(weightedPicker.getList().size()).isEqualTo(2); + assertThat(weightedPicker.getChildren().size()).isEqualTo(2); String weightedPickerStr = weightedPicker.toString(); assertThat(weightedPickerStr).contains("enableOobLoadReport=false"); assertThat(weightedPickerStr).contains("errorUtilizationPenalty=1.0"); assertThat(weightedPickerStr).contains("list="); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); - assertThat(weightedPicker.pickSubchannel(mockArgs) - .getSubchannel()).isEqualTo(weightedSubchannel1); + + assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(weightedChild1.getEag()); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() .setWeightUpdatePeriodNanos(500_000_000L) //.5s @@ -238,35 +218,44 @@ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalt verifyNoMoreInteractions(mockArgs); } + /** + * Picks subchannel using mockArgs, gets its EAG, and then strips the Attrs to make a key. + */ + private EquivalentAddressGroup getAddressesFromPick(WeightedRoundRobinPicker weightedPicker) { + return TestUtils.stripAttrs( + weightedPicker.pickSubchannel(mockArgs).getSubchannel().getAddresses()); + } + @Test public void enableOobLoadReportConfig() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.9, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); - assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1); + assertThat(getAddresses(pickResult)) + .isEqualTo(weightedChild1.getEag()); assertThat(pickResult.getStreamTracerFactory()).isNotNull(); // verify per-request listener assertThat(oobCalls.isEmpty()).isTrue(); @@ -280,7 +269,8 @@ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalt eq(ConnectivityState.READY), pickerCaptor2.capture()); weightedPicker = (WeightedRoundRobinPicker) pickerCaptor2.getAllValues().get(2); pickResult = weightedPicker.pickSubchannel(mockArgs); - assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1); + assertThat(getAddresses(pickResult)) + .isEqualTo(weightedChild1.getEag()); assertThat(pickResult.getStreamTracerFactory()).isNull(); OrcaLoadReportRequest golden = OrcaLoadReportRequest.newBuilder().setReportInterval( Duration.newBuilder().setSeconds(20).setNanos(30000000).build()).build(); @@ -295,46 +285,52 @@ private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3, syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel3 = it.next(); - subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel3).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( - r1); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( - r2); - weightedSubchannel3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( - r3); + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r1); + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r2); + weightedChild3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r3); + assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); - Map pickCount = new HashMap<>(); + Map pickCount = new HashMap<>(); for (int i = 0; i < 10000; i++) { - Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(3); - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 10000.0 - subchannel1PickRatio)) - .isLessThan(0.0002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 10000.0 - subchannel2PickRatio )) - .isLessThan(0.0002); - assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 10000.0 - subchannel3PickRatio )) - .isLessThan(0.0002); + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - subchannel1PickRatio)) + .isAtMost(0.0002); + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - subchannel2PickRatio )) + .isAtMost(0.0002); + assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 10000.0 - subchannel3PickRatio )) + .isAtMost(0.0002); + } + + private SubchannelStateListener getSubchannelStateListener(Subchannel mockSubChannel) { + return subchannelStateListeners.get(mockToRealSubChannelMap.get(mockSubChannel)); + } + + private static ChildLbState getChild(WeightedRoundRobinPicker picker, int index) { + return picker.getChildren().get(index); } @Test @@ -472,14 +468,14 @@ public void emptyConfig() { assertThat(wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(null) .setAttributes(affinity).build())).isFalse(); - verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(fakeClock.getPendingTasks()).isEmpty(); syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); assertThat(pickerCaptor.getValue().getClass().getName()) @@ -492,51 +488,51 @@ public void blackoutPeriod() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1); - Map pickCount = new HashMap<>(); - for (int i = 0; i < 1000; i++) { - Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + Map pickCount = new HashMap<>(); + for (int i = 0; i < 10000; i++) { + EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); // within blackout period, fallback to simple round robin - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5)).isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 0.5)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 0.5)).isLessThan(0.002); assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1); pickCount = new HashMap<>(); - for (int i = 0; i < 1000; i++) { - Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + for (int i = 0; i < 10000; i++) { + EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); // after blackout period - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 2.0 / 3)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 1.0 / 3)) .isLessThan(0.002); } @@ -545,39 +541,39 @@ public void updateWeightTimer() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel connectingSubchannel = it.next(); - subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.CONNECTING)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); - assertThat(weightedPicker.getList().size()).isEqualTo(1); + assertThat(weightedPicker.getChildren().size()).isEqualTo(1); weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - assertThat(weightedPicker.getList().size()).isEqualTo(2); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + assertThat(weightedPicker.getChildren().size()).isEqualTo(2); + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); - assertThat(weightedPicker.pickSubchannel(mockArgs) - .getSubchannel()).isEqualTo(weightedSubchannel1); + assertThat(getAddressesFromPick(weightedPicker)) + .isEqualTo(weightedChild1.getEag()); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() .setWeightUpdatePeriodNanos(500_000_000L) //.5s @@ -586,17 +582,18 @@ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalt .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); //timer fires, new weight updated assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1); - assertThat(weightedPicker.pickSubchannel(mockArgs) - .getSubchannel()).isEqualTo(weightedSubchannel2); - + assertThat(getAddressesFromPick(weightedPicker)) + .isEqualTo(weightedChild2.getEag()); + assertThat(getAddressesFromPick(weightedPicker)) + .isEqualTo(weightedChild1.getEag()); } @Test @@ -604,52 +601,52 @@ public void weightExpired() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); - Map pickCount = new HashMap<>(); + Map pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { - Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3)) .isLessThan(0.002); // weight expired, fallback to simple round robin assertThat(fakeClock.forwardTime(300, TimeUnit.SECONDS)).isEqualTo(1); pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { - Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 0.5)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 0.5)) .isLessThan(0.002); } @@ -658,107 +655,113 @@ public void rrFallback() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - Map qpsByChannel = ImmutableMap.of(weightedSubchannel1, 2, - weightedSubchannel2, 1); - Map pickCount = new HashMap<>(); + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + Map qpsByChannel = ImmutableMap.of(weightedChild1.getEag(), 2, + weightedChild2.getEag(), 1); + Map pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); - pickCount.put(pickResult.getSubchannel(), - pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1); + EquivalentAddressGroup addresses = getAddresses(pickResult); + pickCount.merge(addresses, 1, Integer::sum); assertThat(pickResult.getStreamTracerFactory()).isNotNull(); - WrrSubchannel subchannel = (WrrSubchannel)pickResult.getSubchannel(); - subchannel.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses); + childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( - 0.1, 0, 0.1, qpsByChannel.get(subchannel), 0, + 0.1, 0, 0.1, qpsByChannel.get(addresses), 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); } - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 1.0 / 2)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 1.0 / 2)) .isAtMost(0.1); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 2)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 2)) .isAtMost(0.1); + + // Identical to above except forwards time after each pick pickCount.clear(); for (int i = 0; i < 1000; i++) { PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); - pickCount.put(pickResult.getSubchannel(), - pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1); + EquivalentAddressGroup addresses = getAddresses(pickResult); + pickCount.merge(addresses, 1, Integer::sum); assertThat(pickResult.getStreamTracerFactory()).isNotNull(); - WrrSubchannel subchannel = (WrrSubchannel) pickResult.getSubchannel(); - subchannel.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses); + childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( - 0.1, 0, 0.1, qpsByChannel.get(subchannel), 0, + 0.1, 0, 0.1, qpsByChannel.get(addresses), 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); fakeClock.forwardTime(50, TimeUnit.MILLISECONDS); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3)) .isAtMost(0.1); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3)) .isAtMost(0.1); } + private static EquivalentAddressGroup getAddresses(PickResult pickResult) { + return TestUtils.stripAttrs(pickResult.getSubchannel().getAddresses()); + } + @Test public void unknownWeightIsAvgWeight() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( - any(CreateSubchannelArgs.class)); + verify(helper, times(6)).createSubchannel( + any(CreateSubchannelArgs.class)); // 3 from setup plus 3 from the execute assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo - .forNonError(ConnectivityState.READY)); + getSubchannelStateListener(readySubchannel1) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo - .forNonError(ConnectivityState.READY)); + getSubchannelStateListener(readySubchannel2) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); Subchannel readySubchannel3 = it.next(); - subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo - .forNonError(ConnectivityState.READY)); + getSubchannelStateListener(readySubchannel3) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); - Map pickCount = new HashMap<>(); + Map pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); - pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + pickCount.merge(result.getAddresses(), 1, Integer::sum); } assertThat(pickCount.size()).isEqualTo(3); - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 4.0 / 9)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 2.0 / 9)) .isLessThan(0.002); // subchannel3's weight is average of subchannel1 and subchannel2 - assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9)) + assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 1000.0 - 3.0 / 9)) .isLessThan(0.002); } @@ -767,33 +770,33 @@ public void pickFromOtherThread() throws Exception { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); CyclicBarrier barrier = new CyclicBarrier(2); - Map pickCount = new ConcurrentHashMap<>(); - pickCount.put(weightedSubchannel1, new AtomicInteger(0)); - pickCount.put(weightedSubchannel2, new AtomicInteger(0)); + Map pickCount = new ConcurrentHashMap<>(); + pickCount.put(weightedChild1.getEag(), new AtomicInteger(0)); + pickCount.put(weightedChild2.getEag(), new AtomicInteger(0)); new Thread(new Runnable() { @Override public void run() { @@ -802,7 +805,7 @@ public void run() { barrier.await(); for (int i = 0; i < 1000; i++) { Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); - pickCount.get(result).addAndGet(1); + pickCount.get(result.getAddresses()).addAndGet(1); } barrier.await(); } catch (Exception ex) { @@ -813,15 +816,15 @@ public void run() { assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); barrier.await(); for (int i = 0; i < 1000; i++) { - Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + EquivalentAddressGroup result = getAddresses(weightedPicker.pickSubchannel(mockArgs)); pickCount.get(result).addAndGet(1); } barrier.await(); assertThat(pickCount.size()).isEqualTo(2); // after blackout period - assertThat(Math.abs(pickCount.get(weightedSubchannel1).get() / 2000.0 - 2.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()).get() / 2000.0 - 2.0 / 3)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2).get() / 2000.0 - 1.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()).get() / 2000.0 - 1.0 / 3)) .isLessThan(0.002); } @@ -1104,4 +1107,34 @@ public int nextInt() { return nextInt; } } + + private class TestHelper extends AbstractTestHelper { + + @Override + public Map, Subchannel> getSubchannelMap() { + return subchannels; + } + + @Override + public Map getMockToRealSubChannelMap() { + return mockToRealSubChannelMap; + } + + @Override + public Map getSubchannelStateListeners() { + return subchannelStateListeners; + } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + + @Override + public ScheduledExecutorService getScheduledExecutorService() { + return fakeClock.getScheduledExecutorService(); + } + + + } }