diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java b/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java index cf348ab120..8c3a19586e 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java @@ -433,10 +433,18 @@ public String getRequestType() { return (String) this.auditInfo.get(TRANSPORT_REQUEST_TYPE); } + public RestRequest.Method getRequestMethod() { + return (RestRequest.Method) this.auditInfo.get(REST_REQUEST_METHOD); + } + public AuditCategory getCategory() { return msgCategory; } + public String getExceptionStackTrace() { + return (String) this.auditInfo.get(EXCEPTION); + } + @Override public String toString() { try { diff --git a/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java b/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java index e7e41cdf63..3317fada98 100644 --- a/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java +++ b/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java @@ -49,8 +49,15 @@ import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.Collections; +import java.util.List; import java.util.Objects; +import static org.opensearch.rest.RestRequest.Method.GET; +import static org.opensearch.rest.RestRequest.Method.DELETE; +import static org.opensearch.rest.RestRequest.Method.PATCH; +import static org.opensearch.rest.RestRequest.Method.POST; +import static org.opensearch.rest.RestRequest.Method.PUT; + public class BasicAuditlogTest extends AbstractAuditlogiUnitTest { @Test @@ -123,22 +130,18 @@ public void testSSLPlainText() throws Exception { .build(); setup(additionalSettings); - TestAuditlogImpl.clear(); - - try { - nonSslRestHelper().executeGetRequest("_search", encodeBasicHeader("admin", "admin")); - Assert.fail(); - } catch (NoHttpResponseException e) { - //expected - } - - Thread.sleep(1500); - System.out.println(TestAuditlogImpl.sb.toString()); - Assert.assertFalse(TestAuditlogImpl.messages.isEmpty()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("SSL_EXCEPTION")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("exception_stacktrace")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("not an SSL/TLS record")); - Assert.assertTrue(validateMsgs(TestAuditlogImpl.messages)); + final List messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + final RuntimeException ex = Assert.assertThrows(RuntimeException.class, + () -> nonSslRestHelper().executeGetRequest("_search", encodeBasicHeader("admin", "admin"))); + Assert.assertEquals("org.apache.http.NoHttpResponseException", ex.getCause().getClass().getName()); + }, 4); + + // All of the messages should be the same as the http client is attempting multiple times. + messages.stream().forEach((message) -> { + Assert.assertEquals(AuditCategory.SSL_EXCEPTION, message.getCategory()); + Assert.assertTrue(message.getExceptionStackTrace().contains("not an SSL/TLS record")); + }); + Assert.assertTrue(validateMsgs(messages)); } @Test @@ -767,6 +770,10 @@ public void testIndexRequests() throws Exception { Assert.assertTrue(auditlogs.contains("\"audit_transport_request_type\" : \"DeleteIndexRequest\",")); } + private String messageRestRequestMethod(AuditMessage msg) { + return msg.getAsMap().get("audit_rest_request_method").toString(); + } + @Test public void testRestMethod() throws Exception { final Settings settings = Settings.builder() @@ -777,66 +784,70 @@ public void testRestMethod() throws Exception { .build(); setup(settings); final Header adminHeader = encodeBasicHeader("admin", "admin"); + List messages; // test GET - TestAuditlogImpl.clear(); - rh.executeGetRequest("test", adminHeader); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"GET\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executeGetRequest("test", adminHeader); + }, 1); + Assert.assertEquals(GET, messages.get(0).getRequestMethod()); // test PUT - TestAuditlogImpl.clear(); - rh.executePutRequest("test/_doc/0", "{}", adminHeader); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"PUT\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executePutRequest("test/_doc/0", "{}", adminHeader); + }, 1); + Assert.assertEquals(PUT, messages.get(0).getRequestMethod()); // test DELETE - TestAuditlogImpl.clear(); - rh.executeDeleteRequest("test", adminHeader); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"DELETE\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executeDeleteRequest("test", adminHeader); + }, 1); + Assert.assertEquals(DELETE, messages.get(0).getRequestMethod()); // test POST - TestAuditlogImpl.clear(); - rh.executePostRequest("test/_doc", "{}", adminHeader); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"POST\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executePostRequest("test/_doc", "{}", adminHeader); + }, 1); + Assert.assertEquals(POST, messages.get(0).getRequestMethod()); // test PATCH - TestAuditlogImpl.clear(); - rh.executePatchRequest("/_opendistro/_security/api/audit", "[]"); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"PATCH\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executePatchRequest("/_opendistro/_security/api/audit", "[]"); + }, 1); + Assert.assertEquals(PATCH, messages.get(0).getRequestMethod()); // test MISSING_PRIVILEGES // admin does not have REST role here - TestAuditlogImpl.clear(); - rh.executePatchRequest("/_opendistro/_security/api/audit", "[]", adminHeader); - Assert.assertEquals(2, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("MISSING_PRIVILEGES")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("AUTHENTICATED")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"PATCH\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executePatchRequest("/_opendistro/_security/api/audit", "[]", adminHeader); + }, 2); + // The intital request is authenicated + Assert.assertEquals(PATCH, messages.get(0).getRequestMethod()); + Assert.assertEquals(AuditCategory.AUTHENTICATED, messages.get(0).getCategory()); + // The secondary request does not have permissions + Assert.assertEquals(PATCH, messages.get(1).getRequestMethod()); + Assert.assertEquals(AuditCategory.MISSING_PRIVILEGES, messages.get(1).getCategory()); // test AUTHENTICATED - TestAuditlogImpl.clear(); - rh.executeGetRequest("test", adminHeader); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("AUTHENTICATED")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"GET\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executeGetRequest("test", adminHeader); + }, 1); + Assert.assertEquals(AuditCategory.AUTHENTICATED, messages.get(0).getCategory()); + Assert.assertEquals(GET, messages.get(0).getRequestMethod()); // test FAILED_LOGIN - TestAuditlogImpl.clear(); - rh.executeGetRequest("test", encodeBasicHeader("random", "random")); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("FAILED_LOGIN")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"GET\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executeGetRequest("test", encodeBasicHeader("random", "random")); + }, 1); + Assert.assertEquals(AuditCategory.FAILED_LOGIN, messages.get(0).getCategory()); + Assert.assertEquals(GET, messages.get(0).getRequestMethod()); // test BAD_HEADERS - TestAuditlogImpl.clear(); - rh.executeGetRequest("test", new BasicHeader("_opendistro_security_user", "xxx")); - Assert.assertEquals(1, TestAuditlogImpl.messages.size()); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("BAD_HEADERS")); - Assert.assertTrue(TestAuditlogImpl.sb.toString().contains("\"audit_rest_request_method\" : \"GET\"")); + messages = TestAuditlogImpl.doThenWaitForMessages(() -> { + rh.executeGetRequest("test", new BasicHeader("_opendistro_security_user", "xxx")); + }, 1); + Assert.assertEquals(AuditCategory.BAD_HEADERS, messages.get(0).getCategory()); + Assert.assertEquals(GET, messages.get(0).getRequestMethod()); } @Test diff --git a/src/test/java/org/opensearch/security/auditlog/integration/TestAuditlogImpl.java b/src/test/java/org/opensearch/security/auditlog/integration/TestAuditlogImpl.java index ecafbfa469..a160a2504d 100644 --- a/src/test/java/org/opensearch/security/auditlog/integration/TestAuditlogImpl.java +++ b/src/test/java/org/opensearch/security/auditlog/integration/TestAuditlogImpl.java @@ -17,6 +17,9 @@ import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.opensearch.common.settings.Settings; @@ -25,23 +28,61 @@ public class TestAuditlogImpl extends AuditLogSink { + /** Use the results of `doThenWaitForMessages(...)` instead */ + @Deprecated public static List messages = new ArrayList(100); + /** Check messages indvidually instead of searching this string */ + @Deprecated public static StringBuffer sb = new StringBuffer(); + private static final AtomicReference countDownRef = new AtomicReference<>(); + private static final AtomicReference> messagesRef = new AtomicReference<>(); public TestAuditlogImpl(String name, Settings settings, String settingsPrefix, AuditLogSink fallbackSink) { super(name, settings, null, fallbackSink); } - - public synchronized boolean doStore(AuditMessage msg) { + public synchronized boolean doStore(AuditMessage msg) { + if (messagesRef.get() == null || countDownRef.get() == null) { + throw new RuntimeException("No message latch is waiting"); + } sb.append(msg.toPrettyString()+System.lineSeparator()); - messages.add(msg); + messagesRef.get().add(msg); + countDownRef.get().countDown(); return true; } + /** Unneeded after switching to `doThenWaitForMessages(...)` as data is automatically flushed */ + @Deprecated public static synchronized void clear() { - sb.setLength(0); - messages.clear(); + doThenWaitForMessages(() -> {}, 0); + } + + /** + * Perform an action and then wait until the expected number of messages have been found. + */ + public static List doThenWaitForMessages(final Runnable action, final int expectedCount) { + final CountDownLatch latch = new CountDownLatch(expectedCount); + final List messages = new ArrayList<>(); + countDownRef.set(latch); + messagesRef.set(messages); + + TestAuditlogImpl.sb = new StringBuffer(); + TestAuditlogImpl.messages = messages; + + try { + action.run(); + final int maxSecondsToWaitForMessages = 1; + final boolean foundAll = latch.await(maxSecondsToWaitForMessages, TimeUnit.SECONDS); + if (!foundAll) { + throw new RuntimeException("Did not recieve all " + expectedCount +" audit messages after a short wait."); + } + if (messages.size() != expectedCount) { + throw new RuntimeException("Unexpected number of messages, was expecting " + expectedCount + ", recieved " + messages.size()); + } + } catch (final InterruptedException e) { + throw new RuntimeException("Unexpected exception", e); + } + return new ArrayList<>(messages); } @Override diff --git a/src/test/java/org/opensearch/security/test/helper/rest/RestHelper.java b/src/test/java/org/opensearch/security/test/helper/rest/RestHelper.java index a096eb5496..9530622bc7 100644 --- a/src/test/java/org/opensearch/security/test/helper/rest/RestHelper.java +++ b/src/test/java/org/opensearch/security/test/helper/rest/RestHelper.java @@ -32,6 +32,7 @@ import java.io.FileInputStream; import java.io.IOException; +import java.io.UnsupportedEncodingException; import java.nio.charset.StandardCharsets; import java.security.KeyStore; import java.util.Arrays; @@ -136,48 +137,48 @@ public HttpResponse[] executeMultipleAsyncPutRequest(final int numOfRequests, fi .toArray(s -> new HttpResponse[s]); } - public HttpResponse executeGetRequest(final String request, Header... header) throws Exception { + public HttpResponse executeGetRequest(final String request, Header... header) { return executeRequest(new HttpGet(getHttpServerUri() + "/" + request), header); } - public HttpResponse executeHeadRequest(final String request, Header... header) throws Exception { + public HttpResponse executeHeadRequest(final String request, Header... header) { return executeRequest(new HttpHead(getHttpServerUri() + "/" + request), header); } - public HttpResponse executeOptionsRequest(final String request) throws Exception { + public HttpResponse executeOptionsRequest(final String request) { return executeRequest(new HttpOptions(getHttpServerUri() + "/" + request)); } - public HttpResponse executePutRequest(final String request, String body, Header... header) throws Exception { + public HttpResponse executePutRequest(final String request, String body, Header... header) { HttpPut uriRequest = new HttpPut(getHttpServerUri() + "/" + request); if (body != null && !body.isEmpty()) { - uriRequest.setEntity(new StringEntity(body)); + uriRequest.setEntity(createStringEntity(body)); } return executeRequest(uriRequest, header); } - public HttpResponse executeDeleteRequest(final String request, Header... header) throws Exception { + public HttpResponse executeDeleteRequest(final String request, Header... header) { return executeRequest(new HttpDelete(getHttpServerUri() + "/" + request), header); } - public HttpResponse executePostRequest(final String request, String body, Header... header) throws Exception { + public HttpResponse executePostRequest(final String request, String body, Header... header) { HttpPost uriRequest = new HttpPost(getHttpServerUri() + "/" + request); if (body != null && !body.isEmpty()) { - uriRequest.setEntity(new StringEntity(body)); + uriRequest.setEntity(createStringEntity(body)); } return executeRequest(uriRequest, header); } - public HttpResponse executePatchRequest(final String request, String body, Header... header) throws Exception { + public HttpResponse executePatchRequest(final String request, String body, Header... header) { HttpPatch uriRequest = new HttpPatch(getHttpServerUri() + "/" + request); if (body != null && !body.isEmpty()) { - uriRequest.setEntity(new StringEntity(body)); + uriRequest.setEntity(createStringEntity(body)); } return executeRequest(uriRequest, header); } - public HttpResponse executeRequest(HttpUriRequest uriRequest, Header... header) throws Exception { + public HttpResponse executeRequest(HttpUriRequest uriRequest, Header... header) { CloseableHttpClient httpClient = null; try { @@ -197,13 +198,27 @@ public HttpResponse executeRequest(HttpUriRequest uriRequest, Header... header) HttpResponse res = new HttpResponse(httpClient.execute(uriRequest)); log.debug(res.getBody()); return res; + } catch (final Exception e) { + throw new RuntimeException(e); } finally { if (httpClient != null) { - httpClient.close(); + try { + httpClient.close(); + } catch (final Exception e) { + throw new RuntimeException(e); + } } } } + + private StringEntity createStringEntity(String body) { + try { + return new StringEntity(body); + } catch (final UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } protected final String getHttpServerUri() { final String address = "http" + (enableHTTPClientSSL ? "s" : "") + "://" + clusterInfo.httpHost + ":" + clusterInfo.httpPort;