Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xds: fix the race condition in SslContextProviderSupplier's updateSslContext and close #8294

Merged
merged 2 commits into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package io.grpc.xds.internal.sds;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
Expand Down Expand Up @@ -56,13 +55,14 @@ public BaseTlsContext getTlsContext() {
public synchronized void updateSslContext(final SslContextProvider.Callback callback) {
checkNotNull(callback, "callback");
try {
checkState(!shutdown, "Supplier is shutdown!");
if (sslContextProvider == null) {
sslContextProvider = getSslContextProvider();
if (!shutdown) {
if (sslContextProvider == null) {
sslContextProvider = getSslContextProvider();
}
}
// we want to increment the ref-count so call findOrCreate again...
final SslContextProvider toRelease = getSslContextProvider();
Comment on lines +58 to 64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just return if shutdown is true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you return when shutdown is true, the caller's callback will never be called and that would be bad.

Let me give you some context for this change. This is for a very rare race condition: when the control plane sends a DownstreamTlsContext we translate it to an SslContextProviderSupplier (which internally performs lazy loading and ref-counting etc). For a new incoming connection to the server, gRPC will figure out the SslContextProviderSupplier to use and give it to the protocol negotiator. Before the protocol negotiator has a chance to use it if the control plane replaces the DownstreamTlsContext value for the server, then we call close on the existing SslContextProviderSupplier which will set shutdown to true and release the SslContextProvider (thereby making its ref-count 0). Let's say after this event the protocol negotiator wants to get the SslContext for the connection so it calls updateSslContext on the SslContextProviderSupplier which is now in shut-down state. We can either throw an exception via the callback's onException (but not silently return) or fall through here to get an SslContextProvider and use the callback to provide a proper SslContext to the protocol negotiator (even if it is related to the old DownstreamTlsContext value). So this change does the latter to allow the protocol negotiator to succeed. There will be delays due to fresh loading of the SslContextProvider but that is better than failing the connection just because it came in the middle of switching of the DownstreamTlsContext values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hope this answers your question

sslContextProvider.addCallback(
toRelease.addCallback(
new SslContextProvider.Callback(callback.getExecutor()) {

@Override
Expand Down Expand Up @@ -115,6 +115,7 @@ public synchronized void close() {
tlsContextManager.releaseServerSslContextProvider(sslContextProvider);
}
}
// don't set sslContextProvider to null since we don't want reallocation under any circumstances
sanjaypujare marked this conversation as resolved.
Show resolved Hide resolved
shutdown = true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,13 @@
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.xds.EnvoyServerProtoData;
import io.grpc.xds.TlsContextManager;
import io.netty.handler.ssl.SslContext;
import java.util.concurrent.Executor;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -91,11 +88,11 @@ public void get_updateSecret() {
capturedCallback.updateSecret(mockSslContext);
verify(mockCallback, times(1)).updateSecret(eq(mockSslContext));
verify(mockTlsContextManager, times(1))
.releaseClientSslContextProvider(eq(mockSslContextProvider));
.releaseClientSslContextProvider(eq(mockSslContextProvider));
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class);
supplier.updateSslContext(mockCallback);
verify(mockTlsContextManager, times(3))
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
}

@Test
Expand All @@ -106,9 +103,11 @@ public void get_onException() {
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
assertThat(capturedCallback).isNotNull();
capturedCallback.onException(new Exception("test"));
Exception exception = new Exception("test");
capturedCallback.onException(exception);
verify(mockCallback, times(1)).onException(eq(exception));
verify(mockTlsContextManager, times(1))
.releaseClientSslContextProvider(eq(mockSslContextProvider));
.releaseClientSslContextProvider(eq(mockSslContextProvider));
}

@Test
Expand All @@ -118,20 +117,11 @@ public void testClose() {
supplier.close();
verify(mockTlsContextManager, times(1))
.releaseClientSslContextProvider(eq(mockSslContextProvider));
SslContextProvider.Callback mockCallback = spy(
new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
@Override
public void updateSecret(SslContext sslContext) {
Assert.fail("unexpected call");
}

@Override
protected void onException(Throwable argument) {
assertThat(argument).isInstanceOf(IllegalStateException.class);
assertThat(argument).hasMessageThat().contains("Supplier is shutdown!");
}
});
supplier.updateSslContext(mockCallback);
verify(mockTlsContextManager, times(3))
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
verify(mockTlsContextManager, times(1))
.releaseClientSslContextProvider(any(SslContextProvider.class));
}

@Test
Expand All @@ -142,19 +132,8 @@ public void testClose_nullSslContextProvider() {
supplier.close();
verify(mockTlsContextManager, never())
.releaseClientSslContextProvider(eq(mockSslContextProvider));
SslContextProvider.Callback mockCallback = spy(
new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
@Override
public void updateSecret(SslContext sslContext) {
Assert.fail("unexpected call");
}

@Override
protected void onException(Throwable argument) {
assertThat(argument).isInstanceOf(IllegalStateException.class);
assertThat(argument).hasMessageThat().contains("Supplier is shutdown!");
}
});
supplier.updateSslContext(mockCallback);
callUpdateSslContext();
verify(mockTlsContextManager, times(1))
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
}
}