Skip to content

Commit

Permalink
Delegate to NettyAllocator.getAllocator() for ByteBufAllocator instea…
Browse files Browse the repository at this point in the history
…d of hard-coding PooledByteBufAllocator. (#1396)

Signed-off-by: Peter Nied <[email protected]>
  • Loading branch information
vrozov authored Jun 7, 2022
1 parent ce59944 commit 1ea2cd4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLParameters;

import io.netty.buffer.PooledByteBufAllocator;
import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.OpenSsl;
Expand Down Expand Up @@ -78,6 +77,7 @@
import org.opensearch.security.ssl.util.ExceptionUtils;
import org.opensearch.security.ssl.util.KeystoreProps;
import org.opensearch.security.ssl.util.SSLConfigConstants;
import org.opensearch.transport.NettyAllocator;

public class DefaultSecurityKeyStore implements SecurityKeyStore {

Expand Down Expand Up @@ -653,21 +653,21 @@ private boolean areSameCerts(final X509Certificate[] currentX509Certs, final X50
}

public SSLEngine createHTTPSSLEngine() throws SSLException {
final SSLEngine engine = httpSslContext.newEngine(PooledByteBufAllocator.DEFAULT);
final SSLEngine engine = httpSslContext.newEngine(NettyAllocator.getAllocator());
engine.setEnabledProtocols(getEnabledSSLProtocols(this.sslHTTPProvider, true));
return engine;

}

public SSLEngine createServerTransportSSLEngine() throws SSLException {
final SSLEngine engine = transportServerSslContext.newEngine(PooledByteBufAllocator.DEFAULT);
final SSLEngine engine = transportServerSslContext.newEngine(NettyAllocator.getAllocator());
engine.setEnabledProtocols(getEnabledSSLProtocols(this.sslTransportServerProvider, false));
return engine;
}

public SSLEngine createClientTransportSSLEngine(final String peerHost, final int peerPort) throws SSLException {
if (peerHost != null) {
final SSLEngine engine = transportClientSslContext.newEngine(PooledByteBufAllocator.DEFAULT, peerHost,
final SSLEngine engine = transportClientSslContext.newEngine(NettyAllocator.getAllocator(), peerHost,
peerPort);

final SSLParameters sslParams = new SSLParameters();
Expand All @@ -676,7 +676,7 @@ public SSLEngine createClientTransportSSLEngine(final String peerHost, final int
engine.setEnabledProtocols(getEnabledSSLProtocols(this.sslTransportClientProvider, false));
return engine;
} else {
final SSLEngine engine = transportClientSslContext.newEngine(PooledByteBufAllocator.DEFAULT);
final SSLEngine engine = transportClientSslContext.newEngine(NettyAllocator.getAllocator());
engine.setEnabledProtocols(getEnabledSSLProtocols(this.sslTransportClientProvider, false));
return engine;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
Expand All @@ -34,10 +33,13 @@
import org.opensearch.security.ssl.SecurityKeyStore;
import org.opensearch.security.ssl.util.SSLConnectionTestUtil;

import static org.opensearch.transport.NettyAllocator.getAllocator;

public class DualModeSSLHandlerTests {

public static final int TLS_MAJOR_VERSION = 3;
public static final int TLS_MINOR_VERSION = 0;
private static final ByteBufAllocator ALLOCATOR = getAllocator();

private SecurityKeyStore securityKeyStore;
private ChannelPipeline pipeline;
Expand All @@ -58,8 +60,7 @@ public void setup() {
public void testInvalidMessage() throws Exception {
DualModeSSLHandler handler = new DualModeSSLHandler(securityKeyStore);

ByteBufAllocator alloc = PooledByteBufAllocator.DEFAULT;
handler.decode(ctx, alloc.directBuffer(4), null);
handler.decode(ctx, ALLOCATOR.buffer(4), null);
// ensure pipeline is not fetched and manipulated
Mockito.verify(ctx, Mockito.times(0)).pipeline();
}
Expand All @@ -68,8 +69,7 @@ public void testInvalidMessage() throws Exception {
public void testValidTLSMessage() throws Exception {
DualModeSSLHandler handler = new DualModeSSLHandler(securityKeyStore, sslHandler);

ByteBufAllocator alloc = PooledByteBufAllocator.DEFAULT;
ByteBuf buffer = alloc.directBuffer(6);
ByteBuf buffer = ALLOCATOR.buffer(6);
buffer.writeByte(20);
buffer.writeByte(TLS_MAJOR_VERSION);
buffer.writeByte(TLS_MINOR_VERSION);
Expand All @@ -90,8 +90,7 @@ public void testValidTLSMessage() throws Exception {
public void testNonTLSMessage() throws Exception {
DualModeSSLHandler handler = new DualModeSSLHandler(securityKeyStore, sslHandler);

ByteBufAllocator alloc = PooledByteBufAllocator.DEFAULT;
ByteBuf buffer = alloc.directBuffer(6);
ByteBuf buffer = ALLOCATOR.buffer(6);

for (int i = 0; i < 6; i++) {
buffer.writeByte(1);
Expand All @@ -112,8 +111,7 @@ public void testDualModeClientHelloMessage() throws Exception {
Mockito.when(ctx.writeAndFlush(Mockito.any())).thenReturn(channelFuture);
Mockito.when(channelFuture.addListener(Mockito.any())).thenReturn(channelFuture);

ByteBufAllocator alloc = PooledByteBufAllocator.DEFAULT;
ByteBuf buffer = alloc.directBuffer(6);
ByteBuf buffer = ALLOCATOR.buffer(6);
buffer.writeCharSequence(SSLConnectionTestUtil.DUAL_MODE_CLIENT_HELLO_MSG, StandardCharsets.UTF_8);

DualModeSSLHandler handler = new DualModeSSLHandler(securityKeyStore, sslHandler);
Expand Down
13 changes: 6 additions & 7 deletions src/test/java/org/opensearch/security/ssl/util/TLSUtilTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.PooledByteBufAllocator;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import static org.opensearch.transport.NettyAllocator.getAllocator;

public class TLSUtilTests {

public static final int TLS_MAJOR_VERSION = 3;
public static final int TLS_MINOR_VERSION = 0;
private static final ByteBufAllocator ALLOCATOR = getAllocator();

@Before
public void setup() {
Expand All @@ -35,8 +37,7 @@ public void setup() {
public void testSSLUtilSuccess() {
// byte 20 to 24 are ssl headers
for (int byteToSend = 20; byteToSend <= 24; byteToSend++) {
ByteBufAllocator alloc = PooledByteBufAllocator.DEFAULT;
ByteBuf buffer = alloc.directBuffer(5);
ByteBuf buffer = ALLOCATOR.buffer(5);
buffer.writeByte(byteToSend);
buffer.writeByte(TLS_MAJOR_VERSION);
buffer.writeByte(TLS_MINOR_VERSION);
Expand All @@ -50,8 +51,7 @@ public void testSSLUtilSuccess() {
public void testSSLUtilWrongTLSVersion() {
// byte 20 to 24 are ssl headers
for (int byteToSend = 20; byteToSend <= 24; byteToSend++) {
ByteBufAllocator alloc = PooledByteBufAllocator.DEFAULT;
ByteBuf buffer = alloc.directBuffer(5);
ByteBuf buffer = ALLOCATOR.buffer(5);
buffer.writeByte(byteToSend);
//setting invalid TLS version 100
buffer.writeByte(100);
Expand All @@ -66,8 +66,7 @@ public void testSSLUtilWrongTLSVersion() {
public void testSSLUtilInvalidContentLength() {
// byte 20 to 24 are ssl headers
for (int byteToSend = 20; byteToSend <= 24; byteToSend++) {
ByteBufAllocator alloc = PooledByteBufAllocator.DEFAULT;
ByteBuf buffer = alloc.directBuffer(5);
ByteBuf buffer = ALLOCATOR.buffer(5);
buffer.writeByte(byteToSend);
buffer.writeByte(TLS_MAJOR_VERSION);
buffer.writeByte(TLS_MINOR_VERSION);
Expand Down

0 comments on commit 1ea2cd4

Please sign in to comment.