Skip to content

Commit

Permalink
[Backport 2.11] Add early rejection from RestHandler for unauthorized…
Browse files Browse the repository at this point in the history
… requests (#3418) (#3495)

### Description

Backport of 6b0b682 from #3418

Previously unauthorized requests were fully processed and rejected once
they reached the RestHandler. This allocations more memory and resources
for these requests that might not be useful if they are already detected
as unauthorized. Using the headerVerifer and decompressor customization
from [1], perform an early authorization check when only the headers are
available, save an 'early response' for transmission and do not perform
the decompression on the request to speed up closing out the connection.

- Resolves opensearch-project/OpenSearch#10260

Signed-off-by: Peter Nied <[email protected]>
Signed-off-by: Craig Perkins <[email protected]>
Signed-off-by: Craig Perkins <[email protected]>
Co-authored-by: Craig Perkins <[email protected]>
  • Loading branch information
peternied and cwperks authored Oct 7, 2023
1 parent 919589c commit f7c47af
Show file tree
Hide file tree
Showing 31 changed files with 990 additions and 281 deletions.
46 changes: 46 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ on:

env:
GRADLE_OPTS: -Dhttp.keepAlive=false
CI_ENVIRONMENT: normal

jobs:
generate-test-list:
Expand Down Expand Up @@ -107,6 +108,51 @@ jobs:
arguments: |
integrationTest -Dbuild.snapshot=false
resource-tests:
env:
CI_ENVIRONMENT: resource-test
strategy:
fail-fast: false
matrix:
jdk: [17]
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}

steps:
- name: Set up JDK for build and test
uses: actions/setup-java@v3
with:
distribution: temurin # Temurin is a distribution of adoptium
java-version: ${{ matrix.jdk }}

- name: Checkout security
uses: actions/checkout@v4

- name: Build and Test
uses: gradle/gradle-build-action@v2
with:
cache-disabled: true
arguments: |
integrationTest -Dbuild.snapshot=false --tests org.opensearch.security.ResourceFocusedTests
backward-compatibility-build:
runs-on: ubuntu-latest
steps:
- uses: actions/setup-java@v3
with:
distribution: temurin # Temurin is a distribution of adoptium
java-version: 17

- name: Checkout Security Repo
uses: actions/checkout@v4

- name: Build BWC tests
uses: gradle/gradle-build-action@v2
with:
cache-disabled: true
arguments: |
-p bwc-test build -x test -x integTest
backward-compatibility:
strategy:
fail-fast: false
Expand Down
22 changes: 16 additions & 6 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -470,17 +470,27 @@ sourceSets {

//add new task that runs integration tests
task integrationTest(type: Test) {
doFirst {
// Only run resources tests on resource-test CI environments or locally
if (System.getenv('CI_ENVIRONMENT') == 'resource-test' || System.getenv('CI_ENVIRONMENT') == null) {
include '**/ResourceFocusedTests.class'
} else {
exclude '**/ResourceFocusedTests.class'
}
// Only run with retries while in CI systems
if (System.getenv('CI_ENVIRONMENT') == 'normal') {
retry {
failOnPassedAfterRetry = false
maxRetries = 2
maxFailures = 10
}
}
}
description = 'Run integration tests.'
group = 'verification'
systemProperty "java.util.logging.manager", "org.apache.logging.log4j.jul.LogManager"
testClassesDirs = sourceSets.integrationTest.output.classesDirs
classpath = sourceSets.integrationTest.runtimeClasspath
jvmArgs += "-Djdk.internal.httpclient.disableHostnameVerification"
retry {
failOnPassedAfterRetry = false
maxRetries = 2
maxFailures = 10
}
//run the integrationTest task after the test task
shouldRunAfter test
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
package org.opensearch.security;

import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE;
import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL;
import static org.opensearch.test.framework.TestSecurityConfig.Role.ALL_ACCESS;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.management.GarbageCollectorMXBean;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryPoolMXBean;
import java.lang.management.MemoryUsage;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.zip.GZIPOutputStream;

import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.entity.ContentType;
import org.apache.http.message.BasicHeader;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.client.Client;
import org.opensearch.test.framework.TestSecurityConfig;
import org.opensearch.test.framework.TestSecurityConfig.User;
import org.opensearch.test.framework.cluster.ClusterManager;
import org.opensearch.test.framework.cluster.LocalCluster;
import org.opensearch.test.framework.cluster.TestRestClient;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;

@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class)
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class ResourceFocusedTests {
private static final User ADMIN_USER = new User("admin").roles(ALL_ACCESS);
private static final User LIMITED_USER = new User("limited_user").roles(
new TestSecurityConfig.Role("limited-role").clusterPermissions(
"indices:data/read/mget",
"indices:data/read/msearch",
"indices:data/read/scroll",
"cluster:monitor/state",
"cluster:monitor/health"
)
.indexPermissions(
"indices:data/read/search",
"indices:data/read/mget*",
"indices:monitor/settings/get",
"indices:monitor/stats"
)
.on("*")
);

@ClassRule
public static LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.THREE_CLUSTER_MANAGERS)
.authc(AUTHC_HTTPBASIC_INTERNAL)
.users(ADMIN_USER, LIMITED_USER)
.anonymousAuth(false)
.doNotFailOnForbidden(true)
.build();

@BeforeClass
public static void createTestData() {
try (Client client = cluster.getInternalNodeClient()) {
client.index(new IndexRequest().setRefreshPolicy(IMMEDIATE).index("document").source(Map.of("foo", "bar", "abc", "xyz")))
.actionGet();
}
}

@Test
public void testUnauthenticatedFewBig() {
// Tweaks:
final RequestBodySize size = RequestBodySize.XLarge;
final String requestPath = "/*/_search";
final int parrallelism = 5;
final int totalNumberOfRequests = 100;
final boolean statsPrinter = false;

runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter);
}

@Test
public void testUnauthenticatedManyMedium() {
// Tweaks:
final RequestBodySize size = RequestBodySize.Medium;
final String requestPath = "/*/_search";
final int parrallelism = 20;
final int totalNumberOfRequests = 10_000;
final boolean statsPrinter = false;

runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter);
}

@Test
public void testUnauthenticatedTonsSmall() {
// Tweaks:
final RequestBodySize size = RequestBodySize.Small;
final String requestPath = "/*/_search";
final int parrallelism = 100;
final int totalNumberOfRequests = 1_000_000;
final boolean statsPrinter = false;

runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter);
}

private Long runResourceTest(
final RequestBodySize size,
final String requestPath,
final int parrallelism,
final int totalNumberOfRequests,
final boolean statsPrinter
) {
final byte[] compressedRequestBody = createCompressedRequestBody(size);
try (final TestRestClient client = cluster.getRestClient(new BasicHeader("Content-Encoding", "gzip"))) {

if (statsPrinter) {
printStats();
}
final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath);
post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON));

final ForkJoinPool forkJoinPool = new ForkJoinPool(parrallelism);

final List<CompletableFuture<Void>> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests)
.boxed()
.map(i -> CompletableFuture.runAsync(() -> client.executeRequest(post), forkJoinPool))
.collect(Collectors.toList());
Supplier<Long> getCount = () -> waitingOn.stream().filter(cf -> cf.isDone() && !cf.isCompletedExceptionally()).count();

CompletableFuture<Void> statPrinter = statsPrinter ? CompletableFuture.runAsync(() -> {
while (true) {
printStats();
System.err.println(" & Succesful completions: " + getCount.get());
try {
Thread.sleep(500);
} catch (Exception e) {
break;
}
}
}, forkJoinPool) : CompletableFuture.completedFuture(null);

final CompletableFuture<Void> allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0]));

try {
allOfThem.get(30, TimeUnit.SECONDS);
statPrinter.cancel(true);
} catch (final Exception e) {
// Ignored
}

if (statsPrinter) {
printStats();
System.err.println(" & Succesful completions: " + getCount.get());
}
return getCount.get();
}
}

static enum RequestBodySize {
Small(1),
Medium(1_000),
XLarge(1_000_000);

public final int elementCount;

private RequestBodySize(final int elementCount) {
this.elementCount = elementCount;
}
}

private byte[] createCompressedRequestBody(final RequestBodySize size) {
final int repeatCount = size.elementCount;
final String prefix = "{ \"items\": [";
final String repeatedElement = IntStream.range(0, 20)
.mapToObj(n -> ('a' + n) + "")
.map(n -> '"' + n + '"' + ": 123")
.collect(Collectors.joining(",", "{", "}"));
final String postfix = "]}";
long uncompressedBytesSize = 0;

try (
final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
final GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream)
) {

final byte[] prefixBytes = prefix.getBytes(StandardCharsets.UTF_8);
final byte[] repeatedElementBytes = repeatedElement.getBytes(StandardCharsets.UTF_8);
final byte[] postfixBytes = postfix.getBytes(StandardCharsets.UTF_8);

gzipOutputStream.write(prefixBytes);
uncompressedBytesSize = uncompressedBytesSize + prefixBytes.length;
for (int i = 0; i < repeatCount; i++) {
gzipOutputStream.write(repeatedElementBytes);
uncompressedBytesSize = uncompressedBytesSize + repeatedElementBytes.length;
}
gzipOutputStream.write(postfixBytes);
uncompressedBytesSize = uncompressedBytesSize + postfixBytes.length;
gzipOutputStream.finish();

final byte[] compressedRequestBody = byteArrayOutputStream.toByteArray();
System.err.println(
"^^^"
+ String.format(
"Original size was %,d bytes, compressed to %,d bytes, ratio %,.2f",
uncompressedBytesSize,
compressedRequestBody.length,
((double) uncompressedBytesSize / compressedRequestBody.length)
)
);
return compressedRequestBody;
} catch (final IOException ioe) {
throw new RuntimeException(ioe);
}
}

private void printStats() {
System.err.println("** Stats ");
printMemory();
printMemoryPools();
printGCPools();
}

private void printMemory() {
final Runtime runtime = Runtime.getRuntime();

final long totalMemory = runtime.totalMemory(); // Total allocated memory
final long freeMemory = runtime.freeMemory(); // Amount of free memory
final long usedMemory = totalMemory - freeMemory; // Amount of used memory

System.err.println(" Memory Total: " + totalMemory + " Free:" + freeMemory + " Used:" + usedMemory);
}

private void printMemoryPools() {
List<MemoryPoolMXBean> memoryPools = ManagementFactory.getMemoryPoolMXBeans();
for (MemoryPoolMXBean memoryPool : memoryPools) {
MemoryUsage usage = memoryPool.getUsage();
System.err.println(" " + memoryPool.getName() + " USED: " + usage.getUsed() + " MAX: " + usage.getMax());
}
}

private void printGCPools() {
List<GarbageCollectorMXBean> garbageCollectors = ManagementFactory.getGarbageCollectorMXBeans();
for (GarbageCollectorMXBean garbageCollector : garbageCollectors) {
System.err.println(" " + garbageCollector.getName() + " COLLECTION TIME: " + garbageCollector.getCollectionTime());
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ public void shouldImpersonateUser_negativeJean() {

response.assertStatusCode(403);
String expectedMessage = String.format("'%s' is not allowed to impersonate as '%s'", USER_KIRK, USER_JEAN);
System.out.println("&&&& " + response.getBody());
assertThat(response.getTextFromJsonBody(POINTER_ERROR_REASON), equalTo(expectedMessage));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ public void createRoleMapping(String backendRoleName, String roleName) {
response.assertStatusCode(201);
}

protected final String getHttpServerUri() {
public final String getHttpServerUri() {
return "http" + (enableHTTPClientSSL ? "s" : "") + "://" + nodeHttpAddress.getHostString() + ":" + nodeHttpAddress.getPort();
}

Expand Down Expand Up @@ -403,7 +403,7 @@ private JsonNode getJsonNodeAt(String jsonPointer) {
try {
return toJsonNode().at(jsonPointer);
} catch (IOException e) {
throw new IllegalArgumentException("Cound not convert response body to JSON node ", e);
throw new IllegalArgumentException("Cound not convert response body to JSON node '" + getBody() + "'", e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ private Optional<SecurityResponse> handleLowLevel(RestRequest restRequest) throw
return Optional.of(new SecurityResponse(HttpStatus.SC_OK, SecurityResponse.CONTENT_TYPE_APP_JSON, responseBodyString));
} catch (JsonProcessingException e) {
log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e);
return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, null, "JSON could not be parsed"));
return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, new Exception("JSON could not be parsed")));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public class HTTPSamlAuthenticator implements HTTPAuthenticator, Destroyable {
public static final String IDP_METADATA_FILE = "idp.metadata_file";
public static final String IDP_METADATA_CONTENT = "idp.metadata_content";

private static final String API_AUTHTOKEN_SUFFIX = "api/authtoken";
public static final String API_AUTHTOKEN_SUFFIX = "api/authtoken";
private static final String AUTHINFO_SUFFIX = "authinfo";
private static final String REGEX_PATH_PREFIX = "/(" + LEGACY_OPENDISTRO_PREFIX + "|" + PLUGINS_PREFIX + ")/" + "(.*)";
private static final Pattern PATTERN_PATH_PREFIX = Pattern.compile(REGEX_PATH_PREFIX);
Expand Down
Loading

0 comments on commit f7c47af

Please sign in to comment.