From 8d636c4ea366fca88f28c3226d3997f91992f55e Mon Sep 17 00:00:00 2001 From: Darshit Chanpura <35282393+DarshitChanpura@users.noreply.github.com> Date: Mon, 10 Jul 2023 08:43:41 -0400 Subject: [PATCH] Adds a check to skip serialization-deserialization if request is for same node (#2765) Signed-off-by: Darshit Chanpura Signed-off-by: Craig Perkins Co-authored-by: Craig Perkins --- .../security/OpenSearchSecurityPlugin.java | 13 +- .../transport/SecurityInterceptor.java | 53 +++-- .../transport/SecurityRequestHandler.java | 103 ++++------ .../transport/SecurityInterceptorTests.java | 183 ++++++++++++++++++ 4 files changed, 266 insertions(+), 86 deletions(-) create mode 100644 src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java diff --git a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java index 67f046ed89..ffb7cfc075 100644 --- a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java +++ b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java @@ -73,6 +73,7 @@ import org.opensearch.action.support.ActionFilter; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.component.Lifecycle.State; @@ -211,6 +212,7 @@ public final class OpenSearchSecurityPlugin extends OpenSearchSecuritySSLPlugin private volatile ConfigurationRepository cr; private volatile AdminDNs adminDns; private volatile ClusterService cs; + private static volatile DiscoveryNode localNode; private volatile AuditLog auditLog; private volatile BackendRegistry backendRegistry; private volatile SslExceptionHandler sslExceptionHandler; @@ -1799,11 +1801,12 @@ public List getSettingsFilter() { } @Override - public void onNodeStarted() { + public void onNodeStarted(DiscoveryNode localNode) { log.info("Node started"); if (!SSLConfig.isSslOnlyMode() && !client && !disabled) { cr.initOnNodeStart(); } + this.localNode = localNode; final Set securityModules = ReflectionHelper.getModulesLoaded(); log.info("{} OpenSearch Security modules loaded so far: {}", securityModules.size(), securityModules); } @@ -1883,6 +1886,14 @@ private static String handleKeyword(final String field) { return field; } + public static DiscoveryNode getLocalNode() { + return localNode; + } + + public static void setLocalNode(DiscoveryNode node) { + localNode = node; + } + public static class GuiceHolder implements LifecycleComponent { private static RepositoriesService repositoriesService; diff --git a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java index 7bd5024d2c..66f4140d3e 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java +++ b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java @@ -43,6 +43,7 @@ import org.opensearch.action.get.GetRequest; import org.opensearch.action.search.SearchAction; import org.opensearch.action.search.SearchRequest; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; @@ -131,7 +132,6 @@ public void sendRequestDecorate( TransportRequestOptions options, TransportResponseHandler handler ) { - final Map origHeaders0 = getThreadContext().getHeaders(); final User user0 = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); final String injectedUserString = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER); @@ -146,6 +146,9 @@ public void sendRequestDecorate( final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS); final boolean isDebugEnabled = log.isDebugEnabled(); + final DiscoveryNode localNode = OpenSearchSecurityPlugin.getLocalNode(); + boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode()); + try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) { final TransportResponseHandler restoringHandler = new RestoringTransportResponseHandler(handler, stashedContext); getThreadContext().putHeader("_opendistro_security_remotecn", cs.getClusterName().value()); @@ -223,7 +226,7 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL getThreadContext().putHeader(headerMap); - ensureCorrectHeaders(remoteAddress0, user0, origin0, injectedUserString, injectedRolesString); + ensureCorrectHeaders(remoteAddress0, user0, origin0, injectedUserString, injectedRolesString, isSameNodeRequest); if (isActionTraceEnabled()) { getThreadContext().putHeader( @@ -249,7 +252,8 @@ private void ensureCorrectHeaders( final User origUser, final String origin, final String injectedUserString, - final String injectedRolesString + final String injectedRolesString, + boolean isSameNodeRequest ) { // keep original address @@ -263,30 +267,49 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADE getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADER, Origin.LOCAL.toString()); } + TransportAddress transportAddress = null; if (remoteAdr != null && remoteAdr instanceof TransportAddress) { - String remoteAddressHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER); - if (remoteAddressHeader == null) { - getThreadContext().putHeader( - ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER, - Base64Helper.serializeObject(((TransportAddress) remoteAdr).address()) - ); + transportAddress = (TransportAddress) remoteAdr; } } - String userHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); + // we put headers as transient for same node requests + if (isSameNodeRequest) { + if (transportAddress != null) { + getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, transportAddress); + } - if (userHeader == null) { if (origUser != null) { - getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(origUser)); + // if request is going to be handled by same node, we directly put transient value as the thread context is not going to be + // stah. + getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, origUser); } else if (StringUtils.isNotEmpty(injectedRolesString)) { - getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_HEADER, injectedRolesString); + getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES, injectedRolesString); } else if (StringUtils.isNotEmpty(injectedUserString)) { - getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER_HEADER, injectedUserString); + getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, injectedUserString); + } + } else { + if (transportAddress != null) { + getThreadContext().putHeader( + ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER, + Base64Helper.serializeObject(transportAddress.address()) + ); } - } + final String userHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); + if (userHeader == null) { + // put as headers for other requests + if (origUser != null) { + getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(origUser)); + } else if (StringUtils.isNotEmpty(injectedRolesString)) { + getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_HEADER, injectedRolesString); + } else if (StringUtils.isNotEmpty(injectedUserString)) { + getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER_HEADER, injectedUserString); + } + } + } } private ThreadContext getThreadContext() { diff --git a/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java b/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java index d1ad9b02e1..8ea82c9d9d 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java +++ b/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java @@ -95,7 +95,6 @@ protected void messageReceivedDecorate( final TransportChannel transportChannel, Task task ) throws Exception { - String resolvedActionClass = request.getClass().getSimpleName(); if (request instanceof BulkShardRequest) { @@ -142,7 +141,31 @@ protected void messageReceivedDecorate( } // bypass non-netty requests - if (channelType.equals("direct")) { + if (getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER) != null + || getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER) != null + || getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES) != null + || getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS) != null) { + + final String rolesValidation = getThreadContext().getHeader( + ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_VALIDATION_HEADER + ); + if (!Strings.isNullOrEmpty(rolesValidation)) { + getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_VALIDATION, rolesValidation); + } + + if (isActionTraceEnabled()) { + getThreadContext().putHeader( + "_opendistro_security_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID().toString(), + Thread.currentThread().getName() + + " DIR -> " + + transportChannel.getChannelType() + + " " + + getThreadContext().getHeaders() + ); + } + + putInitialActionClassHeader(initialActionClassValue, resolvedActionClass); + } else { final String userHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); final String injectedRolesHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_HEADER); final String injectedUserHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER_HEADER); @@ -162,15 +185,15 @@ protected void messageReceivedDecorate( ); } - final String originalRemoteAddress = getThreadContext().getHeader( - ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER - ); + String originalRemoteAddress = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER); if (!Strings.isNullOrEmpty(originalRemoteAddress)) { getThreadContext().putTransient( ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress)) ); + } else { + getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, request.remoteAddress()); } final String rolesValidation = getThreadContext().getHeader( @@ -179,20 +202,9 @@ protected void messageReceivedDecorate( if (!Strings.isNullOrEmpty(rolesValidation)) { getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_VALIDATION, rolesValidation); } + } - if (isActionTraceEnabled()) { - getThreadContext().putHeader( - "_opendistro_security_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID().toString(), - Thread.currentThread().getName() - + " DIR -> " - + transportChannel.getChannelType() - + " " - + getThreadContext().getHeaders() - ); - } - - putInitialActionClassHeader(initialActionClassValue, resolvedActionClass); - + if (channelType.equals("direct")) { super.messageReceivedDecorate(request, handler, transportChannel, task); return; } @@ -272,58 +284,10 @@ protected void messageReceivedDecorate( // network intercluster request or cross search cluster request // CS-SUPPRESS-SINGLE: RegexpSingleline Used to allow/disallow TLS connections to extensions - if (HeaderHelper.isInterClusterRequest(getThreadContext()) + if (!(HeaderHelper.isInterClusterRequest(getThreadContext()) || HeaderHelper.isTrustedClusterRequest(getThreadContext()) - || HeaderHelper.isExtensionRequest(getThreadContext())) { + || HeaderHelper.isExtensionRequest(getThreadContext()))) { // CS-ENFORCE-SINGLE - - final String userHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); - final String injectedRolesHeader = getThreadContext().getHeader( - ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_HEADER - ); - final String injectedUserHeader = getThreadContext().getHeader( - ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER_HEADER - ); - - if (Strings.isNullOrEmpty(userHeader)) { - // Keeping role injection with higher priority as plugins under OpenSearch will be using this - // on transport layer - if (!Strings.isNullOrEmpty(injectedRolesHeader)) { - getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES, injectedRolesHeader); - } else if (!Strings.isNullOrEmpty(injectedUserHeader)) { - getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, injectedUserHeader); - } - } else { - getThreadContext().putTransient( - ConfigConstants.OPENDISTRO_SECURITY_USER, - Objects.requireNonNull((User) Base64Helper.deserializeObject(userHeader)) - ); - } - - String originalRemoteAddress = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER); - - if (!Strings.isNullOrEmpty(originalRemoteAddress)) { - getThreadContext().putTransient( - ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, - new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress)) - ); - } else { - getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, request.remoteAddress()); - } - - final String rolesValidation = getThreadContext().getHeader( - ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_VALIDATION_HEADER - ); - if (!Strings.isNullOrEmpty(rolesValidation)) { - getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_VALIDATION, rolesValidation); - } - - } else { - // this is a netty request from a non-server node (maybe also be internal: or a shard request) - // and therefore issued by a transport client - - // since OS 2.0 we do not support this any longer because transport client no longer available - final OpenSearchException exception = ExceptionUtils.createTransportClientNoLongerSupportedException(); log.error(exception.toString()); transportChannel.sendResponse(exception); @@ -346,9 +310,8 @@ protected void messageReceivedDecorate( } putInitialActionClassHeader(initialActionClassValue, resolvedActionClass); - - super.messageReceivedDecorate(request, handler, transportChannel, task); } + super.messageReceivedDecorate(request, handler, transportChannel, task); } finally { if (isActionTraceEnabled()) { diff --git a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java new file mode 100644 index 0000000000..7291050d6e --- /dev/null +++ b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java @@ -0,0 +1,183 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.security.transport; + +// CS-SUPPRESS-SINGLE: RegexpSingleline Extensions manager used for creating a mock +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.action.search.PitService; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.extensions.ExtensionsManager; +import org.opensearch.indices.IndicesService; +import org.opensearch.repositories.RepositoriesService; +import org.opensearch.security.OpenSearchSecurityPlugin; +import org.opensearch.security.auditlog.AuditLog; +import org.opensearch.security.auth.BackendRegistry; +import org.opensearch.security.configuration.ClusterInfoHolder; +import org.opensearch.security.ssl.SslExceptionHandler; +import org.opensearch.security.ssl.transport.PrincipalExtractor; +import org.opensearch.security.ssl.transport.SSLConfig; +import org.opensearch.security.support.Base64Helper; +import org.opensearch.security.support.ConfigConstants; +import org.opensearch.security.user.User; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.transport.MockTransport; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.Transport.Connection; +import org.opensearch.transport.TransportInterceptor.AsyncSender; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportResponse; +import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; + +import static java.util.Collections.emptySet; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +// CS-ENFORCE-SINGLE + +public class SecurityInterceptorTests { + + private SecurityInterceptor securityInterceptor; + + @Mock + private BackendRegistry backendRegistry; + + @Mock + private AuditLog auditLog; + + @Mock + private PrincipalExtractor principalExtractor; + + @Mock + private InterClusterRequestEvaluator requestEvalProvider; + + @Mock + private ClusterService clusterService; + + @Mock + private SslExceptionHandler sslExceptionHandler; + + @Mock + private ClusterInfoHolder clusterInfoHolder; + + @Mock + private SSLConfig sslConfig; + + private Settings settings; + + private ThreadPool threadPool; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + settings = Settings.builder() + .put("node.name", SecurityInterceptorTests.class.getSimpleName()) + .put("request.headers.default", "1") + .build(); + threadPool = new ThreadPool(settings); + securityInterceptor = new SecurityInterceptor( + settings, + threadPool, + backendRegistry, + auditLog, + principalExtractor, + requestEvalProvider, + clusterService, + sslExceptionHandler, + clusterInfoHolder, + sslConfig + ); + } + + @Test + public void testSendRequestDecorate() { + + ClusterName clusterName = ClusterName.DEFAULT; + when(clusterService.getClusterName()).thenReturn(clusterName); + + MockTransport transport = new MockTransport(); + TransportService transportService = transport.createTransportService( + Settings.EMPTY, + threadPool, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + boundTransportAddress -> clusterService.state().nodes().get(SecurityInterceptor.class.getSimpleName()), + null, + emptySet() + ); + + // CS-SUPPRESS-SINGLE: RegexpSingleline Extensions manager used for creating a mock + OpenSearchSecurityPlugin.GuiceHolder guiceHolder = new OpenSearchSecurityPlugin.GuiceHolder( + mock(RepositoriesService.class), + transportService, + mock(IndicesService.class), + mock(PitService.class), + mock(ExtensionsManager.class) + ); + // CS-ENFORCE-SINGLE + + User user = new User("John Doe"); + threadPool.getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, user); + + AsyncSender sender = mock(AsyncSender.class); + String action = "testAction"; + TransportRequest request = mock(TransportRequest.class); + TransportRequestOptions options = mock(TransportRequestOptions.class); + TransportResponseHandler handler = mock(TransportResponseHandler.class); + + DiscoveryNode localNode = new DiscoveryNode("local-node", OpenSearchTestCase.buildNewFakeTransportAddress(), Version.CURRENT); + Connection connection1 = transportService.getConnection(localNode); + + DiscoveryNode otherNode = new DiscoveryNode("local-node", OpenSearchTestCase.buildNewFakeTransportAddress(), Version.CURRENT); + Connection connection2 = transportService.getConnection(otherNode); + + // setting localNode value explicitly + OpenSearchSecurityPlugin.setLocalNode(localNode); + + // isSameNodeRequest = true + securityInterceptor.sendRequestDecorate(sender, connection1, action, request, options, handler); + // from thread context inside sendRequestDecorate + doAnswer(i -> { + User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); + assertEquals(transientUser, user); + return null; + }).when(sender).sendRequest(any(Connection.class), eq(action), eq(request), eq(options), eq(handler)); + + // from original context + User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); + assertEquals(transientUser, user); + assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null); + + // isSameNodeRequest = false + securityInterceptor.sendRequestDecorate(sender, connection2, action, request, options, handler); + // checking thread context inside sendRequestDecorate + doAnswer(i -> { + String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); + assertEquals(serializedUserHeader, Base64Helper.serializeObject(user)); + return null; + }).when(sender).sendRequest(any(Connection.class), eq(action), eq(request), eq(options), eq(handler)); + + // from original context + User transientUser2 = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); + assertEquals(transientUser2, user); + assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null); + + } + +}