From fc15eb7786f8a5fb4c6af5b641c9f79b74fbecd1 Mon Sep 17 00:00:00 2001 From: Craig Perkins Date: Tue, 18 Apr 2023 17:12:21 -0400 Subject: [PATCH] Fix NPE and add additional graceful error handling (#2691) Signed-off-by: Craig Perkins (cherry picked from commit 28086604c909e609f27f0d4f19464714023808b6) --- .../security/OpenSearchSecurityPlugin.java | 32 +++- .../GuardedSearchOperationWrapper.java | 85 ++++++++++ .../GuardedSearchOperationWrapperTest.java | 146 ++++++++++++++++++ 3 files changed, 257 insertions(+), 6 deletions(-) create mode 100644 src/main/java/org/opensearch/security/support/GuardedSearchOperationWrapper.java create mode 100644 src/test/java/org/opensearch/security/support/GuardedSearchOperationWrapperTest.java diff --git a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java index 2f5c8494cd..4a1f5f5227 100644 --- a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java +++ b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java @@ -39,7 +39,18 @@ import java.security.MessageDigest; import java.security.PrivilegedAction; import java.security.Security; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; @@ -107,7 +118,7 @@ import org.opensearch.index.Index; import org.opensearch.index.IndexModule; import org.opensearch.index.cache.query.QueryCache; -import org.opensearch.index.shard.SearchOperationListener; +import org.opensearch.indices.IndicesService; import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.indices.breaker.CircuitBreakerService; import org.opensearch.plugins.ClusterPlugin; @@ -144,6 +155,14 @@ import org.opensearch.security.ssl.SslExceptionHandler; import org.opensearch.security.ssl.transport.SecuritySSLNettyTransport; import org.opensearch.security.ssl.util.SSLConfigConstants; +import org.opensearch.security.support.ConfigConstants; +import org.opensearch.security.support.GuardedSearchOperationWrapper; +import org.opensearch.security.support.HeaderHelper; +import org.opensearch.security.support.ModuleInfo; +import org.opensearch.security.support.ReflectionHelper; +import org.opensearch.security.support.SecuritySettings; +import org.opensearch.security.support.SecurityUtils; +import org.opensearch.security.support.WildcardMatcher; import org.opensearch.security.transport.DefaultInterClusterRequestEvaluator; import org.opensearch.security.transport.InterClusterRequestEvaluator; import org.opensearch.security.user.User; @@ -202,7 +221,7 @@ public final class OpenSearchSecurityPlugin extends OpenSearchSecuritySSLPlugin private final List demoCertHashes = new ArrayList(3); private volatile SecurityFilter sf; private volatile IndexResolverReplacer irr; - private volatile NamedXContentRegistry namedXContentRegistry = null; + private final AtomicReference namedXContentRegistry = new AtomicReference<>(NamedXContentRegistry.EMPTY);; private volatile DlsFlsRequestValve dlsFlsValve = null; private volatile Salt salt; private volatile OpensearchDynamicSetting transportPassiveAuthSetting; @@ -543,11 +562,11 @@ public Weight doCache(Weight weight, QueryCachingPolicy policy) { } }); - indexModule.addSearchOperationListener(new SearchOperationListener() { + indexModule.addSearchOperationListener(new GuardedSearchOperationWrapper() { @Override public void onPreQueryPhase(SearchContext context) { - dlsFlsValve.handleSearchContext(context, threadPool, namedXContentRegistry); + dlsFlsValve.handleSearchContext(context, threadPool, namedXContentRegistry.get()); } @Override @@ -617,7 +636,7 @@ public void onQueryPhase(SearchContext searchContext, long tookInNanos) { } } } - }); + }.toListener()); } } @@ -771,6 +790,7 @@ public Collection createComponents(Client localClient, ClusterService cl final PrivilegesInterceptor privilegesInterceptor; + namedXContentRegistry.set(xContentRegistry); if (SSLConfig.isSslOnlyMode()) { dlsFlsValve = new DlsFlsRequestValve.NoopDlsFlsRequestValve(); auditLog = new NullAuditLog(); diff --git a/src/main/java/org/opensearch/security/support/GuardedSearchOperationWrapper.java b/src/main/java/org/opensearch/security/support/GuardedSearchOperationWrapper.java new file mode 100644 index 0000000000..76d316de2d --- /dev/null +++ b/src/main/java/org/opensearch/security/support/GuardedSearchOperationWrapper.java @@ -0,0 +1,85 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import org.opensearch.index.shard.SearchOperationListener; +import org.opensearch.search.internal.ReaderContext; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.transport.TransportRequest; + +/** + * Guarded version of Search Operation Listener to ensure critical request paths succeed + */ +public interface GuardedSearchOperationWrapper { + + static final Logger log = LogManager.getLogger(GuardedSearchOperationWrapper.class); + + void onPreQueryPhase(final SearchContext context); + + void onNewReaderContext(final ReaderContext readerContext); + + void onNewScrollContext(final ReaderContext readerContext); + + void validateReaderContext(final ReaderContext readerContext, final TransportRequest transportRequest); + + void onQueryPhase(final SearchContext searchContext, final long tookInNanos); + + default SearchOperationListener toListener() { + return new InnerSearchOperationListener(this); + } + + static class InnerSearchOperationListener implements SearchOperationListener { + + private GuardedSearchOperationWrapper that; + InnerSearchOperationListener(GuardedSearchOperationWrapper that) { + this.that = that; + } + + @Override + public void onPreQueryPhase(final SearchContext searchContext) { + try { + that.onPreQueryPhase(searchContext); + } catch (final Exception e) { + searchContext.setTask(null); + log.error("Cancelled request due to internal error", e); + } + } + + @Override + public void onNewReaderContext(final ReaderContext readerContext) { + that.onNewReaderContext(readerContext); + } + + @Override + public void onNewScrollContext(final ReaderContext readerContext) { + that.onNewScrollContext(readerContext); + } + + @Override + public void validateReaderContext(final ReaderContext readerContext, final TransportRequest transportRequest) { + that.validateReaderContext(readerContext, transportRequest); + } + + @Override + public void onQueryPhase(final SearchContext searchContext, final long tookInNanos) { + try { + that.onQueryPhase(searchContext, tookInNanos); + } catch (final Exception e) { + searchContext.setTask(null); + log.error("Cancelled request due to internal error", e); + } + } + } +} diff --git a/src/test/java/org/opensearch/security/support/GuardedSearchOperationWrapperTest.java b/src/test/java/org/opensearch/security/support/GuardedSearchOperationWrapperTest.java new file mode 100644 index 0000000000..982d1108ad --- /dev/null +++ b/src/test/java/org/opensearch/security/support/GuardedSearchOperationWrapperTest.java @@ -0,0 +1,146 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +package org.opensearch.security.support; + +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.Test; + +import org.opensearch.index.shard.SearchOperationListener; +import org.opensearch.search.internal.ReaderContext; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.transport.TransportRequest; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + + +public class GuardedSearchOperationWrapperTest { + + @Test + public void onNewReaderContextCanThrowException() { + final String expectedExceptionText = "abcd1234"; + + DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() { + @Override + public void onNewReaderContext(ReaderContext readerContext) { + throw new RuntimeException(expectedExceptionText); + } + }; + + final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods); + + assertThat(expectedException.getMessage(), equalTo(expectedExceptionText)); + } + + @Test + public void onNewScrollContextCanThrowException() { + final String expectedExceptionText = "qwerty978"; + + DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() { + @Override + public void onNewScrollContext(ReaderContext readerContext) { + throw new RuntimeException(expectedExceptionText); + } + }; + + final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods); + + assertThat(expectedException.getMessage(), equalTo(expectedExceptionText)); + } + + @Test + public void validateReaderContextCanThrowException() { + final String expectedExceptionText = "validationException"; + + DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() { + @Override + public void validateReaderContext(ReaderContext readerContext, TransportRequest transportRequest) { + throw new RuntimeException(expectedExceptionText); + } + }; + + final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods); + + assertThat(expectedException.getMessage(), equalTo(expectedExceptionText)); + } + + @Test + public void onPreQueryPhaseCannotThrow() { + AtomicReference calledSearchContext = new AtomicReference<>(); + DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() { + @Override + public void onPreQueryPhase(SearchContext context) { + calledSearchContext.set(context); + throw new RuntimeException("EXCEPTIONAL!"); + } + }; + + testWrapper.exerciseAllMethods(); + + assertThat(calledSearchContext.get(), notNullValue()); + verify(calledSearchContext.get()).setTask(null); + } + + @Test + public void onQueryPhaseCannotThrow() { + AtomicReference calledSearchContext = new AtomicReference<>(); + DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() { + @Override + public void onQueryPhase(SearchContext context, long tookInNanos) { + calledSearchContext.set(context); + throw new RuntimeException("EXCEPTIONAL!"); + } + }; + + testWrapper.exerciseAllMethods(); + + assertThat(calledSearchContext.get(), notNullValue()); + verify(calledSearchContext.get()).setTask(null); + } + + /** Only use to make testing easier */ + private static class DefaultingGuardedSearchOperationWrapper implements GuardedSearchOperationWrapper { + + @Override + public void onNewReaderContext(ReaderContext readerContext) { + } + + @Override + public void onNewScrollContext(ReaderContext readerContext) { + } + + @Override + public void onPreQueryPhase(SearchContext context) { + } + + @Override + public void onQueryPhase(SearchContext searchContext, long tookInNanos) { + } + + @Override + public void validateReaderContext(ReaderContext readerContext, TransportRequest transportRequest) { + } + + void exerciseAllMethods(){ + final SearchOperationListener sol = this.toListener(); + sol.onNewReaderContext(mock(ReaderContext.class)); + sol.onNewScrollContext(mock(ReaderContext.class)); + sol.onPreQueryPhase(mock(SearchContext.class)); + sol.onQueryPhase(mock(SearchContext.class), 12345L); + sol.validateReaderContext(mock(ReaderContext.class), mock(TransportRequest.class)); + } + } +}