Skip to content

Commit

Permalink
Fix NPE and add additional graceful error handling (#2691)
Browse files Browse the repository at this point in the history
Signed-off-by: Craig Perkins <[email protected]>
  • Loading branch information
cwperks authored Apr 18, 2023
1 parent c53f680 commit 2808660
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,18 @@
import java.security.MessageDigest;
import java.security.PrivilegedAction;
import java.security.Security;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
Expand Down Expand Up @@ -107,7 +118,7 @@
import org.opensearch.index.Index;
import org.opensearch.index.IndexModule;
import org.opensearch.index.cache.query.QueryCache;
import org.opensearch.index.shard.SearchOperationListener;
import org.opensearch.indices.IndicesService;
import org.opensearch.indices.SystemIndexDescriptor;
import org.opensearch.indices.breaker.CircuitBreakerService;
import org.opensearch.plugins.ClusterPlugin;
Expand Down Expand Up @@ -144,6 +155,14 @@
import org.opensearch.security.ssl.SslExceptionHandler;
import org.opensearch.security.ssl.transport.SecuritySSLNettyTransport;
import org.opensearch.security.ssl.util.SSLConfigConstants;
import org.opensearch.security.support.ConfigConstants;
import org.opensearch.security.support.GuardedSearchOperationWrapper;
import org.opensearch.security.support.HeaderHelper;
import org.opensearch.security.support.ModuleInfo;
import org.opensearch.security.support.ReflectionHelper;
import org.opensearch.security.support.SecuritySettings;
import org.opensearch.security.support.SecurityUtils;
import org.opensearch.security.support.WildcardMatcher;
import org.opensearch.security.transport.DefaultInterClusterRequestEvaluator;
import org.opensearch.security.transport.InterClusterRequestEvaluator;
import org.opensearch.security.user.User;
Expand Down Expand Up @@ -202,7 +221,7 @@ public final class OpenSearchSecurityPlugin extends OpenSearchSecuritySSLPlugin
private final List<String> demoCertHashes = new ArrayList<String>(3);
private volatile SecurityFilter sf;
private volatile IndexResolverReplacer irr;
private volatile NamedXContentRegistry namedXContentRegistry = null;
private final AtomicReference<NamedXContentRegistry> namedXContentRegistry = new AtomicReference<>(NamedXContentRegistry.EMPTY);;
private volatile DlsFlsRequestValve dlsFlsValve = null;
private volatile Salt salt;
private volatile OpensearchDynamicSetting<Boolean> transportPassiveAuthSetting;
Expand Down Expand Up @@ -543,11 +562,11 @@ public Weight doCache(Weight weight, QueryCachingPolicy policy) {
}
});

indexModule.addSearchOperationListener(new SearchOperationListener() {
indexModule.addSearchOperationListener(new GuardedSearchOperationWrapper() {

@Override
public void onPreQueryPhase(SearchContext context) {
dlsFlsValve.handleSearchContext(context, threadPool, namedXContentRegistry);
dlsFlsValve.handleSearchContext(context, threadPool, namedXContentRegistry.get());
}

@Override
Expand Down Expand Up @@ -617,7 +636,7 @@ public void onQueryPhase(SearchContext searchContext, long tookInNanos) {
}
}
}
});
}.toListener());
}
}

Expand Down Expand Up @@ -771,6 +790,7 @@ public Collection<Object> createComponents(Client localClient, ClusterService cl

final PrivilegesInterceptor privilegesInterceptor;

namedXContentRegistry.set(xContentRegistry);
if (SSLConfig.isSslOnlyMode()) {
dlsFlsValve = new DlsFlsRequestValve.NoopDlsFlsRequestValve();
auditLog = new NullAuditLog();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.security.support;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.index.shard.SearchOperationListener;
import org.opensearch.search.internal.ReaderContext;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.transport.TransportRequest;

/**
* Guarded version of Search Operation Listener to ensure critical request paths succeed
*/
public interface GuardedSearchOperationWrapper {

static final Logger log = LogManager.getLogger(GuardedSearchOperationWrapper.class);

void onPreQueryPhase(final SearchContext context);

void onNewReaderContext(final ReaderContext readerContext);

void onNewScrollContext(final ReaderContext readerContext);

void validateReaderContext(final ReaderContext readerContext, final TransportRequest transportRequest);

void onQueryPhase(final SearchContext searchContext, final long tookInNanos);

default SearchOperationListener toListener() {
return new InnerSearchOperationListener(this);
}

static class InnerSearchOperationListener implements SearchOperationListener {

private GuardedSearchOperationWrapper that;
InnerSearchOperationListener(GuardedSearchOperationWrapper that) {
this.that = that;
}

@Override
public void onPreQueryPhase(final SearchContext searchContext) {
try {
that.onPreQueryPhase(searchContext);
} catch (final Exception e) {
searchContext.setTask(null);
log.error("Cancelled request due to internal error", e);
}
}

@Override
public void onNewReaderContext(final ReaderContext readerContext) {
that.onNewReaderContext(readerContext);
}

@Override
public void onNewScrollContext(final ReaderContext readerContext) {
that.onNewScrollContext(readerContext);
}

@Override
public void validateReaderContext(final ReaderContext readerContext, final TransportRequest transportRequest) {
that.validateReaderContext(readerContext, transportRequest);
}

@Override
public void onQueryPhase(final SearchContext searchContext, final long tookInNanos) {
try {
that.onQueryPhase(searchContext, tookInNanos);
} catch (final Exception e) {
searchContext.setTask(null);
log.error("Cancelled request due to internal error", e);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
package org.opensearch.security.support;

import java.util.concurrent.atomic.AtomicReference;

import org.junit.Test;

import org.opensearch.index.shard.SearchOperationListener;
import org.opensearch.search.internal.ReaderContext;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.transport.TransportRequest;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;


public class GuardedSearchOperationWrapperTest {

@Test
public void onNewReaderContextCanThrowException() {
final String expectedExceptionText = "abcd1234";

DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
@Override
public void onNewReaderContext(ReaderContext readerContext) {
throw new RuntimeException(expectedExceptionText);
}
};

final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods);

assertThat(expectedException.getMessage(), equalTo(expectedExceptionText));
}

@Test
public void onNewScrollContextCanThrowException() {
final String expectedExceptionText = "qwerty978";

DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
@Override
public void onNewScrollContext(ReaderContext readerContext) {
throw new RuntimeException(expectedExceptionText);
}
};

final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods);

assertThat(expectedException.getMessage(), equalTo(expectedExceptionText));
}

@Test
public void validateReaderContextCanThrowException() {
final String expectedExceptionText = "validationException";

DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
@Override
public void validateReaderContext(ReaderContext readerContext, TransportRequest transportRequest) {
throw new RuntimeException(expectedExceptionText);
}
};

final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods);

assertThat(expectedException.getMessage(), equalTo(expectedExceptionText));
}

@Test
public void onPreQueryPhaseCannotThrow() {
AtomicReference<SearchContext> calledSearchContext = new AtomicReference<>();
DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
@Override
public void onPreQueryPhase(SearchContext context) {
calledSearchContext.set(context);
throw new RuntimeException("EXCEPTIONAL!");
}
};

testWrapper.exerciseAllMethods();

assertThat(calledSearchContext.get(), notNullValue());
verify(calledSearchContext.get()).setTask(null);
}

@Test
public void onQueryPhaseCannotThrow() {
AtomicReference<SearchContext> calledSearchContext = new AtomicReference<>();
DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
@Override
public void onQueryPhase(SearchContext context, long tookInNanos) {
calledSearchContext.set(context);
throw new RuntimeException("EXCEPTIONAL!");
}
};

testWrapper.exerciseAllMethods();

assertThat(calledSearchContext.get(), notNullValue());
verify(calledSearchContext.get()).setTask(null);
}

/** Only use to make testing easier */
private static class DefaultingGuardedSearchOperationWrapper implements GuardedSearchOperationWrapper {

@Override
public void onNewReaderContext(ReaderContext readerContext) {
}

@Override
public void onNewScrollContext(ReaderContext readerContext) {
}

@Override
public void onPreQueryPhase(SearchContext context) {
}

@Override
public void onQueryPhase(SearchContext searchContext, long tookInNanos) {
}

@Override
public void validateReaderContext(ReaderContext readerContext, TransportRequest transportRequest) {
}

void exerciseAllMethods(){
final SearchOperationListener sol = this.toListener();
sol.onNewReaderContext(mock(ReaderContext.class));
sol.onNewScrollContext(mock(ReaderContext.class));
sol.onPreQueryPhase(mock(SearchContext.class));
sol.onQueryPhase(mock(SearchContext.class), 12345L);
sol.validateReaderContext(mock(ReaderContext.class), mock(TransportRequest.class));
}
}
}

0 comments on commit 2808660

Please sign in to comment.