Skip to content

Commit

Permalink
Improve code coverage for SSLNettyTransport class (#3953)
Browse files Browse the repository at this point in the history
### Description
[Describe what this change achieves]
This change increases code coverage for the SecuritySSLNettyTransport
class. In the middle of 12/23, a few unit tests were added to give
coverage to different parts of the class. This change builds on these
existing changes.

### Issues Resolved
Box three of #3137

Signed-off-by: Stephen Crawford <[email protected]>
  • Loading branch information
stephen-crawford authored Jan 18, 2024
1 parent 09051f4 commit 037bc20
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.logging.log4j.Logger;

import org.opensearch.ExceptionsHelper;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.Version;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.network.NetworkService;
Expand Down Expand Up @@ -103,6 +104,11 @@ public SecuritySSLNettyTransport(
this.SSLConfig = SSLConfig;
}

// This allows for testing log messages
Logger getLogger() {
return logger;
}

@Override
public void onException(TcpChannel channel, Exception e) {

Expand All @@ -113,8 +119,11 @@ public void onException(TcpChannel channel, Exception e) {
}

errorHandler.logError(cause, false);
logger.error("Exception during establishing a SSL connection: " + cause, cause);
getLogger().error("Exception during establishing a SSL connection: " + cause, cause);

if (channel == null || !channel.isOpen()) {
throw new OpenSearchSecurityException("The provided TCP channel is invalid.", e);
}
super.onException(channel, e);
}

Expand Down Expand Up @@ -156,7 +165,7 @@ public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) th
}

errorHandler.logError(cause, false);
logger.error("Exception during establishing a SSL connection: " + cause, cause);
getLogger().error("Exception during establishing a SSL connection: " + cause, cause);

super.exceptionCaught(ctx, cause);
}
Expand Down Expand Up @@ -291,7 +300,7 @@ public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) th
}

errorHandler.logError(cause, false);
logger.error("Exception during establishing a SSL connection: " + cause, cause);
getLogger().error("Exception during establishing a SSL connection: " + cause, cause);

super.exceptionCaught(ctx, cause);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@

package org.opensearch.security.ssl.transport;

import org.junit.Assert;
import java.util.Collections;

import org.apache.logging.log4j.Logger;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;

import org.opensearch.OpenSearchSecurityException;
import org.opensearch.Version;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.network.NetworkService;
Expand All @@ -28,15 +32,27 @@
import org.opensearch.security.ssl.transport.SecuritySSLNettyTransport.SSLServerChannelInitializer;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.FakeTcpChannel;
import org.opensearch.transport.SharedGroupFactory;
import org.opensearch.transport.TcpChannel;

import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.DecoderException;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class SecuritySSLNettyTransportTests {

Expand All @@ -45,16 +61,12 @@ public class SecuritySSLNettyTransportTests {
@Mock
private ThreadPool threadPool;
@Mock
private NetworkService networkService;
@Mock
private PageCacheRecycler pageCacheRecycler;
@Mock
private NamedWriteableRegistry namedWriteableRegistry;
@Mock
private CircuitBreakerService circuitBreakerService;
@Mock
private SharedGroupFactory sharedGroupFactory;
@Mock
private Tracer trace;
@Mock
private SecurityKeyStore ossks;
Expand All @@ -63,55 +75,127 @@ public class SecuritySSLNettyTransportTests {
@Mock
private DiscoveryNode discoveryNode;

// This initializes all the above mocks
@Rule
public MockitoRule rule = MockitoJUnit.rule();

private NetworkService networkService;
private SharedGroupFactory sharedGroupFactory;
private Logger mockLogger;
private SSLConfig sslConfig;
private SecuritySSLNettyTransport securitySSLNettyTransport;
Throwable testCause = new Throwable("Test Cause");

@Before
public void setup() {

sslConfig = new SSLConfig(Settings.EMPTY);
networkService = new NetworkService(Collections.emptyList());
sharedGroupFactory = new SharedGroupFactory(Settings.EMPTY);

securitySSLNettyTransport = new SecuritySSLNettyTransport(
Settings.EMPTY,
version,
threadPool,
networkService,
pageCacheRecycler,
namedWriteableRegistry,
circuitBreakerService,
ossks,
sslExceptionHandler,
sharedGroupFactory,
sslConfig,
trace
sslConfig = new SSLConfig(Settings.EMPTY);
mockLogger = mock(Logger.class);

securitySSLNettyTransport = spy(
new SecuritySSLNettyTransport(
Settings.EMPTY,
version,
threadPool,
networkService,
pageCacheRecycler,
namedWriteableRegistry,
circuitBreakerService,
ossks,
sslExceptionHandler,
sharedGroupFactory,
sslConfig,
trace
)
);
}

@Test
public void OnException_withNullChannelShouldThrowException() {

NullPointerException exception = new NullPointerException("Test Exception");
OpenSearchSecurityException exception = new OpenSearchSecurityException("The provided TCP channel is invalid");
assertThrows(OpenSearchSecurityException.class, () -> securitySSLNettyTransport.onException(null, exception));
}

@Test
public void OnException_withClosedChannelShouldThrowException() {

TcpChannel channel = new FakeTcpChannel();
channel.close();
OpenSearchSecurityException exception = new OpenSearchSecurityException("The provided TCP channel is invalid");
assertThrows(OpenSearchSecurityException.class, () -> securitySSLNettyTransport.onException(channel, exception));
}

@Test
public void OnException_withNullExceptionShouldSucceed() {

TcpChannel channel = new FakeTcpChannel();
securitySSLNettyTransport.onException(channel, null);
verify(securitySSLNettyTransport, times(1)).onException(channel, null);
channel.close();
}

Assert.assertThrows(NullPointerException.class, () -> securitySSLNettyTransport.onException(null, exception));
@Test
public void OnException_withDecoderExceptionShouldGetCause() {

when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger);
DecoderException exception = new DecoderException("Test Exception", testCause);
TcpChannel channel = new FakeTcpChannel();
securitySSLNettyTransport.onException(channel, exception);
verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception.getCause(), exception.getCause());
}

@Test
public void getServerChannelInitializer_shouldReturnValidServerChannel() {

ChannelHandler channelHandler = securitySSLNettyTransport.getServerChannelInitializer("test-server-channel");

assertThat(channelHandler, is(notNullValue()));
assertThat(channelHandler, is(instanceOf(SSLServerChannelInitializer.class)));
}

@Test
public void getClientChannelInitializer_shouldReturnValidClientChannel() {

ChannelHandler channelHandler = securitySSLNettyTransport.getClientChannelInitializer(discoveryNode);

assertThat(channelHandler, is(notNullValue()));
assertThat(channelHandler, is(instanceOf(SSLClientChannelInitializer.class)));
}

@Test
public void exceptionWithServerChannelHandlerContext_nonNullDecoderExceptionShouldGetCause() throws Exception {
when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger);
Throwable exception = new DecoderException("Test Exception", testCause);
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
securitySSLNettyTransport.getServerChannelInitializer(discoveryNode.getName()).exceptionCaught(ctx, exception);
verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception.getCause(), exception.getCause());
}

@Test
public void exceptionWithServerChannelHandlerContext_nonNullCauseOnlyShouldNotGetCause() throws Exception {
when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger);
Throwable exception = new OpenSearchSecurityException("Test Exception", testCause);
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
securitySSLNettyTransport.getServerChannelInitializer(discoveryNode.getName()).exceptionCaught(ctx, exception);
verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception, exception);
}

@Test
public void exceptionWithClientChannelHandlerContext_nonNullDecoderExceptionShouldGetCause() throws Exception {
when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger);
Throwable exception = new DecoderException("Test Exception", testCause);
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
securitySSLNettyTransport.getClientChannelInitializer(discoveryNode).exceptionCaught(ctx, exception);
verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception.getCause(), exception.getCause());
}

@Test
public void exceptionWithClientChannelHandlerContext_nonNullCauseOnlyShouldNotGetCause() throws Exception {
when(securitySSLNettyTransport.getLogger()).thenReturn(mockLogger);
Throwable exception = new OpenSearchSecurityException("Test Exception", testCause);
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
securitySSLNettyTransport.getClientChannelInitializer(discoveryNode).exceptionCaught(ctx, exception);
verify(mockLogger, times(1)).error("Exception during establishing a SSL connection: " + exception, exception);
}
}

0 comments on commit 037bc20

Please sign in to comment.