diff --git a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java index ffd7f730ed..362a46843e 100644 --- a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java +++ b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java @@ -45,6 +45,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Predicate; @@ -98,7 +99,6 @@ import org.opensearch.index.Index; import org.opensearch.index.IndexModule; import org.opensearch.index.cache.query.QueryCache; -import org.opensearch.index.shard.SearchOperationListener; import org.opensearch.indices.IndicesService; import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.indices.breaker.CircuitBreakerService; @@ -165,6 +165,7 @@ import org.opensearch.security.ssl.transport.SecuritySSLNettyTransport; import org.opensearch.security.ssl.util.SSLConfigConstants; import org.opensearch.security.support.ConfigConstants; +import org.opensearch.security.support.GuardedSearchOperationWrapper; import org.opensearch.security.support.HeaderHelper; import org.opensearch.security.support.ModuleInfo; import org.opensearch.security.support.ReflectionHelper; @@ -215,7 +216,7 @@ public final class OpenSearchSecurityPlugin extends OpenSearchSecuritySSLPlugin private final List demoCertHashes = new ArrayList(3); private volatile SecurityFilter sf; private volatile IndexResolverReplacer irr; - private volatile NamedXContentRegistry namedXContentRegistry = null; + private final AtomicReference namedXContentRegistry = new AtomicReference<>(NamedXContentRegistry.EMPTY);; private volatile DlsFlsRequestValve dlsFlsValve = null; private volatile Salt salt; private volatile OpensearchDynamicSetting transportPassiveAuthSetting; @@ -569,11 +570,11 @@ public Weight doCache(Weight weight, QueryCachingPolicy policy) { } }); - indexModule.addSearchOperationListener(new SearchOperationListener() { + indexModule.addSearchOperationListener(new GuardedSearchOperationWrapper() { @Override public void onPreQueryPhase(SearchContext context) { - dlsFlsValve.handleSearchContext(context, threadPool, namedXContentRegistry); + dlsFlsValve.handleSearchContext(context, threadPool, namedXContentRegistry.get()); } @Override @@ -643,7 +644,7 @@ public void onQueryPhase(SearchContext searchContext, long tookInNanos) { } } } - }); + }.toListener()); } } @@ -798,6 +799,7 @@ public Collection createComponents(Client localClient, ClusterService cl final PrivilegesInterceptor privilegesInterceptor; + namedXContentRegistry.set(xContentRegistry); if (SSLConfig.isSslOnlyMode()) { dlsFlsValve = new DlsFlsRequestValve.NoopDlsFlsRequestValve(); auditLog = new NullAuditLog(); @@ -822,7 +824,7 @@ public Collection createComponents(Client localClient, ClusterService cl // DLS-FLS is enabled if not client and not disabled and not SSL only. final boolean dlsFlsEnabled = !SSLConfig.isSslOnlyMode(); evaluator = new PrivilegesEvaluator(clusterService, threadPool, cr, resolver, auditLog, - settings, privilegesInterceptor, cih, irr, dlsFlsEnabled, namedXContentRegistry); + settings, privilegesInterceptor, cih, irr, dlsFlsEnabled, namedXContentRegistry.get()); sf = new SecurityFilter(settings, evaluator, adminDns, dlsFlsValve, auditLog, threadPool, cs, compatConfig, irr, xffResolver); diff --git a/src/main/java/org/opensearch/security/support/GuardedSearchOperationWrapper.java b/src/main/java/org/opensearch/security/support/GuardedSearchOperationWrapper.java new file mode 100644 index 0000000000..76d316de2d --- /dev/null +++ b/src/main/java/org/opensearch/security/support/GuardedSearchOperationWrapper.java @@ -0,0 +1,85 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import org.opensearch.index.shard.SearchOperationListener; +import org.opensearch.search.internal.ReaderContext; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.transport.TransportRequest; + +/** + * Guarded version of Search Operation Listener to ensure critical request paths succeed + */ +public interface GuardedSearchOperationWrapper { + + static final Logger log = LogManager.getLogger(GuardedSearchOperationWrapper.class); + + void onPreQueryPhase(final SearchContext context); + + void onNewReaderContext(final ReaderContext readerContext); + + void onNewScrollContext(final ReaderContext readerContext); + + void validateReaderContext(final ReaderContext readerContext, final TransportRequest transportRequest); + + void onQueryPhase(final SearchContext searchContext, final long tookInNanos); + + default SearchOperationListener toListener() { + return new InnerSearchOperationListener(this); + } + + static class InnerSearchOperationListener implements SearchOperationListener { + + private GuardedSearchOperationWrapper that; + InnerSearchOperationListener(GuardedSearchOperationWrapper that) { + this.that = that; + } + + @Override + public void onPreQueryPhase(final SearchContext searchContext) { + try { + that.onPreQueryPhase(searchContext); + } catch (final Exception e) { + searchContext.setTask(null); + log.error("Cancelled request due to internal error", e); + } + } + + @Override + public void onNewReaderContext(final ReaderContext readerContext) { + that.onNewReaderContext(readerContext); + } + + @Override + public void onNewScrollContext(final ReaderContext readerContext) { + that.onNewScrollContext(readerContext); + } + + @Override + public void validateReaderContext(final ReaderContext readerContext, final TransportRequest transportRequest) { + that.validateReaderContext(readerContext, transportRequest); + } + + @Override + public void onQueryPhase(final SearchContext searchContext, final long tookInNanos) { + try { + that.onQueryPhase(searchContext, tookInNanos); + } catch (final Exception e) { + searchContext.setTask(null); + log.error("Cancelled request due to internal error", e); + } + } + } +} diff --git a/src/test/java/org/opensearch/security/support/GuardedSearchOperationWrapperTest.java b/src/test/java/org/opensearch/security/support/GuardedSearchOperationWrapperTest.java new file mode 100644 index 0000000000..982d1108ad --- /dev/null +++ b/src/test/java/org/opensearch/security/support/GuardedSearchOperationWrapperTest.java @@ -0,0 +1,146 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +package org.opensearch.security.support; + +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.Test; + +import org.opensearch.index.shard.SearchOperationListener; +import org.opensearch.search.internal.ReaderContext; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.transport.TransportRequest; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + + +public class GuardedSearchOperationWrapperTest { + + @Test + public void onNewReaderContextCanThrowException() { + final String expectedExceptionText = "abcd1234"; + + DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() { + @Override + public void onNewReaderContext(ReaderContext readerContext) { + throw new RuntimeException(expectedExceptionText); + } + }; + + final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods); + + assertThat(expectedException.getMessage(), equalTo(expectedExceptionText)); + } + + @Test + public void onNewScrollContextCanThrowException() { + final String expectedExceptionText = "qwerty978"; + + DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() { + @Override + public void onNewScrollContext(ReaderContext readerContext) { + throw new RuntimeException(expectedExceptionText); + } + }; + + final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods); + + assertThat(expectedException.getMessage(), equalTo(expectedExceptionText)); + } + + @Test + public void validateReaderContextCanThrowException() { + final String expectedExceptionText = "validationException"; + + DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() { + @Override + public void validateReaderContext(ReaderContext readerContext, TransportRequest transportRequest) { + throw new RuntimeException(expectedExceptionText); + } + }; + + final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods); + + assertThat(expectedException.getMessage(), equalTo(expectedExceptionText)); + } + + @Test + public void onPreQueryPhaseCannotThrow() { + AtomicReference calledSearchContext = new AtomicReference<>(); + DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() { + @Override + public void onPreQueryPhase(SearchContext context) { + calledSearchContext.set(context); + throw new RuntimeException("EXCEPTIONAL!"); + } + }; + + testWrapper.exerciseAllMethods(); + + assertThat(calledSearchContext.get(), notNullValue()); + verify(calledSearchContext.get()).setTask(null); + } + + @Test + public void onQueryPhaseCannotThrow() { + AtomicReference calledSearchContext = new AtomicReference<>(); + DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() { + @Override + public void onQueryPhase(SearchContext context, long tookInNanos) { + calledSearchContext.set(context); + throw new RuntimeException("EXCEPTIONAL!"); + } + }; + + testWrapper.exerciseAllMethods(); + + assertThat(calledSearchContext.get(), notNullValue()); + verify(calledSearchContext.get()).setTask(null); + } + + /** Only use to make testing easier */ + private static class DefaultingGuardedSearchOperationWrapper implements GuardedSearchOperationWrapper { + + @Override + public void onNewReaderContext(ReaderContext readerContext) { + } + + @Override + public void onNewScrollContext(ReaderContext readerContext) { + } + + @Override + public void onPreQueryPhase(SearchContext context) { + } + + @Override + public void onQueryPhase(SearchContext searchContext, long tookInNanos) { + } + + @Override + public void validateReaderContext(ReaderContext readerContext, TransportRequest transportRequest) { + } + + void exerciseAllMethods(){ + final SearchOperationListener sol = this.toListener(); + sol.onNewReaderContext(mock(ReaderContext.class)); + sol.onNewScrollContext(mock(ReaderContext.class)); + sol.onPreQueryPhase(mock(SearchContext.class)); + sol.onQueryPhase(mock(SearchContext.class), 12345L); + sol.validateReaderContext(mock(ReaderContext.class), mock(TransportRequest.class)); + } + } +}