Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NPE and add additional graceful error handling #2687

Merged
merged 5 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
Expand Down Expand Up @@ -98,7 +99,6 @@
import org.opensearch.index.Index;
import org.opensearch.index.IndexModule;
import org.opensearch.index.cache.query.QueryCache;
import org.opensearch.index.shard.SearchOperationListener;
import org.opensearch.indices.IndicesService;
import org.opensearch.indices.SystemIndexDescriptor;
import org.opensearch.indices.breaker.CircuitBreakerService;
Expand Down Expand Up @@ -165,6 +165,7 @@
import org.opensearch.security.ssl.transport.SecuritySSLNettyTransport;
import org.opensearch.security.ssl.util.SSLConfigConstants;
import org.opensearch.security.support.ConfigConstants;
import org.opensearch.security.support.GuardedSearchOperationWrapper;
import org.opensearch.security.support.HeaderHelper;
import org.opensearch.security.support.ModuleInfo;
import org.opensearch.security.support.ReflectionHelper;
Expand Down Expand Up @@ -215,7 +216,7 @@ public final class OpenSearchSecurityPlugin extends OpenSearchSecuritySSLPlugin
private final List<String> demoCertHashes = new ArrayList<String>(3);
private volatile SecurityFilter sf;
private volatile IndexResolverReplacer irr;
private volatile NamedXContentRegistry namedXContentRegistry = null;
private volatile AtomicReference<NamedXContentRegistry> namedXContentRegistry = new AtomicReference<>(NamedXContentRegistry.EMPTY);;
cwperks marked this conversation as resolved.
Show resolved Hide resolved
private volatile DlsFlsRequestValve dlsFlsValve = null;
private volatile Salt salt;
private volatile OpensearchDynamicSetting<Boolean> transportPassiveAuthSetting;
Expand Down Expand Up @@ -569,11 +570,11 @@ public Weight doCache(Weight weight, QueryCachingPolicy policy) {
}
});

indexModule.addSearchOperationListener(new SearchOperationListener() {
indexModule.addSearchOperationListener(new GuardedSearchOperationWrapper() {

@Override
public void onPreQueryPhase(SearchContext context) {
dlsFlsValve.handleSearchContext(context, threadPool, namedXContentRegistry);
dlsFlsValve.handleSearchContext(context, threadPool, namedXContentRegistry.get());
}

@Override
Expand Down Expand Up @@ -643,7 +644,7 @@ public void onQueryPhase(SearchContext searchContext, long tookInNanos) {
}
}
}
});
}.toListener());
}
}

Expand Down Expand Up @@ -798,6 +799,7 @@ public Collection<Object> createComponents(Client localClient, ClusterService cl

final PrivilegesInterceptor privilegesInterceptor;

namedXContentRegistry.set(xContentRegistry);
if (SSLConfig.isSslOnlyMode()) {
dlsFlsValve = new DlsFlsRequestValve.NoopDlsFlsRequestValve();
auditLog = new NullAuditLog();
Expand All @@ -822,7 +824,7 @@ public Collection<Object> createComponents(Client localClient, ClusterService cl
// DLS-FLS is enabled if not client and not disabled and not SSL only.
final boolean dlsFlsEnabled = !SSLConfig.isSslOnlyMode();
evaluator = new PrivilegesEvaluator(clusterService, threadPool, cr, resolver, auditLog,
settings, privilegesInterceptor, cih, irr, dlsFlsEnabled, namedXContentRegistry);
settings, privilegesInterceptor, cih, irr, dlsFlsEnabled, namedXContentRegistry.get());

sf = new SecurityFilter(settings, evaluator, adminDns, dlsFlsValve, auditLog, threadPool, cs, compatConfig, irr, xffResolver);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.security.support;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.index.shard.SearchOperationListener;
import org.opensearch.search.internal.ReaderContext;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.transport.TransportRequest;

/**
* Guarded version of Search Operation Listener to ensure critical request paths succeed
*/
public interface GuardedSearchOperationWrapper {

static final Logger log = LogManager.getLogger(GuardedSearchOperationWrapper.class);

void onPreQueryPhase(final SearchContext context);

void onNewReaderContext(final ReaderContext readerContext);

void onNewScrollContext(final ReaderContext readerContext);

void validateReaderContext(final ReaderContext readerContext, final TransportRequest transportRequest);

void onQueryPhase(final SearchContext searchContext, final long tookInNanos);

default SearchOperationListener toListener() {
return new InnerSearchOperationListener(this);
}

static class InnerSearchOperationListener implements SearchOperationListener {

private GuardedSearchOperationWrapper that;
InnerSearchOperationListener(GuardedSearchOperationWrapper that) {
this.that = that;
}

@Override
public void onPreQueryPhase(final SearchContext searchContext) {
try {
that.onPreQueryPhase(searchContext);
} catch (final Exception e) {
searchContext.setTask(null);
log.error("Cancelled request due to internal error", e);
}
}

@Override
public void onNewReaderContext(final ReaderContext readerContext) {
that.onNewReaderContext(readerContext);
}

@Override
public void onNewScrollContext(final ReaderContext readerContext) {
that.onNewScrollContext(readerContext);
}

@Override
public void validateReaderContext(final ReaderContext readerContext, final TransportRequest transportRequest) {
that.validateReaderContext(readerContext, transportRequest);
}

@Override
public void onQueryPhase(final SearchContext searchContext, final long tookInNanos) {
try {
that.onQueryPhase(searchContext, tookInNanos);
} catch (final Exception e) {
searchContext.setTask(null);
log.error("Cancelled request due to internal error", e);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
package org.opensearch.security.support;

import java.util.concurrent.atomic.AtomicReference;

import org.junit.Test;

import org.opensearch.index.shard.SearchOperationListener;
import org.opensearch.search.internal.ReaderContext;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.transport.TransportRequest;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;


public class GuardedSearchOperationWrapperTest {

@Test
public void onNewReaderContextCanThrowException() {
final String expectedExceptionText = "abcd1234";

DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
@Override
public void onNewReaderContext(ReaderContext readerContext) {
throw new RuntimeException(expectedExceptionText);
}
};

final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods);

assertThat(expectedException.getMessage(), equalTo(expectedExceptionText));
}

@Test
public void onNewScrollContextCanThrowException() {
final String expectedExceptionText = "qwerty978";

DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
@Override
public void onNewScrollContext(ReaderContext readerContext) {
throw new RuntimeException(expectedExceptionText);
}
};

final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods);

assertThat(expectedException.getMessage(), equalTo(expectedExceptionText));
}

@Test
public void validateReaderContextCanThrowException() {
final String expectedExceptionText = "validationException";

DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
@Override
public void validateReaderContext(ReaderContext readerContext, TransportRequest transportRequest) {
throw new RuntimeException(expectedExceptionText);
}
};

final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods);

assertThat(expectedException.getMessage(), equalTo(expectedExceptionText));
}

@Test
public void onPreQueryPhaseCannotThrow() {
AtomicReference<SearchContext> calledSearchContext = new AtomicReference<>();
DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
@Override
public void onPreQueryPhase(SearchContext context) {
calledSearchContext.set(context);
throw new RuntimeException("EXCEPTIONAL!");
}
};

testWrapper.exerciseAllMethods();

assertThat(calledSearchContext.get(), notNullValue());
verify(calledSearchContext.get()).setTask(null);
}

@Test
public void onQueryPhaseCannotThrow() {
AtomicReference<SearchContext> calledSearchContext = new AtomicReference<>();
DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
@Override
public void onQueryPhase(SearchContext context, long tookInNanos) {
calledSearchContext.set(context);
throw new RuntimeException("EXCEPTIONAL!");
}
};

testWrapper.exerciseAllMethods();

assertThat(calledSearchContext.get(), notNullValue());
verify(calledSearchContext.get()).setTask(null);
}

/** Only use to make testing easier */
private static class DefaultingGuardedSearchOperationWrapper implements GuardedSearchOperationWrapper {

@Override
public void onNewReaderContext(ReaderContext readerContext) {
}

@Override
public void onNewScrollContext(ReaderContext readerContext) {
}

@Override
public void onPreQueryPhase(SearchContext context) {
}

@Override
public void onQueryPhase(SearchContext searchContext, long tookInNanos) {
}

@Override
public void validateReaderContext(ReaderContext readerContext, TransportRequest transportRequest) {
}

void exerciseAllMethods(){
final SearchOperationListener sol = this.toListener();
sol.onNewReaderContext(mock(ReaderContext.class));
sol.onNewScrollContext(mock(ReaderContext.class));
sol.onPreQueryPhase(mock(SearchContext.class));
sol.onQueryPhase(mock(SearchContext.class), 12345L);
sol.validateReaderContext(mock(ReaderContext.class), mock(TransportRequest.class));
}
}
}