diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d5d5e3430d..9e58ff4b4f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -108,6 +108,24 @@ jobs: arguments: | integrationTest -Dbuild.snapshot=false + backward-compatibility-build: + runs-on: ubuntu-latest + steps: + - uses: actions/setup-java@v3 + with: + distribution: temurin # Temurin is a distribution of adoptium + java-version: 17 + + - name: Checkout Security Repo + uses: actions/checkout@v4 + + - name: Build BWC tests + uses: gradle/gradle-build-action@v2 + with: + cache-disabled: true + arguments: | + -p bwc-test build -x test -x integTest + backward-compatibility: strategy: fail-fast: false diff --git a/bwc-test/build.gradle b/bwc-test/build.gradle index 24cc645ba1..6fb7fc2348 100644 --- a/bwc-test/build.gradle +++ b/bwc-test/build.gradle @@ -47,6 +47,7 @@ buildscript { opensearch_version = System.getProperty("opensearch.version", "3.0.0-SNAPSHOT") opensearch_group = "org.opensearch" common_utils_version = System.getProperty("common_utils.version", '2.9.0.0-SNAPSHOT') + jackson_version = System.getProperty("jackson_version", "2.15.2") } repositories { mavenLocal() @@ -72,6 +73,9 @@ dependencies { testImplementation "org.opensearch.test:framework:${opensearch_version}" testImplementation "org.apache.logging.log4j:log4j-core:${versions.log4j}" testImplementation "org.opensearch:common-utils:${common_utils_version}" + testImplementation "com.fasterxml.jackson.core:jackson-databind:${jackson_version}" + testImplementation "com.fasterxml.jackson.core:jackson-annotations:${jackson_version}" + } loggerUsageCheck.enabled = false diff --git a/bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java b/bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java deleted file mode 100644 index 3758b43265..0000000000 --- a/bwc-test/src/test/java/SecurityBackwardsCompatibilityIT.java +++ /dev/null @@ -1,205 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.security.bwc; - -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; - -import org.apache.hc.client5.http.auth.AuthScope; -import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; -import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; -import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; -import org.apache.hc.client5.http.nio.AsyncClientConnectionManager; -import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder; -import org.apache.hc.client5.http.ssl.NoopHostnameVerifier; -import org.apache.hc.core5.function.Factory; -import org.apache.hc.core5.http.Header; -import org.apache.hc.core5.http.HttpHost; -import org.apache.hc.core5.http.message.BasicHeader; -import org.apache.hc.core5.http.nio.ssl.TlsStrategy; -import org.apache.hc.core5.reactor.ssl.TlsDetails; -import org.apache.hc.core5.ssl.SSLContextBuilder; -import org.junit.Assume; -import org.junit.Before; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.test.rest.OpenSearchRestTestCase; - -import org.opensearch.Version; -import org.opensearch.common.settings.Settings; -import org.opensearch.test.rest.OpenSearchRestTestCase; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.hasItem; - -import org.opensearch.client.RestClient; -import org.opensearch.client.RestClientBuilder; - -import org.junit.Assert; - -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLEngine; - -public class SecurityBackwardsCompatibilityIT extends OpenSearchRestTestCase { - - private ClusterType CLUSTER_TYPE; - private String CLUSTER_NAME; - - @Before - private void testSetup() { - final String bwcsuiteString = System.getProperty("tests.rest.bwcsuite"); - Assume.assumeTrue("Test cannot be run outside the BWC gradle task 'bwcTestSuite' or its dependent tasks", bwcsuiteString != null); - CLUSTER_TYPE = ClusterType.parse(bwcsuiteString); - CLUSTER_NAME = System.getProperty("tests.clustername"); - } - - @Override - protected final boolean preserveClusterUponCompletion() { - return true; - } - - @Override - protected final boolean preserveIndicesUponCompletion() { - return true; - } - - @Override - protected final boolean preserveReposUponCompletion() { - return true; - } - - @Override - protected boolean preserveTemplatesUponCompletion() { - return true; - } - - @Override - protected String getProtocol() { - return "https"; - } - - @Override - protected final Settings restClientSettings() { - return Settings.builder() - .put(super.restClientSettings()) - // increase the timeout here to 90 seconds to handle long waits for a green - // cluster health. the waits for green need to be longer than a minute to - // account for delayed shards - .put(OpenSearchRestTestCase.CLIENT_SOCKET_TIMEOUT, "90s") - .build(); - } - - @Override - protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOException { - RestClientBuilder builder = RestClient.builder(hosts); - configureHttpsClient(builder, settings); - boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true); - builder.setStrictDeprecationMode(strictDeprecationMode); - return builder.build(); - } - - protected static void configureHttpsClient(RestClientBuilder builder, Settings settings) throws IOException { - Map headers = ThreadContext.buildDefaultHeaders(settings); - Header[] defaultHeaders = new Header[headers.size()]; - int i = 0; - for (Map.Entry entry : headers.entrySet()) { - defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue()); - } - builder.setDefaultHeaders(defaultHeaders); - builder.setHttpClientConfigCallback(httpClientBuilder -> { - String userName = Optional.ofNullable(System.getProperty("tests.opensearch.username")) - .orElseThrow(() -> new RuntimeException("user name is missing")); - String password = Optional.ofNullable(System.getProperty("tests.opensearch.password")) - .orElseThrow(() -> new RuntimeException("password is missing")); - BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); - credentialsProvider.setCredentials(new AuthScope(null, -1), new UsernamePasswordCredentials(userName, password.toCharArray())); - try { - SSLContext sslContext = SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build(); - - TlsStrategy tlsStrategy = ClientTlsStrategyBuilder.create() - .setSslContext(sslContext) - .setTlsVersions(new String[] { "TLSv1", "TLSv1.1", "TLSv1.2", "SSLv3" }) - .setHostnameVerifier(NoopHostnameVerifier.INSTANCE) - // See please https://issues.apache.org/jira/browse/HTTPCLIENT-2219 - .setTlsDetailsFactory(new Factory() { - @Override - public TlsDetails create(final SSLEngine sslEngine) { - return new TlsDetails(sslEngine.getSession(), sslEngine.getApplicationProtocol()); - } - }) - .build(); - - final AsyncClientConnectionManager cm = PoolingAsyncClientConnectionManagerBuilder.create() - .setTlsStrategy(tlsStrategy) - .build(); - return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider).setConnectionManager(cm); - } catch (Exception e) { - throw new RuntimeException(e); - } - }); - } - - public void testBasicBackwardsCompatibility() throws Exception { - String round = System.getProperty("tests.rest.bwcsuite_round"); - - if (round.equals("first") || round.equals("old")) { - assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-0/plugins"); - } else if (round.equals("second")) { - assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-1/plugins"); - } else if (round.equals("third")) { - assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-2/plugins"); - } - } - - @SuppressWarnings("unchecked") - public void testWhoAmI() throws Exception { - Map responseMap = (Map) getAsMap("_plugins/_security/whoami"); - Assert.assertTrue(responseMap.containsKey("dn")); - } - - private enum ClusterType { - OLD, - MIXED, - UPGRADED; - - public static ClusterType parse(String value) { - switch (value) { - case "old_cluster": - return OLD; - case "mixed_cluster": - return MIXED; - case "upgraded_cluster": - return UPGRADED; - default: - throw new AssertionError("unknown cluster type: " + value); - } - } - } - - @SuppressWarnings("unchecked") - private void assertPluginUpgrade(String uri) throws Exception { - Map> responseMap = (Map>) getAsMap(uri).get("nodes"); - for (Map response : responseMap.values()) { - List> plugins = (List>) response.get("plugins"); - Set pluginNames = plugins.stream().map(map -> (String) map.get("name")).collect(Collectors.toSet()); - - final Version minNodeVersion = this.minimumNodeVersion(); - - if (minNodeVersion.major <= 1) { - assertThat(pluginNames, hasItem("opensearch_security")); - } else { - assertThat(pluginNames, hasItem("opensearch-security")); - } - - } - } -} diff --git a/bwc-test/src/test/java/org/opensearch/security/bwc/ClusterType.java b/bwc-test/src/test/java/org/opensearch/security/bwc/ClusterType.java new file mode 100644 index 0000000000..7fe849d5b3 --- /dev/null +++ b/bwc-test/src/test/java/org/opensearch/security/bwc/ClusterType.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.security.bwc; + +public enum ClusterType { + OLD, + MIXED, + UPGRADED; + + public static ClusterType parse(String value) { + switch (value) { + case "old_cluster": + return OLD; + case "mixed_cluster": + return MIXED; + case "upgraded_cluster": + return UPGRADED; + default: + throw new AssertionError("unknown cluster type: " + value); + } + } +} diff --git a/bwc-test/src/test/java/org/opensearch/security/bwc/SecurityBackwardsCompatibilityIT.java b/bwc-test/src/test/java/org/opensearch/security/bwc/SecurityBackwardsCompatibilityIT.java new file mode 100644 index 0000000000..1647dbb132 --- /dev/null +++ b/bwc-test/src/test/java/org/opensearch/security/bwc/SecurityBackwardsCompatibilityIT.java @@ -0,0 +1,367 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.security.bwc; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import javax.net.ssl.SSLContext; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.hc.client5.http.auth.AuthScope; +import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; +import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; +import org.apache.hc.client5.http.nio.AsyncClientConnectionManager; +import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder; +import org.apache.hc.client5.http.ssl.NoopHostnameVerifier; +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.hc.core5.http.nio.ssl.TlsStrategy; +import org.apache.hc.core5.reactor.ssl.TlsDetails; +import org.apache.hc.core5.ssl.SSLContextBuilder; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.common.Randomness; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.util.io.IOUtils; +import org.opensearch.security.bwc.helper.RestHelper; +import org.opensearch.test.rest.OpenSearchRestTestCase; +import org.opensearch.Version; + +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.equalTo; + +public class SecurityBackwardsCompatibilityIT extends OpenSearchRestTestCase { + + private ClusterType CLUSTER_TYPE; + private String CLUSTER_NAME; + + private final String TEST_USER = "user"; + private final String TEST_PASSWORD = "290735c0-355d-4aaf-9b42-1aaa1f2a3cee"; + private final String TEST_ROLE = "test-dls-fls-role"; + private static RestClient testUserRestClient = null; + + @Before + public void testSetup() { + final String bwcsuiteString = System.getProperty("tests.rest.bwcsuite"); + Assume.assumeTrue("Test cannot be run outside the BWC gradle task 'bwcTestSuite' or its dependent tasks", bwcsuiteString != null); + CLUSTER_TYPE = ClusterType.parse(bwcsuiteString); + CLUSTER_NAME = System.getProperty("tests.clustername"); + if (testUserRestClient == null) { + testUserRestClient = buildClient( + super.restClientSettings(), + super.getClusterHosts().toArray(new HttpHost[0]), + TEST_USER, + TEST_PASSWORD + ); + } + } + + @Override + protected final boolean preserveClusterUponCompletion() { + return true; + } + + @Override + protected final boolean preserveIndicesUponCompletion() { + return true; + } + + @Override + protected final boolean preserveReposUponCompletion() { + return true; + } + + @Override + protected boolean preserveTemplatesUponCompletion() { + return true; + } + + @Override + protected String getProtocol() { + return "https"; + } + + @Override + protected final Settings restClientSettings() { + return Settings.builder() + .put(super.restClientSettings()) + // increase the timeout here to 90 seconds to handle long waits for a green + // cluster health. the waits for green need to be longer than a minute to + // account for delayed shards + .put(OpenSearchRestTestCase.CLIENT_SOCKET_TIMEOUT, "90s") + .build(); + } + + protected RestClient buildClient(Settings settings, HttpHost[] hosts, String username, String password) { + RestClientBuilder builder = RestClient.builder(hosts); + configureHttpsClient(builder, settings, username, password); + boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true); + builder.setStrictDeprecationMode(strictDeprecationMode); + return builder.build(); + } + + @Override + protected RestClient buildClient(Settings settings, HttpHost[] hosts) { + String username = Optional.ofNullable(System.getProperty("tests.opensearch.username")) + .orElseThrow(() -> new RuntimeException("user name is missing")); + String password = Optional.ofNullable(System.getProperty("tests.opensearch.password")) + .orElseThrow(() -> new RuntimeException("password is missing")); + return buildClient(super.restClientSettings(), super.getClusterHosts().toArray(new HttpHost[0]), username, password); + } + + private static void configureHttpsClient(RestClientBuilder builder, Settings settings, String userName, String password) { + Map headers = ThreadContext.buildDefaultHeaders(settings); + Header[] defaultHeaders = new Header[headers.size()]; + int i = 0; + for (Map.Entry entry : headers.entrySet()) { + defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue()); + } + builder.setDefaultHeaders(defaultHeaders); + builder.setHttpClientConfigCallback(httpClientBuilder -> { + BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials(new AuthScope(null, -1), new UsernamePasswordCredentials(userName, password.toCharArray())); + try { + SSLContext sslContext = SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build(); + + TlsStrategy tlsStrategy = ClientTlsStrategyBuilder.create() + .setSslContext(sslContext) + .setTlsVersions(new String[] { "TLSv1", "TLSv1.1", "TLSv1.2", "SSLv3" }) + .setHostnameVerifier(NoopHostnameVerifier.INSTANCE) + // See please https://issues.apache.org/jira/browse/HTTPCLIENT-2219 + .setTlsDetailsFactory(sslEngine -> new TlsDetails(sslEngine.getSession(), sslEngine.getApplicationProtocol())) + .build(); + + final AsyncClientConnectionManager cm = PoolingAsyncClientConnectionManagerBuilder.create() + .setTlsStrategy(tlsStrategy) + .build(); + return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider).setConnectionManager(cm); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + public void testWhoAmI() throws Exception { + Map responseMap = getAsMap("_plugins/_security/whoami"); + assertThat(responseMap, hasKey("dn")); + } + + public void testBasicBackwardsCompatibility() throws Exception { + String round = System.getProperty("tests.rest.bwcsuite_round"); + + if (round.equals("first") || round.equals("old")) { + assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-0/plugins"); + } else if (round.equals("second")) { + assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-1/plugins"); + } else if (round.equals("third")) { + assertPluginUpgrade("_nodes/" + CLUSTER_NAME + "-2/plugins"); + } + } + + /** + * Tests backward compatibility by created a test user and role with DLS, FLS and masked field settings. Ingests + * data into a test index and runs a matchAll query against the same. + */ + public void testDataIngestionAndSearchBackwardsCompatibility() throws Exception { + String round = System.getProperty("tests.rest.bwcsuite_round"); + String index = "test_index"; + if (round.equals("old")) { + createTestRoleIfNotExists(TEST_ROLE); + createUserIfNotExists(TEST_USER, TEST_PASSWORD, TEST_ROLE); + createIndexIfNotExists(index); + } + ingestData(index); + searchMatchAll(index); + } + + public void testNodeStats() throws IOException { + List responses = RestHelper.requestAgainstAllNodes(client(), "GET", "_nodes/stats", null); + responses.forEach(r -> Assert.assertEquals(200, r.getStatusLine().getStatusCode())); + } + + @SuppressWarnings("unchecked") + private void assertPluginUpgrade(String uri) throws Exception { + Map> responseMap = (Map>) getAsMap(uri).get("nodes"); + for (Map response : responseMap.values()) { + List> plugins = (List>) response.get("plugins"); + Set pluginNames = plugins.stream().map(map -> (String) map.get("name")).collect(Collectors.toSet()); + + final Version minNodeVersion = minimumNodeVersion(); + + if (minNodeVersion.major <= 1) { + assertThat(pluginNames, hasItem("opensearch_security")); // With underscore seperator + } else { + assertThat(pluginNames, hasItem("opensearch-security")); // With dash seperator + } + } + } + + /** + * Ingests data into the test index + * @param index index to ingest data into + */ + + private void ingestData(String index) throws IOException { + StringBuilder bulkRequestBody = new StringBuilder(); + ObjectMapper objectMapper = new ObjectMapper(); + int numberOfRequests = Randomness.get().nextInt(10); + while (numberOfRequests-- > 0) { + for (int i = 0; i < Randomness.get().nextInt(100); i++) { + Map> indexRequest = new HashMap<>(); + indexRequest.put("index", new HashMap<>() { + { + put("_index", index); + } + }); + bulkRequestBody.append(objectMapper.writeValueAsString(indexRequest) + "\n"); + bulkRequestBody.append(objectMapper.writeValueAsString(Song.randomSong().asJson()) + "\n"); + } + List responses = RestHelper.requestAgainstAllNodes( + testUserRestClient, + "POST", + "_bulk?refresh=wait_for", + RestHelper.toHttpEntity(bulkRequestBody.toString()) + ); + responses.forEach(r -> assertEquals(200, r.getStatusLine().getStatusCode())); + } + } + + /** + * Runs a matchAll query against the test index + * @param index index to search + */ + private void searchMatchAll(String index) throws IOException { + String matchAllQuery = "{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}"; + int numberOfRequests = Randomness.get().nextInt(10); + while (numberOfRequests-- > 0) { + List responses = RestHelper.requestAgainstAllNodes( + testUserRestClient, + "POST", + index + "/_search", + RestHelper.toHttpEntity(matchAllQuery) + ); + responses.forEach(r -> assertEquals(200, r.getStatusLine().getStatusCode())); + } + } + + /** + * Checks if a resource at the specified URL exists + * @param url of the resource to be checked for existence + * @return true if the resource exists, false otherwise + */ + + private boolean resourceExists(String url) throws IOException { + try { + RestHelper.get(adminClient(), url); + return true; + } catch (ResponseException e) { + if (e.getResponse().getStatusLine().getStatusCode() == 404) { + return false; + } else { + throw e; + } + } + } + + /** + * Creates a test role with DLS, FLS and masked field settings on the test index. + */ + private void createTestRoleIfNotExists(String role) throws IOException { + String url = "_plugins/_security/api/roles/" + role; + String roleSettings = "{\n" + + " \"cluster_permissions\": [\n" + + " \"unlimited\"\n" + + " ],\n" + + " \"index_permissions\": [\n" + + " {\n" + + " \"index_patterns\": [\n" + + " \"test_index*\"\n" + + " ],\n" + + " \"dls\": \"{ \\\"bool\\\": { \\\"must\\\": { \\\"match\\\": { \\\"genre\\\": \\\"rock\\\" } } } }\",\n" + + " \"fls\": [\n" + + " \"~lyrics\"\n" + + " ],\n" + + " \"masked_fields\": [\n" + + " \"artist\"\n" + + " ],\n" + + " \"allowed_actions\": [\n" + + " \"read\",\n" + + " \"write\"\n" + + " ]\n" + + " }\n" + + " ],\n" + + " \"tenant_permissions\": []\n" + + "}\n"; + Response response = RestHelper.makeRequest(adminClient(), "PUT", url, RestHelper.toHttpEntity(roleSettings)); + + assertThat(response.getStatusLine().getStatusCode(), anyOf(equalTo(200), equalTo(201))); + } + + /** + * Creates a test index if it does not exist already + * @param index index to create + */ + + private void createIndexIfNotExists(String index) throws IOException { + String settings = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"number_of_shards\": 3,\n" + + " \"number_of_replicas\": 1\n" + + " }\n" + + " }\n" + + "}"; + if (!resourceExists(index)) { + Response response = RestHelper.makeRequest(client(), "PUT", index, RestHelper.toHttpEntity(settings)); + assertThat(response.getStatusLine().getStatusCode(), equalTo(200)); + } + } + + /** + * Creates the test user if it does not exist already and maps it to the test role with DLS/FLS settings. + * @param user user to be created + * @param password password for the new user + * @param role roles that the user has to be mapped to + */ + private void createUserIfNotExists(String user, String password, String role) throws IOException { + String url = "_plugins/_security/api/internalusers/" + user; + if (!resourceExists(url)) { + String userSettings = String.format( + Locale.ENGLISH, + "{\n" + " \"password\": \"%s\",\n" + " \"opendistro_security_roles\": [\"%s\"],\n" + " \"backend_roles\": []\n" + "}", + password, + role + ); + Response response = RestHelper.makeRequest(adminClient(), "PUT", url, RestHelper.toHttpEntity(userSettings)); + assertThat(response.getStatusLine().getStatusCode(), equalTo(201)); + } + } + + @AfterClass + public static void cleanUp() throws IOException { + OpenSearchRestTestCase.closeClients(); + IOUtils.close(testUserRestClient); + } +} diff --git a/bwc-test/src/test/java/org/opensearch/security/bwc/Song.java b/bwc-test/src/test/java/org/opensearch/security/bwc/Song.java new file mode 100644 index 0000000000..3cfd2c03e8 --- /dev/null +++ b/bwc-test/src/test/java/org/opensearch/security/bwc/Song.java @@ -0,0 +1,117 @@ +/* +* Copyright OpenSearch Contributors +* SPDX-License-Identifier: Apache-2.0 +* +* The OpenSearch Contributors require contributions made to +* this file be licensed under the Apache-2.0 license or a +* compatible open source license. +* +*/ +package org.opensearch.security.bwc; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.opensearch.common.Randomness; + +import java.util.Map; +import java.util.Objects; +import java.util.UUID; + +public class Song { + + public static final String FIELD_TITLE = "title"; + public static final String FIELD_ARTIST = "artist"; + public static final String FIELD_LYRICS = "lyrics"; + public static final String FIELD_STARS = "stars"; + public static final String FIELD_GENRE = "genre"; + public static final String ARTIST_FIRST = "First artist"; + public static final String ARTIST_STRING = "String"; + public static final String ARTIST_TWINS = "Twins"; + public static final String TITLE_MAGNUM_OPUS = "Magnum Opus"; + public static final String TITLE_SONG_1_PLUS_1 = "Song 1+1"; + public static final String TITLE_NEXT_SONG = "Next song"; + public static final String ARTIST_NO = "No!"; + public static final String TITLE_POISON = "Poison"; + + public static final String ARTIST_YES = "yes"; + + public static final String TITLE_AFFIRMATIVE = "Affirmative"; + + public static final String ARTIST_UNKNOWN = "unknown"; + public static final String TITLE_CONFIDENTIAL = "confidential"; + + public static final String LYRICS_1 = "Very deep subject"; + public static final String LYRICS_2 = "Once upon a time"; + public static final String LYRICS_3 = "giant nonsense"; + public static final String LYRICS_4 = "Much too much"; + public static final String LYRICS_5 = "Little to little"; + public static final String LYRICS_6 = "confidential secret classified"; + + public static final String GENRE_ROCK = "rock"; + public static final String GENRE_JAZZ = "jazz"; + public static final String GENRE_BLUES = "blues"; + + public static final String QUERY_TITLE_NEXT_SONG = FIELD_TITLE + ":" + "\"" + TITLE_NEXT_SONG + "\""; + public static final String QUERY_TITLE_POISON = FIELD_TITLE + ":" + TITLE_POISON; + public static final String QUERY_TITLE_MAGNUM_OPUS = FIELD_TITLE + ":" + TITLE_MAGNUM_OPUS; + + public static final Song[] SONGS = { + new Song(ARTIST_FIRST, TITLE_MAGNUM_OPUS, LYRICS_1, 1, GENRE_ROCK), + new Song(ARTIST_STRING, TITLE_SONG_1_PLUS_1, LYRICS_2, 2, GENRE_BLUES), + new Song(ARTIST_TWINS, TITLE_NEXT_SONG, LYRICS_3, 3, GENRE_JAZZ), + new Song(ARTIST_NO, TITLE_POISON, LYRICS_4, 4, GENRE_ROCK), + new Song(ARTIST_YES, TITLE_AFFIRMATIVE, LYRICS_5, 5, GENRE_BLUES), + new Song(ARTIST_UNKNOWN, TITLE_CONFIDENTIAL, LYRICS_6, 6, GENRE_JAZZ) }; + + private final String artist; + private final String title; + private final String lyrics; + private final Integer stars; + private final String genre; + + public Song(String artist, String title, String lyrics, Integer stars, String genre) { + this.artist = Objects.requireNonNull(artist, "Artist is required"); + this.title = Objects.requireNonNull(title, "Title is required"); + this.lyrics = Objects.requireNonNull(lyrics, "Lyrics is required"); + this.stars = Objects.requireNonNull(stars, "Stars field is required"); + this.genre = Objects.requireNonNull(genre, "Genre field is required"); + } + + public String getArtist() { + return artist; + } + + public String getTitle() { + return title; + } + + public String getLyrics() { + return lyrics; + } + + public Integer getStars() { + return stars; + } + + public String getGenre() { + return genre; + } + + public Map asMap() { + return Map.of(FIELD_ARTIST, artist, FIELD_TITLE, title, FIELD_LYRICS, lyrics, FIELD_STARS, stars, FIELD_GENRE, genre); + } + + public String asJson() throws JsonProcessingException { + return new ObjectMapper().writeValueAsString(this.asMap()); + } + + public static Song randomSong() { + return new Song( + UUID.randomUUID().toString(), + UUID.randomUUID().toString(), + UUID.randomUUID().toString(), + Randomness.get().nextInt(5), + UUID.randomUUID().toString() + ); + } +} diff --git a/bwc-test/src/test/java/org/opensearch/security/bwc/helper/RestHelper.java b/bwc-test/src/test/java/org/opensearch/security/bwc/helper/RestHelper.java new file mode 100644 index 0000000000..3272ac736a --- /dev/null +++ b/bwc-test/src/test/java/org/opensearch/security/bwc/helper/RestHelper.java @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.security.bwc.helper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.WarningsHandler; + +import static org.apache.hc.core5.http.ContentType.APPLICATION_JSON; + +public class RestHelper { + + private static final Logger log = LogManager.getLogger(RestHelper.class); + + public static HttpEntity toHttpEntity(String jsonString) { + return new StringEntity(jsonString, APPLICATION_JSON); + } + + public static Response get(RestClient client, String url) throws IOException { + return makeRequest(client, "GET", url, null, null); + } + + public static Response makeRequest(RestClient client, String method, String endpoint, HttpEntity entity) throws IOException { + return makeRequest(client, method, endpoint, entity, null); + } + + public static Response makeRequest(RestClient client, String method, String endpoint, HttpEntity entity, List
headers) + throws IOException { + log.info("Making request " + method + " " + endpoint + ", with headers " + headers); + + Request request = new Request(method, endpoint); + + RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); + options.setWarningsHandler(WarningsHandler.PERMISSIVE); + if (headers != null) { + headers.forEach(header -> options.addHeader(header.getName(), header.getValue())); + } + request.setOptions(options.build()); + + if (entity != null) { + request.setEntity(entity); + } + + Response response = client.performRequest(request); + log.info("Recieved response " + response.getStatusLine()); + return response; + } + + public static List requestAgainstAllNodes(RestClient client, String method, String endpoint, HttpEntity entity) + throws IOException { + return requestAgainstAllNodes(client, method, endpoint, entity, null); + } + + public static List requestAgainstAllNodes( + RestClient client, + String method, + String endpoint, + HttpEntity entity, + List
headers + ) throws IOException { + int nodeCount = client.getNodes().size(); + List responses = new ArrayList<>(); + while (nodeCount-- > 0) { + responses.add(makeRequest(client, method, endpoint, entity, headers)); + } + return responses; + } + + public static Header getAuthorizationHeader(String username, String password) { + return new BasicHeader("Authorization", "Basic " + username + ":" + password); + } +} diff --git a/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java b/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java index 907d605860..f752ce4a49 100755 --- a/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java +++ b/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java @@ -11,6 +11,7 @@ package com.amazon.dlic.auth.ldap; +import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -20,6 +21,8 @@ import com.amazon.dlic.auth.ldap.util.Utils; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.security.support.WildcardMatcher; import org.opensearch.security.user.AuthCredentials; import org.opensearch.security.user.User; @@ -45,6 +48,12 @@ public LdapUser( attributes.putAll(extractLdapAttributes(originalUsername, userEntry, customAttrMaxValueLen, allowlistedCustomLdapAttrMatcher)); } + public LdapUser(StreamInput in) throws IOException { + super(in); + userEntry = null; + originalUsername = in.readString(); + } + /** * May return null because ldapEntry is transient * @@ -88,4 +97,10 @@ public static Map extractLdapAttributes( } return Collections.unmodifiableMap(attributes); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(originalUsername); + } } diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java b/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java index 804e0a2114..a8f511be97 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java @@ -773,7 +773,8 @@ private TransportAddress getRemoteAddress() { if (address == null && threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER) != null) { address = new TransportAddress( (InetSocketAddress) Base64Helper.deserializeObject( - threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER) + threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER), + threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION) ) ); } @@ -784,7 +785,8 @@ private String getUser() { User user = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); if (user == null && threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER) != null) { user = (User) Base64Helper.deserializeObject( - threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER) + threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), + threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION) ); } return user == null ? null : user.getName(); diff --git a/src/main/java/org/opensearch/security/auth/UserInjector.java b/src/main/java/org/opensearch/security/auth/UserInjector.java index 3e89a52e93..30df84ef5f 100644 --- a/src/main/java/org/opensearch/security/auth/UserInjector.java +++ b/src/main/java/org/opensearch/security/auth/UserInjector.java @@ -26,6 +26,7 @@ package org.opensearch.security.auth; +import java.io.IOException; import java.io.ObjectStreamException; import java.net.InetAddress; import java.net.UnknownHostException; @@ -36,6 +37,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.rest.RestRequest; @@ -63,13 +66,18 @@ public UserInjector(Settings settings, ThreadPool threadPool, AuditLog auditLog, } - static class InjectedUser extends User { + public static class InjectedUser extends User { private transient TransportAddress transportAddress; public InjectedUser(String name) { super(name); } + public InjectedUser(StreamInput in) throws IOException { + super(in); + this.setInjected(true); + } + private Object writeReplace() throws ObjectStreamException { User user = new User(getName()); user.addRoles(getRoles()); @@ -96,6 +104,11 @@ public void setTransportAddress(String addr) throws UnknownHostException, Illega this.transportAddress = new TransportAddress(iAdress, port); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } } public InjectedUser getInjectedUser() { diff --git a/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java b/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java index 14eaed4e0d..b35137a35d 100644 --- a/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java +++ b/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java @@ -443,7 +443,8 @@ private void setDlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) } else { if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER) != null) { Object deserializedDlsQueries = Base64Helper.deserializeObject( - threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER) + threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER), + threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) ); if (!dlsQueries.equals(deserializedDlsQueries)) { throw new OpenSearchSecurityException( @@ -506,7 +507,10 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER) != null) { if (!maskedFieldsMap.equals( - Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER)) + Base64Helper.deserializeObject( + threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER), + threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) + ) )) { throw new OpenSearchSecurityException( ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER + " does not match (SG 901D)" @@ -542,7 +546,10 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) } else { if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER) != null) { if (!flsFields.equals( - Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER)) + Base64Helper.deserializeObject( + threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER), + threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) + ) )) { throw new OpenSearchSecurityException( ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER @@ -550,7 +557,8 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) + flsFields + "---" + Base64Helper.deserializeObject( - threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER) + threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER), + threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) ) ); } else { diff --git a/src/main/java/org/opensearch/security/filter/SecurityFilter.java b/src/main/java/org/opensearch/security/filter/SecurityFilter.java index 06f2fae397..f433a5857d 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityFilter.java @@ -183,6 +183,10 @@ private void ap threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN, Origin.LOCAL.toString()); } + if (threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) == null) { + threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, false); + } + final ComplianceConfig complianceConfig = auditLog.getComplianceConfig(); if (complianceConfig != null && complianceConfig.isEnabled()) { attachSourceFieldContext(request); diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java index 0a1b94548e..c67579e30f 100644 --- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java +++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java @@ -83,8 +83,14 @@ protected ThreadContext getThreadContext() { @Override public final void messageReceived(T request, TransportChannel channel, Task task) throws Exception { + ThreadContext threadContext = getThreadContext(); + threadContext.putTransient( + ConfigConstants.USE_JDK_SERIALIZATION, + channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION) + ); + if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) { final Exception exception = ExceptionUtils.createBadHeaderException(); channel.sendResponse(exception); diff --git a/src/main/java/org/opensearch/security/support/Base64CustomHelper.java b/src/main/java/org/opensearch/security/support/Base64CustomHelper.java new file mode 100644 index 0000000000..dc66268fcd --- /dev/null +++ b/src/main/java/org/opensearch/security/support/Base64CustomHelper.java @@ -0,0 +1,225 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import com.amazon.dlic.auth.ldap.LdapUser; +import com.google.common.base.Preconditions; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.io.BaseEncoding; +import org.opensearch.OpenSearchException; +import org.opensearch.common.Nullable; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.common.Strings; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.User; + +import java.io.IOException; +import java.io.Serializable; + +import static org.opensearch.security.support.SafeSerializationUtils.prohibitUnsafeClasses; + +/** + * Provides support for Serialization/Deserialization of objects of supported classes into/from Base64 encoded stream + * using the OpenSearch's custom serialization protocol implemented by the StreamInput/StreamOutput classes. + */ +public class Base64CustomHelper { + + private enum CustomSerializationFormat { + + WRITEABLE(1), + STREAMABLE(2), + GENERIC(3); + + private final int id; + + CustomSerializationFormat(int id) { + this.id = id; + } + + static CustomSerializationFormat fromId(int id) { + switch (id) { + case 1: + return WRITEABLE; + case 2: + return STREAMABLE; + case 3: + return GENERIC; + default: + throw new IllegalArgumentException(String.format("%d is not a valid id", id)); + } + } + + } + + private static final BiMap, Integer> writeableClassToIdMap = HashBiMap.create(); + private static final StreamableRegistry streamableRegistry = StreamableRegistry.getInstance(); + + static { + registerAllWriteables(); + } + + protected static String serializeObject(final Serializable object) { + + Preconditions.checkArgument(object != null, "object must not be null"); + final BytesStreamOutput streamOutput = new SafeBytesStreamOutput(128); + Class clazz = object.getClass(); + try { + prohibitUnsafeClasses(clazz); + CustomSerializationFormat customSerializationFormat = getCustomSerializationMode(clazz); + switch (customSerializationFormat) { + case WRITEABLE: + streamOutput.writeByte((byte) CustomSerializationFormat.WRITEABLE.id); + streamOutput.writeByte((byte) getWriteableClassID(clazz).intValue()); + ((Writeable) object).writeTo(streamOutput); + break; + case STREAMABLE: + streamOutput.writeByte((byte) CustomSerializationFormat.STREAMABLE.id); + streamableRegistry.writeTo(streamOutput, object); + break; + case GENERIC: + streamOutput.writeByte((byte) CustomSerializationFormat.GENERIC.id); + streamOutput.writeGenericValue(object); + break; + default: + throw new IllegalArgumentException( + String.format("Could not determine custom serialization mode for class %s", clazz.getName()) + ); + } + } catch (final Exception e) { + throw new OpenSearchException("Instance {} of class {} is not serializable", e, object, object.getClass()); + } + final byte[] bytes = streamOutput.bytes().toBytesRef().bytes; + streamOutput.close(); + return BaseEncoding.base64().encode(bytes); + } + + protected static Serializable deserializeObject(final String string) { + + Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "object must not be null or empty"); + final byte[] bytes = BaseEncoding.base64().decode(string); + Serializable obj = null; + try (final BytesStreamInput streamInput = new SafeBytesStreamInput(bytes)) { + CustomSerializationFormat serializationFormat = CustomSerializationFormat.fromId(streamInput.readByte()); + switch (serializationFormat) { + case WRITEABLE: + final int classId = streamInput.readByte(); + Class clazz = getWriteableClassFromId(classId); + obj = (Serializable) clazz.getConstructor(StreamInput.class).newInstance(streamInput); + break; + case STREAMABLE: + obj = (Serializable) streamableRegistry.readFrom(streamInput); + break; + case GENERIC: + obj = (Serializable) streamInput.readGenericValue(); + break; + default: + throw new IllegalArgumentException("Could not determine custom deserialization mode"); + } + prohibitUnsafeClasses(obj.getClass()); + return obj; + } catch (final Exception e) { + throw new OpenSearchException(e); + } + } + + private static boolean isWriteable(Class clazz) { + return Writeable.class.isAssignableFrom(clazz); + } + + /** + * Returns integer ID for the registered Writeable class + *
+ * Protected for testing + */ + protected static Integer getWriteableClassID(Class clazz) { + if (!isWriteable(clazz)) { + throw new OpenSearchException("clazz should implement Writeable ", clazz); + } + if (!writeableClassToIdMap.containsKey(clazz)) { + throw new OpenSearchException("Writeable clazz not registered ", clazz); + } + return writeableClassToIdMap.get(clazz); + } + + private static Class getWriteableClassFromId(int id) { + return writeableClassToIdMap.inverse().get(id); + } + + /** + * Registers the given Writeable class for custom serialization by assigning an incrementing integer ID + * IDs are stored in a HashBiMap + * @param clazz class to be registered + */ + private static void registerWriteable(Class clazz) { + if (writeableClassToIdMap.containsKey(clazz)) { + throw new OpenSearchException("writeable clazz is already registered ", clazz.getName()); + } + int id = writeableClassToIdMap.size() + 1; + writeableClassToIdMap.put(clazz, id); + } + + /** + * Registers all Writeable classes for custom serialization support. + * Removing existing classes / changing order of registration will cause a breaking change in the serialization protocol + * as registerWriteable assigns an incrementing integer ID to each of the classes in the order it is called + * starting from 1. + *
+ * New classes can safely be added towards the end. + */ + private static void registerAllWriteables() { + registerWriteable(User.class); + registerWriteable(LdapUser.class); + registerWriteable(UserInjector.InjectedUser.class); + registerWriteable(SourceFieldsContext.class); + } + + private static CustomSerializationFormat getCustomSerializationMode(Class clazz) { + if (isWriteable(clazz)) { + return CustomSerializationFormat.WRITEABLE; + } else if (streamableRegistry.isStreamable(clazz)) { + return CustomSerializationFormat.STREAMABLE; + } else { + return CustomSerializationFormat.GENERIC; + } + } + + private static class SafeBytesStreamOutput extends BytesStreamOutput { + + public SafeBytesStreamOutput(int expectedSize) { + super(expectedSize); + } + + @Override + public void writeGenericValue(@Nullable Object value) throws IOException { + prohibitUnsafeClasses(value.getClass()); + super.writeGenericValue(value); + } + } + + private static class SafeBytesStreamInput extends BytesStreamInput { + + public SafeBytesStreamInput(byte[] bytes) { + super(bytes); + } + + @Override + public Object readGenericValue() throws IOException { + Object object = super.readGenericValue(); + prohibitUnsafeClasses(object.getClass()); + return object; + } + } +} diff --git a/src/main/java/org/opensearch/security/support/Base64Helper.java b/src/main/java/org/opensearch/security/support/Base64Helper.java index 836858decb..a5fbab8515 100644 --- a/src/main/java/org/opensearch/security/support/Base64Helper.java +++ b/src/main/java/org/opensearch/security/support/Base64Helper.java @@ -26,174 +26,47 @@ package org.opensearch.security.support; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.InvalidClassException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.io.ObjectStreamClass; -import java.io.OutputStream; import java.io.Serializable; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.SocketAddress; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.regex.Pattern; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import com.google.common.io.BaseEncoding; -import org.ldaptive.AbstractLdapBean; -import org.ldaptive.LdapAttribute; -import org.ldaptive.LdapEntry; -import org.ldaptive.SearchEntry; - -import com.amazon.dlic.auth.ldap.LdapUser; - -import org.opensearch.OpenSearchException; -import org.opensearch.SpecialPermission; -import org.opensearch.core.common.Strings; -import org.opensearch.security.user.User; public class Base64Helper { - private static final Set> SAFE_CLASSES = ImmutableSet.of( - String.class, - SocketAddress.class, - InetSocketAddress.class, - Pattern.class, - User.class, - SourceFieldsContext.class, - LdapUser.class, - SearchEntry.class, - LdapEntry.class, - AbstractLdapBean.class, - LdapAttribute.class - ); - - private static final List> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableList.of( - InetAddress.class, - Number.class, - Collection.class, - Map.class, - Enum.class - ); - - private static final Set SAFE_CLASS_NAMES = Collections.singleton("org.ldaptive.LdapAttribute$LdapAttributeValues"); - - private static boolean isSafeClass(Class cls) { - return cls.isArray() - || SAFE_CLASSES.contains(cls) - || SAFE_CLASS_NAMES.contains(cls.getName()) - || SAFE_ASSIGNABLE_FROM_CLASSES.stream().anyMatch(c -> c.isAssignableFrom(cls)); - } - - private final static class SafeObjectOutputStream extends ObjectOutputStream { - - private static final boolean useSafeObjectOutputStream = checkSubstitutionPermission(); - - @SuppressWarnings("removal") - private static boolean checkSubstitutionPermission() { - SecurityManager sm = System.getSecurityManager(); - if (sm != null) { - try { - sm.checkPermission(new SpecialPermission()); - - AccessController.doPrivileged((PrivilegedAction) () -> { - AccessController.checkPermission(SUBSTITUTION_PERMISSION); - return null; - }); - } catch (SecurityException e) { - return false; - } - } - return true; - } - - static ObjectOutputStream create(ByteArrayOutputStream out) throws IOException { - try { - return useSafeObjectOutputStream ? new SafeObjectOutputStream(out) : new ObjectOutputStream(out); - } catch (SecurityException e) { - // As we try to create SafeObjectOutputStream only when necessary permissions are granted, we should - // not reach here, but if we do, we can still return ObjectOutputStream after resetting ByteArrayOutputStream - out.reset(); - return new ObjectOutputStream(out); - } - } - - @SuppressWarnings("removal") - private SafeObjectOutputStream(OutputStream out) throws IOException { - super(out); - - SecurityManager sm = System.getSecurityManager(); - if (sm != null) { - sm.checkPermission(new SpecialPermission()); - } - - AccessController.doPrivileged((PrivilegedAction) () -> enableReplaceObject(true)); - } - - @Override - protected Object replaceObject(Object obj) throws IOException { - Class clazz = obj.getClass(); - if (isSafeClass(clazz)) { - return obj; - } - throw new IOException("Unauthorized serialization attempt " + clazz.getName()); - } + public static String serializeObject(final Serializable object, final boolean useJDKSerialization) { + return useJDKSerialization ? Base64JDKHelper.serializeObject(object) : Base64CustomHelper.serializeObject(object); } public static String serializeObject(final Serializable object) { - - Preconditions.checkArgument(object != null, "object must not be null"); - - final ByteArrayOutputStream bos = new ByteArrayOutputStream(); - try (final ObjectOutputStream out = SafeObjectOutputStream.create(bos)) { - out.writeObject(object); - } catch (final Exception e) { - throw new OpenSearchException("Instance {} of class {} is not serializable", e, object, object.getClass()); - } - final byte[] bytes = bos.toByteArray(); - return BaseEncoding.base64().encode(bytes); + return serializeObject(object, false); } public static Serializable deserializeObject(final String string) { - - Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "string must not be null or empty"); - - final byte[] bytes = BaseEncoding.base64().decode(string); - final ByteArrayInputStream bis = new ByteArrayInputStream(bytes); - try (SafeObjectInputStream in = new SafeObjectInputStream(bis)) { - return (Serializable) in.readObject(); - } catch (final Exception e) { - throw new OpenSearchException(e); - } + return deserializeObject(string, false); } - private final static class SafeObjectInputStream extends ObjectInputStream { - - public SafeObjectInputStream(InputStream in) throws IOException { - super(in); - } - - @Override - protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { - - Class clazz = super.resolveClass(desc); - if (isSafeClass(clazz)) { - return clazz; - } + public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) { + return useJDKDeserialization ? Base64JDKHelper.deserializeObject(string) : Base64CustomHelper.deserializeObject(string); + } - throw new InvalidClassException("Unauthorized deserialization attempt ", clazz.getName()); + /** + * Ensures that the returned string is JDK serialized. + * + * If the supplied string is a custom serialized representation, will deserialize it and further serialize using + * JDK, otherwise returns the string as is. + * + * @param string original string, can be JDK or custom serialized + * @return jdk serialized string + */ + public static String ensureJDKSerialized(final String string) { + Serializable serializable; + try { + serializable = Base64Helper.deserializeObject(string, false); + } catch (Exception e) { + // We received an exception when de-serializing the given string. It is probably JDK serialized. + // Try to deserialize using JDK + Base64Helper.deserializeObject(string, true); + // Since we could deserialize the object using JDK, the string is already JDK serialized, return as is + return string; } + // If we see an exception now, we want the caller to see it - + return Base64Helper.serializeObject(serializable, true); } } diff --git a/src/main/java/org/opensearch/security/support/Base64JDKHelper.java b/src/main/java/org/opensearch/security/support/Base64JDKHelper.java new file mode 100644 index 0000000000..a4ab87d813 --- /dev/null +++ b/src/main/java/org/opensearch/security/support/Base64JDKHelper.java @@ -0,0 +1,156 @@ +/* + * Copyright 2015-2018 _floragunn_ GmbH + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InvalidClassException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; +import java.io.OutputStream; +import java.io.Serializable; +import java.security.AccessController; +import java.security.PrivilegedAction; + +import com.google.common.base.Preconditions; +import com.google.common.io.BaseEncoding; + +import org.opensearch.OpenSearchException; +import org.opensearch.SpecialPermission; +import org.opensearch.core.common.Strings; + +import static org.opensearch.security.support.SafeSerializationUtils.isSafeClass; + +/** + * Provides support for Serialization/Deserialization of objects of supported classes into/from Base64 encoded stream + * using JDK's in-built serialization protocol implemented by the ObjectOutputStream and ObjectInputStream classes. + */ +public class Base64JDKHelper { + + private final static class SafeObjectOutputStream extends ObjectOutputStream { + + private static final boolean useSafeObjectOutputStream = checkSubstitutionPermission(); + + @SuppressWarnings("removal") + private static boolean checkSubstitutionPermission() { + SecurityManager sm = System.getSecurityManager(); + if (sm != null) { + try { + sm.checkPermission(new SpecialPermission()); + + AccessController.doPrivileged((PrivilegedAction) () -> { + AccessController.checkPermission(SUBSTITUTION_PERMISSION); + return null; + }); + } catch (SecurityException e) { + return false; + } + } + return true; + } + + static ObjectOutputStream create(ByteArrayOutputStream out) throws IOException { + try { + return useSafeObjectOutputStream ? new SafeObjectOutputStream(out) : new ObjectOutputStream(out); + } catch (SecurityException e) { + // As we try to create SafeObjectOutputStream only when necessary permissions are granted, we should + // not reach here, but if we do, we can still return ObjectOutputStream after resetting ByteArrayOutputStream + out.reset(); + return new ObjectOutputStream(out); + } + } + + @SuppressWarnings("removal") + private SafeObjectOutputStream(OutputStream out) throws IOException { + super(out); + + SecurityManager sm = System.getSecurityManager(); + if (sm != null) { + sm.checkPermission(new SpecialPermission()); + } + + AccessController.doPrivileged((PrivilegedAction) () -> enableReplaceObject(true)); + } + + @Override + protected Object replaceObject(Object obj) throws IOException { + Class clazz = obj.getClass(); + if (isSafeClass(clazz)) { + return obj; + } + throw new IOException("Unauthorized serialization attempt " + clazz.getName()); + } + } + + public static String serializeObject(final Serializable object) { + + Preconditions.checkArgument(object != null, "object must not be null"); + + final ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try (final ObjectOutputStream out = SafeObjectOutputStream.create(bos)) { + out.writeObject(object); + } catch (final Exception e) { + throw new OpenSearchException("Instance {} of class {} is not serializable", e, object, object.getClass()); + } + final byte[] bytes = bos.toByteArray(); + return BaseEncoding.base64().encode(bytes); + } + + public static Serializable deserializeObject(final String string) { + + Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "object must not be null or empty"); + + final byte[] bytes = BaseEncoding.base64().decode(string); + final ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + try (SafeObjectInputStream in = new SafeObjectInputStream(bis)) { + return (Serializable) in.readObject(); + } catch (final Exception e) { + throw new OpenSearchException(e); + } + } + + private final static class SafeObjectInputStream extends ObjectInputStream { + + public SafeObjectInputStream(InputStream in) throws IOException { + super(in); + } + + @Override + protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { + + Class clazz = super.resolveClass(desc); + if (isSafeClass(clazz)) { + return clazz; + } + + throw new InvalidClassException("Unauthorized deserialization attempt ", clazz.getName()); + } + } +} diff --git a/src/main/java/org/opensearch/security/support/ConfigConstants.java b/src/main/java/org/opensearch/security/support/ConfigConstants.java index 8317d65335..9ac73cd579 100644 --- a/src/main/java/org/opensearch/security/support/ConfigConstants.java +++ b/src/main/java/org/opensearch/security/support/ConfigConstants.java @@ -35,6 +35,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import org.opensearch.Version; import org.opensearch.common.settings.Settings; import org.opensearch.security.auditlog.impl.AuditCategory; @@ -242,6 +243,7 @@ public class ConfigConstants { "opendistro_security.compliance.history.write.ignore_users"; public static final String OPENDISTRO_SECURITY_COMPLIANCE_HISTORY_EXTERNAL_CONFIG_ENABLED = "opendistro_security.compliance.history.external_config_enabled"; + public static final String OPENDISTRO_SECURITY_SOURCE_FIELD_CONTEXT = OPENDISTRO_SECURITY_CONFIG_PREFIX + "source_field_context"; public static final String SECURITY_COMPLIANCE_DISABLE_ANONYMOUS_AUTHENTICATION = "plugins.security.compliance.disable_anonymous_authentication"; public static final String SECURITY_COMPLIANCE_IMMUTABLE_INDICES = "plugins.security.compliance.immutable_indices"; @@ -323,6 +325,9 @@ public enum RolesMappingResolution { public static final String TENANCY_GLOBAL_TENANT_NAME = "global"; public static final String TENANCY_GLOBAL_TENANT_DEFAULT_NAME = ""; + public static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization"; + public static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_3_0_0; + // On-behalf-of endpoints settings // CS-SUPPRESS-SINGLE: RegexpSingleline get Extensions Settings public static final String EXTENSIONS_BWC_PLUGIN_MODE = "bwcPluginMode"; diff --git a/src/main/java/org/opensearch/security/support/HeaderHelper.java b/src/main/java/org/opensearch/security/support/HeaderHelper.java index e8d50346a8..bbb44664fa 100644 --- a/src/main/java/org/opensearch/security/support/HeaderHelper.java +++ b/src/main/java/org/opensearch/security/support/HeaderHelper.java @@ -27,6 +27,8 @@ package org.opensearch.security.support; import java.io.Serializable; +import java.util.Arrays; +import java.util.List; import com.google.common.base.Strings; @@ -68,7 +70,7 @@ public static Serializable deserializeSafeFromHeader(final ThreadContext context final String objectAsBase64 = getSafeFromHeader(context, headerName); if (!Strings.isNullOrEmpty(objectAsBase64)) { - return Base64Helper.deserializeObject(objectAsBase64); + return Base64Helper.deserializeObject(objectAsBase64, context.getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } return null; @@ -77,4 +79,16 @@ public static Serializable deserializeSafeFromHeader(final ThreadContext context public static boolean isTrustedClusterRequest(final ThreadContext context) { return context.getTransient(ConfigConstants.OPENDISTRO_SECURITY_SSL_TRANSPORT_TRUSTED_CLUSTER_REQUEST) == Boolean.TRUE; } + + public static List getAllSerializedHeaderNames() { + return Arrays.asList( + ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_DLS_FILTER_LEVEL_QUERY_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_SOURCE_FIELD_CONTEXT + ); + } } diff --git a/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java b/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java new file mode 100644 index 0000000000..c980959f68 --- /dev/null +++ b/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java @@ -0,0 +1,81 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import com.amazon.dlic.auth.ldap.LdapUser; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.ldaptive.AbstractLdapBean; +import org.ldaptive.LdapAttribute; +import org.ldaptive.LdapEntry; +import org.ldaptive.SearchEntry; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.User; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Pattern; + +/** + * Provides functionality to verify if a class is categorised to be safe for serialization or + * deserialization by the security plugin. + *
+ * All methods are package private. + */ +public final class SafeSerializationUtils { + + private static final Set> SAFE_CLASSES = ImmutableSet.of( + String.class, + SocketAddress.class, + InetSocketAddress.class, + Pattern.class, + User.class, + UserInjector.InjectedUser.class, + SourceFieldsContext.class, + LdapUser.class, + SearchEntry.class, + LdapEntry.class, + AbstractLdapBean.class, + LdapAttribute.class + ); + + private static final List> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableList.of( + InetAddress.class, + Number.class, + Collection.class, + Map.class, + Enum.class + ); + + private static final Set SAFE_CLASS_NAMES = Collections.singleton("org.ldaptive.LdapAttribute$LdapAttributeValues"); + + static boolean isSafeClass(Class cls) { + return cls.isArray() + || SAFE_CLASSES.contains(cls) + || SAFE_CLASS_NAMES.contains(cls.getName()) + || SAFE_ASSIGNABLE_FROM_CLASSES.stream().anyMatch(c -> c.isAssignableFrom(cls)); + } + + static void prohibitUnsafeClasses(Class clazz) throws IOException { + if (!isSafeClass(clazz)) { + throw new IOException("Unauthorized serialization attempt " + clazz.getName()); + } + } + +} diff --git a/src/main/java/org/opensearch/security/support/SourceFieldsContext.java b/src/main/java/org/opensearch/security/support/SourceFieldsContext.java index 02f0ad9226..83bbb683e9 100644 --- a/src/main/java/org/opensearch/security/support/SourceFieldsContext.java +++ b/src/main/java/org/opensearch/security/support/SourceFieldsContext.java @@ -26,13 +26,18 @@ package org.opensearch.security.support; +import java.io.IOException; import java.io.Serializable; import java.util.Arrays; +import java.util.Objects; import org.opensearch.action.get.GetRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; -public class SourceFieldsContext implements Serializable { +public class SourceFieldsContext implements Serializable, Writeable { private String[] includes; private String[] excludes; @@ -77,6 +82,18 @@ public SourceFieldsContext(SearchRequest request) { // } } + public SourceFieldsContext(StreamInput in) throws IOException { + includes = in.readStringArray(); + if (includes.length == 0) { + includes = null; + } + excludes = in.readStringArray(); + if (excludes.length == 0) { + excludes = null; + } + fetchSource = in.readBoolean(); + } + public SourceFieldsContext(GetRequest request) { if (request.fetchSourceContext() != null) { includes = request.fetchSourceContext().includes(); @@ -117,4 +134,11 @@ public String toString() { + fetchSource + "]"; } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeStringArray(Objects.requireNonNullElseGet(includes, () -> new String[] {})); + streamOutput.writeStringArray(Objects.requireNonNullElseGet(excludes, () -> new String[] {})); + streamOutput.writeBoolean(fetchSource); + } } diff --git a/src/main/java/org/opensearch/security/support/StreamableRegistry.java b/src/main/java/org/opensearch/security/support/StreamableRegistry.java new file mode 100644 index 0000000000..bfde866376 --- /dev/null +++ b/src/main/java/org/opensearch/security/support/StreamableRegistry.java @@ -0,0 +1,134 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.HashMap; +import java.util.Map; + +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; + +import org.opensearch.OpenSearchException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; + +/** + * Registry for any class that does NOT implement the Writeable interface + * and needs to be serialized over the wire. Supports registration of writer and reader via registerStreamable + * for such classes and provides methods writeTo and readFrom for objects of such registered classes. + *
+ * Methods are protected and intended to be accessed from only within the package. (mostly by Base64Helper) + */ +public class StreamableRegistry { + + private static final StreamableRegistry INSTANCE = new StreamableRegistry(); + public final BiMap, Integer> classToIdMap = HashBiMap.create(); + private final Map idToEntryMap = new HashMap<>(); + + private StreamableRegistry() { + registerAllStreamables(); + } + + private static class Entry { + Writeable.Writer writer; + Writeable.Reader reader; + + Entry(Writeable.Writer writer, Writeable.Reader reader) { + this.writer = writer; + this.reader = reader; + } + } + + private Writeable.Writer getWriter(Class clazz) { + if (!classToIdMap.containsKey(clazz)) { + throw new OpenSearchException(String.format("No writer registered for class %s", clazz.getName())); + } + return idToEntryMap.get(classToIdMap.get(clazz)).writer; + } + + private Writeable.Reader getReader(int id) { + if (!idToEntryMap.containsKey(id)) { + throw new OpenSearchException(String.format("No reader registered for id %s", id)); + } + return idToEntryMap.get(id).reader; + } + + private int getId(Class clazz) { + if (!classToIdMap.containsKey(clazz)) { + throw new OpenSearchException(String.format("No writer registered for class %s", clazz.getName())); + } + return classToIdMap.get(clazz); + } + + protected boolean isStreamable(Class clazz) { + return classToIdMap.containsKey(clazz); + } + + protected void writeTo(StreamOutput out, Object object) throws IOException { + out.writeByte((byte) getId(object.getClass())); + getWriter(object.getClass()).write(out, object); + } + + protected Object readFrom(StreamInput in) throws IOException { + int id = in.readByte(); + return getReader(id).read(in); + } + + protected static StreamableRegistry getInstance() { + return INSTANCE; + } + + protected void registerStreamable(int streamableId, Class clazz, Writeable.Writer writer, Writeable.Reader reader) { + if (Writeable.class.isAssignableFrom(clazz)) { + throw new IllegalArgumentException( + String.format("%s is Writeable and should not be registered as a streamable", clazz.getName()) + ); + } + classToIdMap.put(clazz, streamableId); + idToEntryMap.put(streamableId, new Entry(writer, reader)); + } + + protected int getStreamableID(Class clazz) { + if (!isStreamable(clazz)) { + throw new OpenSearchException(String.format("class %s is in streamable registry", clazz.getName())); + } else { + return classToIdMap.get(clazz); + } + } + + /** + * Register all streamables here. + *
+ * Caution - Register new streamables towards the end. Removing / reordering a registered streamable will change the typeIDs associated with the streamables + * causing a breaking change in the serialization format. + */ + private void registerAllStreamables() { + + // InetSocketAddress + this.registerStreamable(1, InetSocketAddress.class, (o, v) -> { + final InetSocketAddress inetSocketAddress = (InetSocketAddress) v; + o.writeString(inetSocketAddress.getHostString()); + o.writeByteArray(inetSocketAddress.getAddress().getAddress()); + o.writeInt(inetSocketAddress.getPort()); + }, i -> { + String host = i.readString(); + byte[] addressBytes = i.readByteArray(); + int port = i.readInt(); + return new InetSocketAddress(InetAddress.getByAddress(host, addressBytes), port); + }); + } + +} diff --git a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java index 0c645c9a00..f064f0af04 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java +++ b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java @@ -59,6 +59,7 @@ import org.opensearch.security.ssl.transport.SSLConfig; import org.opensearch.security.support.Base64Helper; import org.opensearch.security.support.ConfigConstants; +import org.opensearch.security.support.HeaderHelper; import org.opensearch.security.user.User; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transport.Connection; @@ -147,6 +148,7 @@ public void sendRequestDecorate( final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS); final boolean isDebugEnabled = log.isDebugEnabled(); + final boolean useJDKSerialization = connection.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION); final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode()); try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) { @@ -224,9 +226,26 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL ); } + if (useJDKSerialization) { + Map jdkSerializedHeaders = new HashMap<>(); + HeaderHelper.getAllSerializedHeaderNames() + .stream() + .filter(k -> headerMap.get(k) != null) + .forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k)))); + headerMap.putAll(jdkSerializedHeaders); + } + getThreadContext().putHeader(headerMap); - ensureCorrectHeaders(remoteAddress0, user0, origin0, injectedUserString, injectedRolesString, isSameNodeRequest); + ensureCorrectHeaders( + remoteAddress0, + user0, + origin0, + injectedUserString, + injectedRolesString, + isSameNodeRequest, + useJDKSerialization + ); if (isActionTraceEnabled()) { getThreadContext().putHeader( @@ -253,7 +272,8 @@ private void ensureCorrectHeaders( final String origin, final String injectedUserString, final String injectedRolesString, - boolean isSameNodeRequest + final boolean isSameNodeRequest, + final boolean useJDKSerialization ) { // keep original address @@ -294,7 +314,7 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADE if (transportAddress != null) { getThreadContext().putHeader( ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER, - Base64Helper.serializeObject(transportAddress.address()) + Base64Helper.serializeObject(transportAddress.address(), useJDKSerialization) ); } @@ -302,7 +322,10 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADE if (userHeader == null) { // put as headers for other requests if (origUser != null) { - getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(origUser)); + getThreadContext().putHeader( + ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, + Base64Helper.serializeObject(origUser, useJDKSerialization) + ); } else if (StringUtils.isNotEmpty(injectedRolesString)) { getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_HEADER, injectedRolesString); } else if (StringUtils.isNotEmpty(injectedUserString)) { diff --git a/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java b/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java index 1284ca9781..3ba379dd67 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java +++ b/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java @@ -107,6 +107,8 @@ protected void messageReceivedDecorate( resolvedActionClass = ((ConcreteShardRequest) request).getRequest().getClass().getSimpleName(); } + final boolean useJDKSerialization = getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION); + String initialActionClassValue = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INITIAL_ACTION_CLASS_HEADER); final ThreadContext.StoredContext sgContext = getThreadContext().newStoredContext(false); @@ -181,7 +183,7 @@ protected void messageReceivedDecorate( } else { getThreadContext().putTransient( ConfigConstants.OPENDISTRO_SECURITY_USER, - Objects.requireNonNull((User) Base64Helper.deserializeObject(userHeader)) + Objects.requireNonNull((User) Base64Helper.deserializeObject(userHeader, useJDKSerialization)) ); } @@ -190,7 +192,7 @@ protected void messageReceivedDecorate( if (!Strings.isNullOrEmpty(originalRemoteAddress)) { getThreadContext().putTransient( ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, - new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress)) + new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress, useJDKSerialization)) ); } else { getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, request.remoteAddress()); diff --git a/src/main/java/org/opensearch/security/user/User.java b/src/main/java/org/opensearch/security/user/User.java index 2642b368d7..394b251271 100644 --- a/src/main/java/org/opensearch/security/user/User.java +++ b/src/main/java/org/opensearch/security/user/User.java @@ -83,6 +83,9 @@ public User(final StreamInput in) throws IOException { name = in.readString(); roles.addAll(in.readList(StreamInput::readString)); requestedTenant = in.readString(); + if (requestedTenant.isEmpty()) { + requestedTenant = null; + } attributes = Collections.synchronizedMap(in.readMap(StreamInput::readString, StreamInput::readString)); securityRoles.addAll(in.readList(StreamInput::readString)); } @@ -167,9 +170,9 @@ public final boolean isUserInRole(final String role) { } /** - * Associate this user with a set of backend roles + * Associate this user with a set of custom attributes * - * @param roles The backend roles + * @param attributes custom attributes */ public final void addAttributes(final Map attributes) { if (attributes != null) { @@ -255,7 +258,7 @@ public final void copyRolesFrom(final User user) { public void writeTo(StreamOutput out) throws IOException { out.writeString(name); out.writeStringCollection(new ArrayList(roles)); - out.writeString(requestedTenant); + out.writeString(requestedTenant == null ? "" : requestedTenant); out.writeMap(attributes, StreamOutput::writeString, StreamOutput::writeString); out.writeStringCollection(securityRoles == null ? Collections.emptyList() : new ArrayList(securityRoles)); } diff --git a/src/test/java/org/opensearch/security/support/Base64CustomHelperTest.java b/src/test/java/org/opensearch/security/support/Base64CustomHelperTest.java new file mode 100644 index 0000000000..e35e1d72ba --- /dev/null +++ b/src/test/java/org/opensearch/security/support/Base64CustomHelperTest.java @@ -0,0 +1,159 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import com.amazon.dlic.auth.ldap.LdapUser; +import org.junit.Assert; +import org.junit.Test; +import org.ldaptive.LdapEntry; +import org.opensearch.OpenSearchException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.AuthCredentials; +import org.opensearch.security.user.User; + +import java.io.Serializable; +import java.net.InetSocketAddress; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.HashMap; + +import static org.opensearch.security.support.Base64CustomHelper.deserializeObject; +import static org.opensearch.security.support.Base64CustomHelper.serializeObject; + +public class Base64CustomHelperTest { + + private static final class NotSafeStreamable implements Serializable { + private static final long serialVersionUID = 5135559266828470092L; + } + + private static final class NotSafeWriteable implements Writeable, Serializable { + @Override + public void writeTo(StreamOutput out) { + + } + } + + private static Serializable ds(Serializable s) { + return deserializeObject(serializeObject(s)); + } + + @Test + public void testString() { + String string = "string"; + Assert.assertEquals(string, ds(string)); + } + + @Test + public void testInteger() { + Integer integer = 0; + Assert.assertEquals(integer, ds(integer)); + } + + @Test + public void testDouble() { + Double number = 0.; + Assert.assertEquals(number, ds(number)); + } + + @Test + public void testInetSocketAddress() { + InetSocketAddress inetSocketAddress = new InetSocketAddress(0); + Assert.assertEquals(inetSocketAddress, ds(inetSocketAddress)); + } + + @Test + public void testUser() { + User user = new User("user"); + Assert.assertEquals(user, ds(user)); + } + + @Test + public void testSourceFieldsContext() { + SourceFieldsContext sourceFieldsContext = new SourceFieldsContext(new SearchRequest("")); + Assert.assertEquals(sourceFieldsContext.toString(), ds(sourceFieldsContext).toString()); + } + + @Test + public void testHashMap() { + HashMap map = new HashMap<>() { + { + put("key", "value"); + } + }; + Assert.assertEquals(map, ds(map)); + } + + @Test + public void testArrayList() { + ArrayList list = new ArrayList<>() { + { + add("value"); + } + }; + Assert.assertEquals(list, ds(list)); + } + + @Test + public void testLdapUser() { + LdapUser ldapUser = new LdapUser( + "username", + "originalusername", + new LdapEntry("dn"), + new AuthCredentials("originalusername", "12345"), + 34, + WildcardMatcher.ANY + ); + Assert.assertEquals(ldapUser, ds(ldapUser)); + } + + @Test + public void testGetWriteableClassID() { + // a need to make a change in this test signifies a breaking change in security plugin's custom serialization + // format + Assert.assertEquals(Integer.valueOf(1), Base64CustomHelper.getWriteableClassID(User.class)); + Assert.assertEquals(Integer.valueOf(2), Base64CustomHelper.getWriteableClassID(LdapUser.class)); + Assert.assertEquals(Integer.valueOf(3), Base64CustomHelper.getWriteableClassID(UserInjector.InjectedUser.class)); + Assert.assertEquals(Integer.valueOf(4), Base64CustomHelper.getWriteableClassID(SourceFieldsContext.class)); + } + + @Test + public void testInjectedUser() { + UserInjector.InjectedUser injectedUser = new UserInjector.InjectedUser("username"); + + // for custom serialization, we expect InjectedUser to be returned on deserialization + UserInjector.InjectedUser deserializedInjecteduser = (UserInjector.InjectedUser) ds(injectedUser); + Assert.assertEquals(injectedUser, deserializedInjecteduser); + Assert.assertTrue(deserializedInjecteduser.isInjected()); + } + + @Test(expected = OpenSearchException.class) + public void testNotSafeStreamable() { + Base64JDKHelper.serializeObject(new NotSafeStreamable()); + } + + @Test(expected = OpenSearchException.class) + public void testNotSafeWriteable() { + Base64JDKHelper.serializeObject(new NotSafeWriteable()); + } + + @Test(expected = OpenSearchException.class) + public void testNotSafeGeneric() { + HashMap map = new HashMap<>(); + map.put(1, ZonedDateTime.now()); + map.put(2, ZonedDateTime.now()); + Base64JDKHelper.serializeObject(map); + } + +} diff --git a/src/test/java/org/opensearch/security/support/Base64HelperTest.java b/src/test/java/org/opensearch/security/support/Base64HelperTest.java index 81c2505985..f55581c7e7 100644 --- a/src/test/java/org/opensearch/security/support/Base64HelperTest.java +++ b/src/test/java/org/opensearch/security/support/Base64HelperTest.java @@ -10,100 +10,44 @@ */ package org.opensearch.security.support; -import java.io.ByteArrayOutputStream; -import java.io.ObjectOutputStream; import java.io.Serializable; -import java.net.InetSocketAddress; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.regex.Pattern; -import com.google.common.io.BaseEncoding; import org.junit.Assert; import org.junit.Test; -import org.opensearch.OpenSearchException; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.security.user.User; - import static org.opensearch.security.support.Base64Helper.deserializeObject; import static org.opensearch.security.support.Base64Helper.serializeObject; public class Base64HelperTest { - private static final class NotSafeSerializable implements Serializable { - private static final long serialVersionUID = 5135559266828470092L; + private static Serializable dsJDK(Serializable s) { + return deserializeObject(serializeObject(s, true), true); } private static Serializable ds(Serializable s) { return deserializeObject(serializeObject(s)); } + /** + * Just one sanity test comprising invocation of JDK and Custom Serialization. + * + * Individual scenarios are covered by Base64CustomHelperTest and Base64JDKHelperTest + */ @Test - public void testString() { - String string = "string"; - Assert.assertEquals(string, ds(string)); - } - - @Test - public void testInteger() { - Integer integer = Integer.valueOf(0); - Assert.assertEquals(integer, ds(integer)); - } - - @Test - public void testDouble() { - Double number = Double.valueOf(0.); - Assert.assertEquals(number, ds(number)); - } - - @Test - public void testInetSocketAddress() { - InetSocketAddress inetSocketAddress = new InetSocketAddress(0); - Assert.assertEquals(inetSocketAddress, ds(inetSocketAddress)); - } - - @Test - public void testPattern() { - Pattern pattern = Pattern.compile(".*"); - Assert.assertEquals(pattern.pattern(), ((Pattern) ds(pattern)).pattern()); - } - - @Test - public void testUser() { - User user = new User("user"); - Assert.assertEquals(user, ds(user)); - } - - @Test - public void testSourceFieldsContext() { - SourceFieldsContext sourceFieldsContext = new SourceFieldsContext(new SearchRequest("")); - Assert.assertEquals(sourceFieldsContext.toString(), ds(sourceFieldsContext).toString()); - } - - @Test - public void testHashMap() { - HashMap map = new HashMap(); - Assert.assertEquals(map, ds(map)); + public void testSerde() { + String test = "string"; + Assert.assertEquals(test, ds(test)); + Assert.assertEquals(test, dsJDK(test)); } @Test - public void testArrayList() { - ArrayList list = new ArrayList(); - Assert.assertEquals(list, ds(list)); - } + public void testEnsureJDKSerialized() { + String test = "string"; + String jdkSerialized = Base64Helper.serializeObject(test, true); + String customSerialized = Base64Helper.serializeObject(test, false); + Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized)); + Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized)); - @Test(expected = OpenSearchException.class) - public void notSafeSerializable() { - serializeObject(new NotSafeSerializable()); } - @Test(expected = OpenSearchException.class) - public void notSafeDeserializable() throws Exception { - final ByteArrayOutputStream bos = new ByteArrayOutputStream(); - try (final ObjectOutputStream out = new ObjectOutputStream(bos)) { - out.writeObject(new NotSafeSerializable()); - } - deserializeObject(BaseEncoding.base64().encode(bos.toByteArray())); - } } diff --git a/src/test/java/org/opensearch/security/support/Base64JDKHelperTest.java b/src/test/java/org/opensearch/security/support/Base64JDKHelperTest.java new file mode 100644 index 0000000000..704f1dc1d7 --- /dev/null +++ b/src/test/java/org/opensearch/security/support/Base64JDKHelperTest.java @@ -0,0 +1,128 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import com.amazon.dlic.auth.ldap.LdapUser; +import com.google.common.io.BaseEncoding; +import org.junit.Assert; +import org.junit.Test; +import org.ldaptive.LdapEntry; +import org.opensearch.OpenSearchException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.AuthCredentials; +import org.opensearch.security.user.User; + +import java.io.ByteArrayOutputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.HashMap; + +public class Base64JDKHelperTest { + private static final class NotSafeSerializable implements Serializable { + private static final long serialVersionUID = 5135559266828470092L; + } + + private static Serializable ds(Serializable s) { + return Base64JDKHelper.deserializeObject(Base64JDKHelper.serializeObject(s)); + } + + @Test + public void testString() { + String string = "string"; + Assert.assertEquals(string, ds(string)); + } + + @Test + public void testInteger() { + Integer integer = 0; + Assert.assertEquals(integer, ds(integer)); + } + + @Test + public void testDouble() { + Double number = 0.0; + Assert.assertEquals(number, ds(number)); + } + + @Test + public void testInetSocketAddress() { + InetSocketAddress inetSocketAddress = new InetSocketAddress(0); + Assert.assertEquals(inetSocketAddress, ds(inetSocketAddress)); + } + + @Test + public void testUser() { + User user = new User("user"); + Assert.assertEquals(user, ds(user)); + } + + @Test + public void testSourceFieldsContext() { + SourceFieldsContext sourceFieldsContext = new SourceFieldsContext(new SearchRequest("")); + Assert.assertEquals(sourceFieldsContext.toString(), ds(sourceFieldsContext).toString()); + } + + @Test + public void testHashMap() { + HashMap map = new HashMap<>(); + map.put("key", "value"); + Assert.assertEquals(map, ds(map)); + } + + @Test + public void testArrayList() { + ArrayList list = new ArrayList<>(); + list.add("value"); + Assert.assertEquals(list, ds(list)); + } + + @Test(expected = OpenSearchException.class) + public void notSafeSerializable() { + Base64JDKHelper.serializeObject(new NotSafeSerializable()); + } + + @Test(expected = OpenSearchException.class) + public void notSafeDeserializable() throws Exception { + final ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try (final ObjectOutputStream out = new ObjectOutputStream(bos)) { + out.writeObject(new NotSafeSerializable()); + } + Base64JDKHelper.deserializeObject(BaseEncoding.base64().encode(bos.toByteArray())); + } + + @Test + public void testLdapUser() { + LdapUser ldapUser = new LdapUser( + "username", + "originalusername", + new LdapEntry("dn"), + new AuthCredentials("originalusername", "12345"), + 34, + WildcardMatcher.ANY + ); + Assert.assertEquals(ldapUser, ds(ldapUser)); + } + + @Test + public void testInjectedUser() { + UserInjector.InjectedUser injectedUser = new UserInjector.InjectedUser("username"); + + // we expect to get User object when deserializing InjectedUser via JDK serialization + User user = new User("username"); + User deserializedUser = (User) ds(injectedUser); + Assert.assertEquals(user, deserializedUser); + Assert.assertTrue(deserializedUser.isInjected()); + } +} diff --git a/src/test/java/org/opensearch/security/support/StreamableRegistryTest.java b/src/test/java/org/opensearch/security/support/StreamableRegistryTest.java new file mode 100644 index 0000000000..13f2448b30 --- /dev/null +++ b/src/test/java/org/opensearch/security/support/StreamableRegistryTest.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.OpenSearchException; + +import java.net.InetSocketAddress; + +public class StreamableRegistryTest { + + StreamableRegistry streamableRegistry = StreamableRegistry.getInstance(); + + @Test + public void testStreamableTypeIDs() { + Assert.assertEquals(1, streamableRegistry.getStreamableID(InetSocketAddress.class)); + Assert.assertThrows(OpenSearchException.class, () -> streamableRegistry.getStreamableID(String.class)); + } +} diff --git a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java index d3363c54d8..abc0e314ef 100644 --- a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java +++ b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java @@ -47,9 +47,6 @@ import static java.util.Collections.emptySet; import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; // CS-ENFORCE-SINGLE @@ -110,9 +107,8 @@ public void setup() { ); } - @Test - public void testSendRequestDecorate() { - + private void testSendRequestDecorate(Version remoteNodeVersion) { + boolean useJDKSerialization = remoteNodeVersion.before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION); ClusterName clusterName = ClusterName.DEFAULT; when(clusterService.getClusterName()).thenReturn(clusterName); @@ -140,7 +136,6 @@ public void testSendRequestDecorate() { User user = new User("John Doe"); threadPool.getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, user); - AsyncSender sender = mock(AsyncSender.class); String action = "testAction"; TransportRequest request = mock(TransportRequest.class); TransportRequestOptions options = mock(TransportRequestOptions.class); @@ -156,37 +151,65 @@ public void testSendRequestDecorate() { DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(localAddress, 1234), Version.CURRENT); Connection connection1 = transportService.getConnection(localNode); - DiscoveryNode otherNode = new DiscoveryNode("local-node", new TransportAddress(localAddress, 4321), Version.CURRENT); + DiscoveryNode otherNode = new DiscoveryNode("remote-node", new TransportAddress(localAddress, 4321), remoteNodeVersion); Connection connection2 = transportService.getConnection(otherNode); + // from thread context inside sendRequestDecorate + AsyncSender sender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); + assertEquals(transientUser, user); + } + }; // isSameNodeRequest = true securityInterceptor.sendRequestDecorate(sender, connection1, action, request, options, handler, localNode); - // from thread context inside sendRequestDecorate - doAnswer(i -> { - User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); - assertEquals(transientUser, user); - return null; - }).when(sender).sendRequest(any(Connection.class), eq(action), eq(request), eq(options), eq(handler)); // from original context User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); assertEquals(transientUser, user); assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null); - // isSameNodeRequest = false - securityInterceptor.sendRequestDecorate(sender, connection2, action, request, options, handler, otherNode); // checking thread context inside sendRequestDecorate - doAnswer(i -> { - String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); - assertEquals(serializedUserHeader, Base64Helper.serializeObject(user)); - return null; - }).when(sender).sendRequest(any(Connection.class), eq(action), eq(request), eq(options), eq(handler)); + sender = new AsyncSender() { + @Override + public void sendRequest( + Connection connection, + String action, + TransportRequest request, + TransportRequestOptions options, + TransportResponseHandler handler + ) { + String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); + assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, useJDKSerialization)); + } + }; + // isSameNodeRequest = false + securityInterceptor.sendRequestDecorate(sender, connection2, action, request, options, handler, localNode); // from original context User transientUser2 = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); assertEquals(transientUser2, user); assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null); + } + + @Test + public void testSendRequestDecorate() { + testSendRequestDecorate(Version.CURRENT); + } + /** + * Tests the scenario when remote node does not implement custom serialization protocol and uses JDK serialization + */ + @Test + public void testSendRequestDecorateWhenRemoteNodeUsesJDKSerde() { + testSendRequestDecorate(Version.V_2_0_0); } } diff --git a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java new file mode 100644 index 0000000000..23a64e4be3 --- /dev/null +++ b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java @@ -0,0 +1,80 @@ +package org.opensearch.security.transport; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.opensearch.Version; +import org.opensearch.common.settings.Settings; +import org.opensearch.security.ssl.SslExceptionHandler; +import org.opensearch.security.ssl.transport.PrincipalExtractor; +import org.opensearch.security.ssl.transport.SSLConfig; +import org.opensearch.security.ssl.transport.SecuritySSLRequestHandler; +import org.opensearch.security.support.ConfigConstants; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestHandler; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class SecuritySSLRequestHandlerTests { + + @Mock + TransportRequestHandler actualHandler; + @Mock + SSLConfig sslConfig; + ThreadPool threadPool; + SslExceptionHandler sslExceptionHandler; + Settings settings; + SecuritySSLRequestHandler securitySSLRequestHandler; + String testAction; + + @Mock + private PrincipalExtractor principalExtractor; + + @Before + public void setUp() { + settings = Settings.builder() + .put("node.name", SecurityInterceptorTests.class.getSimpleName()) + .put("request.headers.default", "1") + .build(); + threadPool = new ThreadPool(settings); + testAction = "test_action"; + sslExceptionHandler = mock(SslExceptionHandler.class); + securitySSLRequestHandler = new SecuritySSLRequestHandler<>( + testAction, + actualHandler, + threadPool, + principalExtractor, + sslConfig, + sslExceptionHandler + ); + doNothing().when(sslExceptionHandler) + .logError(any(Exception.class), any(TransportRequest.class), any(String.class), any(Task.class), anyInt()); + } + + @Test + public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Exception { + TransportRequest transportRequest = mock(TransportRequest.class); + TransportChannel transportChannel = mock(TransportChannel.class); + Task task = mock(Task.class); + doNothing().when(transportChannel).sendResponse(ArgumentMatchers.any(Exception.class)); + when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0); + when(transportChannel.getChannelType()).thenReturn("transport"); + + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); + Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + } +}