Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Refactors reRequestAuthentication to call notifyIpAuthFailureListener before sending the response to the channel (#3411) #3445

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.concurrent.TimeUnit;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -36,8 +35,8 @@
@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class)
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class IpBruteForceAttacksPreventionTests {
private static final User USER_1 = new User("simple-user-1").roles(ALL_ACCESS);
private static final User USER_2 = new User("simple-user-2").roles(ALL_ACCESS);
static final User USER_1 = new User("simple-user-1").roles(ALL_ACCESS);
static final User USER_2 = new User("simple-user-2").roles(ALL_ACCESS);

public static final int ALLOWED_TRIES = 3;
public static final int TIME_WINDOW_SECONDS = 3;
Expand All @@ -51,7 +50,7 @@ public class IpBruteForceAttacksPreventionTests {
public static final String CLIENT_IP_8 = "127.0.0.8";
public static final String CLIENT_IP_9 = "127.0.0.9";

private static final AuthFailureListeners listener = new AuthFailureListeners().addRateLimit(
static final AuthFailureListeners listener = new AuthFailureListeners().addRateLimit(
new RateLimiting("internal_authentication_backend_limiting").type("ip")
.allowedTries(ALLOWED_TRIES)
.timeWindowSeconds(TIME_WINDOW_SECONDS)
Expand All @@ -60,13 +59,17 @@ public class IpBruteForceAttacksPreventionTests {
.maxTrackedClients(500)
);

@ClassRule
public static final LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE)
.anonymousAuth(false)
.authFailureListeners(listener)
.authc(AUTHC_HTTPBASIC_INTERNAL_WITHOUT_CHALLENGE)
.users(USER_1, USER_2)
.build();
@Rule
public LocalCluster cluster = createCluster();

public LocalCluster createCluster() {
return new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE)
.anonymousAuth(false)
.authFailureListeners(listener)
.authc(AUTHC_HTTPBASIC_INTERNAL_WITHOUT_CHALLENGE)
.users(USER_1, USER_2)
.build();
}

@Rule
public LogsRule logsRule = new LogsRule("org.opensearch.security.auth.BackendRegistry");
Expand Down Expand Up @@ -151,7 +154,7 @@ public void shouldReleaseIpAddressLock() throws InterruptedException {
}
}

private static void authenticateUserWithIncorrectPassword(String sourceIpAddress, User user, int numberOfRequests) {
void authenticateUserWithIncorrectPassword(String sourceIpAddress, User user, int numberOfRequests) {
var clientConfiguration = new TestRestClientConfiguration().username(user.getName())
.password("incorrect password")
.sourceInetAddress(sourceIpAddress);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
*/

package org.opensearch.security;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import org.junit.runner.RunWith;
import org.opensearch.test.framework.cluster.ClusterManager;
import org.opensearch.test.framework.cluster.LocalCluster;

import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL;

@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class)
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class IpBruteForceAttacksPreventionWithDomainChallengeTests extends IpBruteForceAttacksPreventionTests {
@Override
public LocalCluster createCluster() {
return new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE)
.anonymousAuth(false)
.authFailureListeners(listener)
.authc(AUTHC_HTTPBASIC_INTERNAL)
.users(USER_1, USER_2)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ public String toString() {
String clusterManagerNodes = nodeByTypeToString(CLUSTER_MANAGER);
String dataNodes = nodeByTypeToString(DATA);
String clientNodes = nodeByTypeToString(CLIENT);
return "\nES Cluster "
return "\nOS Cluster "
+ clusterName
+ "\ncluster manager nodes: "
+ clusterManagerNodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.common.Strings;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.security.auth.HTTPAuthenticator;
Expand Down Expand Up @@ -236,11 +235,10 @@ public String[] extractRoles(JwtClaims claims) {
protected abstract KeyProvider initKeyProvider(Settings settings, Path configPath) throws Exception;

@Override
public boolean reRequestAuthentication(RestChannel channel, AuthCredentials authCredentials) {
public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials authCredentials) {
final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, "");
wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\"");
channel.sendResponse(wwwAuthenticateResponse);
return true;
return wwwAuthenticateResponse;
}

public String getRequiredAudience() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.security.auth.HTTPAuthenticator;
Expand Down Expand Up @@ -224,11 +223,10 @@ private AuthCredentials extractCredentials0(final RestRequest request) {
}

@Override
public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) {
public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) {
final BytesRestResponse wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, "");
wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Bearer realm=\"OpenSearch Security\"");
channel.sendResponse(wwwAuthenticateResponse);
return true;
return wwwAuthenticateResponse;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.env.Environment;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.security.auth.HTTPAuthenticator;
Expand Down Expand Up @@ -280,8 +279,7 @@ public GSSCredential run() throws GSSException {
}

@Override
public boolean reRequestAuthentication(final RestChannel channel, AuthCredentials creds) {

public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) {
final BytesRestResponse wwwAuthenticateResponse;
XContentBuilder response = getNegotiateResponseBody();

Expand All @@ -291,16 +289,15 @@ public boolean reRequestAuthentication(final RestChannel channel, AuthCredential
wwwAuthenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, EMPTY_STRING);
}

if (creds == null || creds.getNativeCredentials() == null) {
if (credentials == null || credentials.getNativeCredentials() == null) {
wwwAuthenticateResponse.addHeader("WWW-Authenticate", "Negotiate");
} else {
wwwAuthenticateResponse.addHeader(
"WWW-Authenticate",
"Negotiate " + Base64.getEncoder().encodeToString((byte[]) creds.getNativeCredentials())
"Negotiate " + Base64.getEncoder().encodeToString((byte[]) credentials.getNativeCredentials())
);
}
channel.sendResponse(wwwAuthenticateResponse);
return true;
return wwwAuthenticateResponse;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import javax.xml.parsers.ParserConfigurationException;
import javax.xml.xpath.XPathExpressionException;

import com.fasterxml.jackson.core.JsonParseException;
Expand All @@ -32,7 +31,6 @@
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.base.Strings;
import com.onelogin.saml2.authn.SamlResponse;
import com.onelogin.saml2.exception.SettingsException;
import com.onelogin.saml2.exception.ValidationError;
import com.onelogin.saml2.settings.Saml2Settings;
import com.onelogin.saml2.util.Util;
Expand All @@ -49,15 +47,13 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.joda.time.DateTime;
import org.xml.sax.SAXException;

import org.opensearch.OpenSearchSecurityException;
import org.opensearch.SpecialPermission;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestRequest.Method;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -122,19 +118,18 @@ class AuthTokenProcessorHandler {
}

@SuppressWarnings("removal")
boolean handle(RestRequest restRequest, RestChannel restChannel) throws Exception {
BytesRestResponse handle(RestRequest restRequest) throws Exception {
try {
final SecurityManager sm = System.getSecurityManager();

if (sm != null) {
sm.checkPermission(new SpecialPermission());
}

return AccessController.doPrivileged(new PrivilegedExceptionAction<Boolean>() {
return AccessController.doPrivileged(new PrivilegedExceptionAction<BytesRestResponse>() {
@Override
public Boolean run() throws XPathExpressionException, SamlConfigException, IOException, ParserConfigurationException,
SAXException, SettingsException {
return handleLowLevel(restRequest, restChannel);
public BytesRestResponse run() throws SamlConfigException, IOException {
return handleLowLevel(restRequest);
}
});
} catch (PrivilegedActionException e) {
Expand All @@ -147,13 +142,11 @@ public Boolean run() throws XPathExpressionException, SamlConfigException, IOExc
}

private AuthTokenProcessorAction.Response handleImpl(
RestRequest restRequest,
RestChannel restChannel,
String samlResponseBase64,
String samlRequestId,
String acsEndpoint,
Saml2Settings saml2Settings
) throws XPathExpressionException, ParserConfigurationException, SAXException, IOException, SettingsException {
) {
if (token_log.isDebugEnabled()) {
try {
token_log.debug(
Expand Down Expand Up @@ -188,8 +181,7 @@ private AuthTokenProcessorAction.Response handleImpl(
}
}

private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel) throws SamlConfigException, IOException,
XPathExpressionException, ParserConfigurationException, SAXException, SettingsException {
private BytesRestResponse handleLowLevel(RestRequest restRequest) throws SamlConfigException, IOException {
try {

if (restRequest.getMediaType() != XContentType.JSON) {
Expand Down Expand Up @@ -234,31 +226,18 @@ private boolean handleLowLevel(RestRequest restRequest, RestChannel restChannel)
acsEndpoint = getAbsoluteAcsEndpoint(((ObjectNode) jsonRoot).get("acsEndpoint").textValue());
}

AuthTokenProcessorAction.Response responseBody = this.handleImpl(
restRequest,
restChannel,
samlResponseBase64,
samlRequestId,
acsEndpoint,
saml2Settings
);
AuthTokenProcessorAction.Response responseBody = this.handleImpl(samlResponseBase64, samlRequestId, acsEndpoint, saml2Settings);

if (responseBody == null) {
return false;
return null;
}

String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString(responseBody);

BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString);
restChannel.sendResponse(authenticateResponse);

return true;
return new BytesRestResponse(RestStatus.OK, "application/json", responseBodyString);
} catch (JsonProcessingException e) {
log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e);

BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed");
restChannel.sendResponse(authenticateResponse);
return true;
return new BytesRestResponse(RestStatus.BAD_REQUEST, "JSON could not be parsed");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.security.auth.Destroyable;
Expand Down Expand Up @@ -171,26 +170,26 @@ public String getType() {
}

@Override
public boolean reRequestAuthentication(RestChannel restChannel, AuthCredentials authCredentials) {
public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) {
try {
RestRequest restRequest = restChannel.request();
Matcher matcher = PATTERN_PATH_PREFIX.matcher(restRequest.path());
Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path());
final String suffix = matcher.matches() ? matcher.group(2) : null;
if (API_AUTHTOKEN_SUFFIX.equals(suffix) && this.authTokenProcessorHandler.handle(restRequest, restChannel)) {
return true;
if (API_AUTHTOKEN_SUFFIX.equals(suffix)) {
final BytesRestResponse restResponse = this.authTokenProcessorHandler.handle(request);
if (restResponse != null) {
return restResponse;
}
}

Saml2Settings saml2Settings = this.saml2SettingsProvider.getCached();
BytesRestResponse authenticateResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, "");

authenticateResponse.addHeader("WWW-Authenticate", getWwwAuthenticateHeader(saml2Settings));

restChannel.sendResponse(authenticateResponse);

return true;
return authenticateResponse;
} catch (Exception e) {
log.error("Error in reRequestAuthentication()", e);
return false;
return null;
}
}

Expand Down
Loading