Skip to content

Commit

Permalink
[Backport 2.x] Better support wrapping of TransportChannel (opensearc…
Browse files Browse the repository at this point in the history
…h-project#3778)

Backport 481b373 from opensearch-project#3769.

---------

Signed-off-by: Craig Perkins <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Craig Perkins <[email protected]>
  • Loading branch information
3 people authored Nov 30, 2023
1 parent 2942490 commit 9110114
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Set;
import javax.net.ssl.SSLPeerUnverifiedException;

import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -55,6 +56,8 @@ public class SecuritySSLRequestHandler<T extends TransportRequest> implements Tr
private final SslExceptionHandler errorHandler;
private final SSLConfig SSLConfig;

private static final Set<String> DEFAULT_CHANNEL_TYPES = Set.of("direct", "transport");

public SecuritySSLRequestHandler(
String action,
TransportRequestHandler<T> actualHandler,
Expand Down Expand Up @@ -86,6 +89,11 @@ public final void messageReceived(T request, TransportChannel channel, Task task

ThreadContext threadContext = getThreadContext();

String channelType = channel.getChannelType();
if (!DEFAULT_CHANNEL_TYPES.contains(channelType)) {
channel = getInnerChannel(channel);
}

threadContext.putTransient(
ConfigConstants.USE_JDK_SERIALIZATION,
channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
Expand All @@ -97,11 +105,6 @@ public final void messageReceived(T request, TransportChannel channel, Task task
throw exception;
}

String channelType = channel.getChannelType();
if (!channelType.equals("direct") && !channelType.equals("transport")) {
channel = getInnerChannel(channel);
}

if (!"transport".equals(channel.getChannelType())) { // netty4
messageReceivedDecorate(request, actualHandler, channel, task);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
*/
package org.opensearch.security.transport;

import java.io.IOException;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.opensearch.Version;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.security.ssl.SslExceptionHandler;
import org.opensearch.security.ssl.transport.PrincipalExtractor;
import org.opensearch.security.ssl.transport.SSLConfig;
Expand All @@ -27,11 +30,13 @@
import org.opensearch.transport.TransportRequestHandler;

import org.mockito.ArgumentMatchers;
import org.mockito.InOrder;
import org.mockito.Mock;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -88,4 +93,76 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
}

@Test
public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Exception {
TransportRequest transportRequest = mock(TransportRequest.class);
TransportChannel transportChannel = mock(TransportChannel.class);
TransportChannel wrappedChannel = new WrappedTransportChannel(transportChannel);
Task task = mock(Task.class);
doNothing().when(transportChannel).sendResponse(ArgumentMatchers.any(Exception.class));
when(transportChannel.getVersion()).thenReturn(Version.V_2_10_0);
when(transportChannel.getChannelType()).thenReturn("other");

Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
}

@Test
public void testUseJDKSerializationHeaderIsSetAfterGetInnerChannel() throws Exception {
TransportRequest transportRequest = mock(TransportRequest.class);
TransportChannel transportChannel = mock(TransportChannel.class);
WrappedTransportChannel wrappedChannel = mock(WrappedTransportChannel.class);
Task task = mock(Task.class);
when(wrappedChannel.getInnerChannel()).thenReturn(transportChannel);
when(wrappedChannel.getChannelType()).thenReturn("other");
doNothing().when(transportChannel).sendResponse(ArgumentMatchers.any(Exception.class));
when(transportChannel.getVersion()).thenReturn(Version.V_2_10_0);

Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

InOrder inOrder = inOrder(wrappedChannel, transportChannel);

inOrder.verify(wrappedChannel).getInnerChannel();
inOrder.verify(transportChannel).getVersion();
}

public class WrappedTransportChannel implements TransportChannel {

private TransportChannel inner;

public WrappedTransportChannel(TransportChannel inner) {
this.inner = inner;
}

@Override
public String getProfileName() {
return "WrappedTransportChannelProfileName";
}

public TransportChannel getInnerChannel() {
return this.inner;
}

@Override
public void sendResponse(TransportResponse response) throws IOException {
inner.sendResponse(response);
}

@Override
public void sendResponse(Exception e) throws IOException {

}

@Override
public String getChannelType() {
return "WrappedTransportChannelType";
}
}
}

0 comments on commit 9110114

Please sign in to comment.