From e213e9c23124418f6f3f2b4a35568ab89ae25232 Mon Sep 17 00:00:00 2001 From: Peter Nied Date: Wed, 13 Apr 2022 14:12:54 -0500 Subject: [PATCH] Add signal/wait model for TestAuditlogImpl (#1758) * Add signal/wait model for TestAuditlogImpl I have been tracking test failures with testRestMethod very often showing failures. My theory is that the execution environment can impact the order of operations sometimes causing the audit log not to contain messages before it is checked. Adding a new method `doThenWaitForMessages(...)` this ensures the log queue is fresh, the triggering action completes, and the expected number of messages were recieved. There is a second long time window that allows for the messages to be flushed, this is likely more than enough - if the messages are recieved the count down latch immediately continues execution so the tests will not wait if they are ready to proceed. While this new method is much more reliable not all tests were encountering such issues, so I've keep the original convention. This can be migrated in one-offs or all at once if we see more troublesome behavoir. The previous methods/fields are depreciated to push future tests to follow the new pattern. Modifications to the rest helper not to throw exceptions were needed to keep the Runnable declaration clean and small. Signed-off-by: Peter Nied --- .../security/auditlog/impl/AuditMessage.java | 8 ++ .../integration/BasicAuditlogTest.java | 125 ++++++++++-------- .../integration/TestAuditlogImpl.java | 51 ++++++- .../security/test/helper/rest/RestHelper.java | 39 ++++-- 4 files changed, 149 insertions(+), 74 deletions(-) diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java b/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java index cf348ab120..8c3a19586e 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java @@ -433,10 +433,18 @@ public String getRequestType() { return (String) this.auditInfo.get(TRANSPORT_REQUEST_TYPE); } + public RestRequest.Method getRequestMethod() { + return (RestRequest.Method) this.auditInfo.get(REST_REQUEST_METHOD); + } + public AuditCategory getCategory() { return msgCategory; } + public String getExceptionStackTrace() { + return (String) this.auditInfo.get(EXCEPTION); + } + @Override public String toString() { try { diff --git a/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java b/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java index e7e41cdf63..3317fada98 100644 --- a/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java +++ b/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java @@ -49,8 +49,15 @@ import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.Collections; +import java.util.List; import java.util.Objects; +import static org.opensearch.rest.RestRequest.Method.GET; +import static org.opensearch.rest.RestRequest.Method.DELETE; +import static org.opensearch.rest.RestRequest.Method.PATCH; +import static org.opensearch.rest.RestRequest.Method.POST; +import static org.opensearch.rest.RestRequest.Method.PUT; + public class BasicAuditlogTest extends AbstractAuditlogiUnitTest { @Test @@ -123,22 +130,18 @@ public void testSSLPlainText() throws Exception { .build(); setup(additionalSettings); - TestAuditlogImpl.clear(); - - try { - nonSslRestHelper().executeGetRequest("_search", encodeBasicHeader("admin", "admin")); - Assert.fail(); - } catch (NoHttpResponseException e) { - //expected - } - - Thread.sleep(1500); - System.out.println(TestAuditlogImpl.sb.toString()); - Assert.assertFalse(TestAuditlogImpl.messages.isEmpty()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("SSL_EXCEPTION")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("exception_stacktrace")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("not an SSL/TLS record")); - Assert.assertTrue(validateMsgs(TestAuditlogImpl.messages)); + final List messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + final RuntimeException ex = Assert.assertThrows(RuntimeException.class, + () -> nonSslRestHelper().executeGetRequest("_search", encodeBasicHeader("admin", "admin"))); + Assert.assertEquals("org.apache.http.NoHttpResponseException", ex.getCause().getClass().getName()); + }, 4); + + // All of the messages should be the same as the http client is attempting multiple times. + messages.stream().forEach((message) -> { + Assert.assertEquals(AuditCategory.SSL_EXCEPTION, message.getCategory()); + Assert.assertTrue(message.getExceptionStackTrace().contains("not an SSL/TLS record")); + }); + Assert.assertTrue(validateMsgs(messages)); } @Test @@ -767,6 +770,10 @@ public void testIndexRequests() throws Exception { Assert.assertTrue(auditlogs.contains("\"audit_transport_request_type\" : \"DeleteIndexRequest\",")); } + private String messageRestRequestMethod(AuditMessage msg) { + return msg.getAsMap().get("audit_rest_request_method").toString(); + } + @Test public void testRestMethod() throws Exception { final Settings settings = Settings.builder() @@ -777,66 +784,70 @@ public void testRestMethod() throws Exception { .build(); setup(settings); final Header adminHeader = encodeBasicHeader("admin", "admin"); + List messages; // test GET - TestAuditlogImpl.clear(); - rh.executeGetRequest("test", adminHeader); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"GET\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executeGetRequest("test", adminHeader); + }, 1); + Assert.assertEquals(GET, messages.get(0).getRequestMethod()); // test PUT - TestAuditlogImpl.clear(); - rh.executePutRequest("test/_doc/0", "{}", adminHeader); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"PUT\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executePutRequest("test/_doc/0", "{}", adminHeader); + }, 1); + Assert.assertEquals(PUT, messages.get(0).getRequestMethod()); // test DELETE - TestAuditlogImpl.clear(); - rh.executeDeleteRequest("test", adminHeader); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"DELETE\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executeDeleteRequest("test", adminHeader); + }, 1); + Assert.assertEquals(DELETE, messages.get(0).getRequestMethod()); // test POST - TestAuditlogImpl.clear(); - rh.executePostRequest("test/_doc", "{}", adminHeader); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"POST\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executePostRequest("test/_doc", "{}", adminHeader); + }, 1); + Assert.assertEquals(POST, messages.get(0).getRequestMethod()); // test PATCH - TestAuditlogImpl.clear(); - rh.executePatchRequest("/_opendistro/_security/api/audit", "[]"); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"PATCH\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executePatchRequest("/_opendistro/_security/api/audit", "[]"); + }, 1); + Assert.assertEquals(PATCH, messages.get(0).getRequestMethod()); // test MISSING_PRIVILEGES // admin does not have REST role here - TestAuditlogImpl.clear(); - rh.executePatchRequest("/_opendistro/_security/api/audit", "[]", adminHeader); - Assert.assertEquals(2, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("MISSING_PRIVILEGES")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("AUTHENTICATED")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"PATCH\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executePatchRequest("/_opendistro/_security/api/audit", "[]", adminHeader); + }, 2); + // The intital request is authenicated + Assert.assertEquals(PATCH, messages.get(0).getRequestMethod()); + Assert.assertEquals(AuditCategory.AUTHENTICATED, messages.get(0).getCategory()); + // The secondary request does not have permissions + Assert.assertEquals(PATCH, messages.get(1).getRequestMethod()); + Assert.assertEquals(AuditCategory.MISSING_PRIVILEGES, messages.get(1).getCategory()); // test AUTHENTICATED - TestAuditlogImpl.clear(); - rh.executeGetRequest("test", adminHeader); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("AUTHENTICATED")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"GET\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executeGetRequest("test", adminHeader); + }, 1); + Assert.assertEquals(AuditCategory.AUTHENTICATED, messages.get(0).getCategory()); + Assert.assertEquals(GET, messages.get(0).getRequestMethod()); // test FAILED_LOGIN - TestAuditlogImpl.clear(); - rh.executeGetRequest("test", encodeBasicHeader("random", "random")); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("FAILED_LOGIN")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"GET\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executeGetRequest("test", encodeBasicHeader("random", "random")); + }, 1); + Assert.assertEquals(AuditCategory.FAILED_LOGIN, messages.get(0).getCategory()); + Assert.assertEquals(GET, messages.get(0).getRequestMethod()); // test BAD_HEADERS - TestAuditlogImpl.clear(); - rh.executeGetRequest("test", new BasicHeader("_opendistro_security_user", "xxx")); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("BAD_HEADERS")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"GET\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executeGetRequest("test", new BasicHeader("_opendistro_security_user", "xxx")); + }, 1); + Assert.assertEquals(AuditCategory.BAD_HEADERS, messages.get(0).getCategory()); + Assert.assertEquals(GET, messages.get(0).getRequestMethod()); } @Test diff --git a/src/test/java/org/opensearch/security/auditlog/integration/TestAuditlogImpl.java b/src/test/java/org/opensearch/security/auditlog/integration/TestAuditlogImpl.java index ecafbfa469..a160a2504d 100644 --- a/src/test/java/org/opensearch/security/auditlog/integration/TestAuditlogImpl.java +++ b/src/test/java/org/opensearch/security/auditlog/integration/TestAuditlogImpl.java @@ -17,6 +17,9 @@ import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.opensearch.common.settings.Settings; @@ -25,23 +28,61 @@ public class TestAuditlogImpl extends AuditLogSink { + /** Use the results of `doThenWaitForMessages(...)` instead */ + @Deprecated public static List messages = new ArrayList(100); + /** Check messages indvidually instead of searching this string */ + @Deprecated public static StringBuffer sb = new StringBuffer(); + private static final AtomicReference countDownRef = new AtomicReference<>(); + private static final AtomicReference> messagesRef = new AtomicReference<>(); public TestAuditlogImpl(String name, Settings settings, String settingsPrefix, AuditLogSink fallbackSink) { super(name, settings, null, fallbackSink); } - - public synchronized boolean doStore(AuditMessage msg) { + public synchronized boolean doStore(AuditMessage msg) { + if (messagesRef.get() == null || countDownRef.get() == null) { + throw new RuntimeException("No message latch is waiting"); + } sb.append(msg.toPrettyString()+System.lineSeparator()); - messages.add(msg); + messagesRef.get().add(msg); + countDownRef.get().countDown(); return true; } + /** Unneeded after switching to `doThenWaitForMessages(...)` as data is automatically flushed */ + @Deprecated public static synchronized void clear() { - sb.setLength(0); - messages.clear(); + doThenWaitForMessages(() -> {}, 0); + } + + /** + * Perform an action and then wait until the expected number of messages have been found. + */ + public static List doThenWaitForMessages(final Runnable action, final int expectedCount) { + final CountDownLatch latch = new CountDownLatch(expectedCount); + final List messages = new ArrayList<>(); + countDownRef.set(latch); + messagesRef.set(messages); + + TestAuditlogImpl.sb = new StringBuffer(); + TestAuditlogImpl.messages = messages; + + try { + action.run(); + final int maxSecondsToWaitForMessages = 1; + final boolean foundAll = latch.await(maxSecondsToWaitForMessages, TimeUnit.SECONDS); + if (!foundAll) { + throw new RuntimeException("Did not recieve all " + expectedCount +" audit messages after a short wait."); + } + if (messages.size() != expectedCount) { + throw new RuntimeException("Unexpected number of messages, was expecting " + expectedCount + ", recieved " + messages.size()); + } + } catch (final InterruptedException e) { + throw new RuntimeException("Unexpected exception", e); + } + return new ArrayList<>(messages); } @Override diff --git a/src/test/java/org/opensearch/security/test/helper/rest/RestHelper.java b/src/test/java/org/opensearch/security/test/helper/rest/RestHelper.java index a096eb5496..9530622bc7 100644 --- a/src/test/java/org/opensearch/security/test/helper/rest/RestHelper.java +++ b/src/test/java/org/opensearch/security/test/helper/rest/RestHelper.java @@ -32,6 +32,7 @@ import java.io.FileInputStream; import java.io.IOException; +import java.io.UnsupportedEncodingException; import java.nio.charset.StandardCharsets; import java.security.KeyStore; import java.util.Arrays; @@ -136,48 +137,48 @@ public HttpResponse[] executeMultipleAsyncPutRequest(final int numOfRequests, fi .toArray(s -> new HttpResponse[s]); } - public HttpResponse executeGetRequest(final String request, Header... header) throws Exception { + public HttpResponse executeGetRequest(final String request, Header... header) { return executeRequest(new HttpGet(getHttpServerUri() + "/" + request), header); } - public HttpResponse executeHeadRequest(final String request, Header... header) throws Exception { + public HttpResponse executeHeadRequest(final String request, Header... header) { return executeRequest(new HttpHead(getHttpServerUri() + "/" + request), header); } - public HttpResponse executeOptionsRequest(final String request) throws Exception { + public HttpResponse executeOptionsRequest(final String request) { return executeRequest(new HttpOptions(getHttpServerUri() + "/" + request)); } - public HttpResponse executePutRequest(final String request, String body, Header... header) throws Exception { + public HttpResponse executePutRequest(final String request, String body, Header... header) { HttpPut uriRequest = new HttpPut(getHttpServerUri() + "/" + request); if (body != null && !body.isEmpty()) { - uriRequest.setEntity(new StringEntity(body)); + uriRequest.setEntity(createStringEntity(body)); } return executeRequest(uriRequest, header); } - public HttpResponse executeDeleteRequest(final String request, Header... header) throws Exception { + public HttpResponse executeDeleteRequest(final String request, Header... header) { return executeRequest(new HttpDelete(getHttpServerUri() + "/" + request), header); } - public HttpResponse executePostRequest(final String request, String body, Header... header) throws Exception { + public HttpResponse executePostRequest(final String request, String body, Header... header) { HttpPost uriRequest = new HttpPost(getHttpServerUri() + "/" + request); if (body != null && !body.isEmpty()) { - uriRequest.setEntity(new StringEntity(body)); + uriRequest.setEntity(createStringEntity(body)); } return executeRequest(uriRequest, header); } - public HttpResponse executePatchRequest(final String request, String body, Header... header) throws Exception { + public HttpResponse executePatchRequest(final String request, String body, Header... header) { HttpPatch uriRequest = new HttpPatch(getHttpServerUri() + "/" + request); if (body != null && !body.isEmpty()) { - uriRequest.setEntity(new StringEntity(body)); + uriRequest.setEntity(createStringEntity(body)); } return executeRequest(uriRequest, header); } - public HttpResponse executeRequest(HttpUriRequest uriRequest, Header... header) throws Exception { + public HttpResponse executeRequest(HttpUriRequest uriRequest, Header... header) { CloseableHttpClient httpClient = null; try { @@ -197,13 +198,27 @@ public HttpResponse executeRequest(HttpUriRequest uriRequest, Header... header) HttpResponse res = new HttpResponse(httpClient.execute(uriRequest)); log.debug(res.getBody()); return res; + } catch (final Exception e) { + throw new RuntimeException(e); } finally { if (httpClient != null) { - httpClient.close(); + try { + httpClient.close(); + } catch (final Exception e) { + throw new RuntimeException(e); + } } } } + + private StringEntity createStringEntity(String body) { + try { + return new StringEntity(body); + } catch (final UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } protected final String getHttpServerUri() { final String address = "http" + (enableHTTPClientSSL ? "s" : "") + "://" + clusterInfo.httpHost + ":" + clusterInfo.httpPort;