Skip to content

Commit

Permalink
xds: fix the race condition in SslContextProviderSupplier's updateSsl…
Browse files Browse the repository at this point in the history
…Context and close (#8294)
  • Loading branch information
sanjaypujare authored Jul 9, 2021
1 parent 3965315 commit 629748d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package io.grpc.xds.internal.sds;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
Expand Down Expand Up @@ -56,13 +55,14 @@ public BaseTlsContext getTlsContext() {
public synchronized void updateSslContext(final SslContextProvider.Callback callback) {
checkNotNull(callback, "callback");
try {
checkState(!shutdown, "Supplier is shutdown!");
if (sslContextProvider == null) {
sslContextProvider = getSslContextProvider();
if (!shutdown) {
if (sslContextProvider == null) {
sslContextProvider = getSslContextProvider();
}
}
// we want to increment the ref-count so call findOrCreate again...
final SslContextProvider toRelease = getSslContextProvider();
sslContextProvider.addCallback(
toRelease.addCallback(
new SslContextProvider.Callback(callback.getExecutor()) {

@Override
Expand Down Expand Up @@ -115,6 +115,7 @@ public synchronized void close() {
tlsContextManager.releaseServerSslContextProvider(sslContextProvider);
}
}
sslContextProvider = null;
shutdown = true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,13 @@
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.xds.EnvoyServerProtoData;
import io.grpc.xds.TlsContextManager;
import io.netty.handler.ssl.SslContext;
import java.util.concurrent.Executor;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -91,11 +88,11 @@ public void get_updateSecret() {
capturedCallback.updateSecret(mockSslContext);
verify(mockCallback, times(1)).updateSecret(eq(mockSslContext));
verify(mockTlsContextManager, times(1))
.releaseClientSslContextProvider(eq(mockSslContextProvider));
.releaseClientSslContextProvider(eq(mockSslContextProvider));
SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class);
supplier.updateSslContext(mockCallback);
verify(mockTlsContextManager, times(3))
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
}

@Test
Expand All @@ -106,9 +103,11 @@ public void get_onException() {
verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture());
SslContextProvider.Callback capturedCallback = callbackCaptor.getValue();
assertThat(capturedCallback).isNotNull();
capturedCallback.onException(new Exception("test"));
Exception exception = new Exception("test");
capturedCallback.onException(exception);
verify(mockCallback, times(1)).onException(eq(exception));
verify(mockTlsContextManager, times(1))
.releaseClientSslContextProvider(eq(mockSslContextProvider));
.releaseClientSslContextProvider(eq(mockSslContextProvider));
}

@Test
Expand All @@ -118,20 +117,11 @@ public void testClose() {
supplier.close();
verify(mockTlsContextManager, times(1))
.releaseClientSslContextProvider(eq(mockSslContextProvider));
SslContextProvider.Callback mockCallback = spy(
new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
@Override
public void updateSecret(SslContext sslContext) {
Assert.fail("unexpected call");
}

@Override
protected void onException(Throwable argument) {
assertThat(argument).isInstanceOf(IllegalStateException.class);
assertThat(argument).hasMessageThat().contains("Supplier is shutdown!");
}
});
supplier.updateSslContext(mockCallback);
verify(mockTlsContextManager, times(3))
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
verify(mockTlsContextManager, times(1))
.releaseClientSslContextProvider(any(SslContextProvider.class));
}

@Test
Expand All @@ -142,19 +132,8 @@ public void testClose_nullSslContextProvider() {
supplier.close();
verify(mockTlsContextManager, never())
.releaseClientSslContextProvider(eq(mockSslContextProvider));
SslContextProvider.Callback mockCallback = spy(
new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
@Override
public void updateSecret(SslContext sslContext) {
Assert.fail("unexpected call");
}

@Override
protected void onException(Throwable argument) {
assertThat(argument).isInstanceOf(IllegalStateException.class);
assertThat(argument).hasMessageThat().contains("Supplier is shutdown!");
}
});
supplier.updateSslContext(mockCallback);
callUpdateSslContext();
verify(mockTlsContextManager, times(1))
.findOrCreateClientSslContextProvider(eq(upstreamTlsContext));
}
}

0 comments on commit 629748d

Please sign in to comment.