diff --git a/src/integrationTest/java/org/opensearch/security/SecurityConfigurationTests.java b/src/integrationTest/java/org/opensearch/security/SecurityConfigurationTests.java index 76ea02494e..2a0d15c452 100644 --- a/src/integrationTest/java/org/opensearch/security/SecurityConfigurationTests.java +++ b/src/integrationTest/java/org/opensearch/security/SecurityConfigurationTests.java @@ -13,8 +13,13 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.apache.http.HttpStatus; import org.awaitility.Awaitility; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -24,6 +29,7 @@ import org.junit.runner.RunWith; import org.opensearch.client.Client; +import org.opensearch.test.framework.AsyncActions; import org.opensearch.test.framework.TestSecurityConfig.Role; import org.opensearch.test.framework.TestSecurityConfig.User; import org.opensearch.test.framework.certificate.TestCertificates; @@ -33,6 +39,8 @@ import org.opensearch.test.framework.cluster.TestRestClient.HttpResponse; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; import static org.opensearch.security.support.ConfigConstants.SECURITY_BACKGROUND_INIT_IF_SECURITYINDEX_NOT_EXIST; @@ -229,4 +237,39 @@ public void shouldUseSecurityAdminTool() throws Exception { .until(() -> client.get("_plugins/_security/api/rolesmapping/readall").getStatusCode(), equalTo(200)); } } + + @Test + public void testParallelTenantPutRequests() throws Exception { + final String TENANT_ENDPOINT = "_plugins/_security/api/tenants/tenant1"; + final String TENANT_BODY = "{\"description\":\"create new tenant\"}"; + final String TENANT_BODY_TWO = "{\"description\":\"update tenant\"}"; + + try (TestRestClient client = cluster.getRestClient(USER_ADMIN)) { + + final CountDownLatch countDownLatch = new CountDownLatch(1); + final List> conflictingRequests = AsyncActions.generate(() -> { + countDownLatch.await(); + return client.putJson(TENANT_ENDPOINT, TENANT_BODY); + }, 4, 4); + + // Make sure all requests start at the same time + countDownLatch.countDown(); + + AtomicInteger numCreatedResponses = new AtomicInteger(); + AsyncActions.getAll(conflictingRequests, 1, TimeUnit.SECONDS).forEach((response) -> { + assertThat(response.getStatusCode(), anyOf(equalTo(HttpStatus.SC_CREATED), equalTo(HttpStatus.SC_CONFLICT))); + if (response.getStatusCode() == HttpStatus.SC_CREATED) numCreatedResponses.getAndIncrement(); + }); + assertThat(numCreatedResponses.get(), equalTo(1)); // should only be one 201 + + TestRestClient.HttpResponse getResponse = client.get(TENANT_ENDPOINT); // make sure the one 201 works + assertThat(getResponse.getBody(), containsString("create new tenant")); + + TestRestClient.HttpResponse updateResponse = client.putJson(TENANT_ENDPOINT, TENANT_BODY_TWO); + assertThat(updateResponse.getStatusCode(), equalTo(HttpStatus.SC_OK)); + + getResponse = client.get(TENANT_ENDPOINT); // make sure update works + assertThat(getResponse.getBody(), containsString("update tenant")); + } + } }