diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetAlertsRequest.java b/src/main/java/org/opensearch/securityanalytics/action/GetAlertsRequest.java index 1e0cb6113..4b30267ae 100644 --- a/src/main/java/org/opensearch/securityanalytics/action/GetAlertsRequest.java +++ b/src/main/java/org/opensearch/securityanalytics/action/GetAlertsRequest.java @@ -6,6 +6,8 @@ import java.io.IOException; import java.util.Locale; +import java.util.ArrayList; +import java.util.List; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.StreamInput; @@ -19,6 +21,7 @@ public class GetAlertsRequest extends ActionRequest { private String detectorId; + private ArrayList findingIds; private String logType; private Table table; private String severityLevel; @@ -26,8 +29,11 @@ public class GetAlertsRequest extends ActionRequest { public static final String DETECTOR_ID = "detector_id"; + + // Updated the constructor to include findingIds public GetAlertsRequest( String detectorId, + ArrayList findingIds, String logType, Table table, String severityLevel, @@ -35,20 +41,29 @@ public GetAlertsRequest( ) { super(); this.detectorId = detectorId; + this.findingIds = findingIds; this.logType = logType; this.table = table; this.severityLevel = severityLevel; this.alertState = alertState; } - public GetAlertsRequest(StreamInput sin) throws IOException { - this( - sin.readOptionalString(), - sin.readOptionalString(), - Table.readFrom(sin), - sin.readString(), - sin.readString() - ); - } + +public GetAlertsRequest(StreamInput sin) throws IOException { + super(); + + this.detectorId = sin.readOptionalString(); + + List findingIdsList = sin.readStringList(); + this.findingIds = findingIdsList != null ? new ArrayList<>(findingIdsList) : new ArrayList<>(); + + this.logType = sin.readOptionalString(); + this.table = Table.readFrom(sin); + this.severityLevel = sin.readString(); + this.alertState = sin.readString(); +} + + + @Override public ActionRequestValidationException validate() { @@ -61,9 +76,11 @@ public ActionRequestValidationException validate() { return validationException; } + // Added the writeTo for findingIds @Override public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(detectorId); + out.writeStringCollection(findingIds); out.writeOptionalString(logType); table.writeTo(out); out.writeString(severityLevel); @@ -89,4 +106,9 @@ public String getAlertState() { public String getLogType() { return logType; } + + // Getter Function for findingIds + public ArrayList getFindingIds() { + return findingIds; + } } diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetFindingsRequest.java b/src/main/java/org/opensearch/securityanalytics/action/GetFindingsRequest.java index 8e99720ee..f7f49f6d8 100644 --- a/src/main/java/org/opensearch/securityanalytics/action/GetFindingsRequest.java +++ b/src/main/java/org/opensearch/securityanalytics/action/GetFindingsRequest.java @@ -22,15 +22,19 @@ public class GetFindingsRequest extends ActionRequest { private String detectorId; private Table table; + public static final String DETECTOR_ID = "detector_id"; public GetFindingsRequest(String detectorId) { super(); this.detectorId = detectorId; } + public GetFindingsRequest(StreamInput sin) throws IOException { this( + sin.readOptionalString(), + // sin.readOptionalList for arraylist findingIds sin.readOptionalString(), Table.readFrom(sin) ); @@ -38,6 +42,7 @@ public GetFindingsRequest(StreamInput sin) throws IOException { public GetFindingsRequest(String detectorId, String logType, Table table) { this.detectorId = detectorId; + // Updated param above this.logType = logType; this.table = table; } @@ -57,6 +62,7 @@ public ActionRequestValidationException validate() { public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(detectorId); out.writeOptionalString(logType); + // Write the finding ids table.writeTo(out); } @@ -71,4 +77,5 @@ public String getLogType() { public Table getTable() { return table; } + } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/alerts/AlertsService.java b/src/main/java/org/opensearch/securityanalytics/alerts/AlertsService.java index a61fe9d35..87b1ed295 100644 --- a/src/main/java/org/opensearch/securityanalytics/alerts/AlertsService.java +++ b/src/main/java/org/opensearch/securityanalytics/alerts/AlertsService.java @@ -56,6 +56,7 @@ public AlertsService(Client client) { * Searches alerts generated by specific Detector * * @param detectorId id of Detector + * @param findingIds finding id to filter alerts * @param table group of search related parameters * @param severityLevel alert severity level * @param alertState current alert state @@ -63,6 +64,7 @@ public AlertsService(Client client) { */ public void getAlertsByDetectorId( String detectorId, + ArrayList findingIds, Table table, String severityLevel, String alertState, @@ -81,9 +83,11 @@ public void onResponse(GetDetectorResponse getDetectorResponse) { monitorId -> monitorToDetectorMapping.put(monitorId, detector.getId()) ); // Get alerts for all monitor ids + // Do i need to add finding IDs for this method? Line 128 another doubt AlertsService.this.getAlertsByMonitorIds( monitorToDetectorMapping, monitorIds, + findingIds, DetectorMonitorConfig.getAllAlertsIndicesPattern(detector.getDetectorType()), table, severityLevel, @@ -116,15 +120,18 @@ public void onFailure(Exception e) { * * @param monitorToDetectorMapping monitorId to detectorId mapping * @param monitorIds list of monitor ids + * @param findingIds finding id to filter alerts * @param alertIndex alert index to search alerts on * @param table group of search related parameters * @param severityLevel alert severity level * @param alertState current alert state * * @param listener ActionListener to get notified on response or error */ + public void getAlertsByMonitorIds( Map monitorToDetectorMapping, List monitorIds, + List findingIds, String alertIndex, Table table, String severityLevel, @@ -135,6 +142,7 @@ public void getAlertsByMonitorIds( org.opensearch.commons.alerting.action.GetAlertsRequest req = new org.opensearch.commons.alerting.action.GetAlertsRequest( table, + findingIds, severityLevel, alertState, null, @@ -174,6 +182,7 @@ void setIndicesAdminClient(Client client) { public void getAlerts( List detectors, + ArrayList findingIds, String logType, Table table, String severityLevel, @@ -200,6 +209,7 @@ public void getAlerts( AlertsService.this.getAlertsByMonitorIds( monitorToDetectorMapping, allMonitorIds, + findingIds, DetectorMonitorConfig.getAllAlertsIndicesPattern(logType), table, severityLevel, @@ -243,12 +253,14 @@ private AlertDto mapAlertToAlertDto(Alert alert, String detectorId) { ); } + // Check where exactly is this method used? public void getAlerts(List alertIds, Detector detector, Table table, ActionListener actionListener) { GetAlertsRequest request = new GetAlertsRequest( table, + List.of(), "ALL", "ALL", null, diff --git a/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java b/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java index 4674f40cc..7b8a80c83 100644 --- a/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java +++ b/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java @@ -52,6 +52,8 @@ public FindingsService(Client client) { * @param table group of search related parameters * @param listener ActionListener to get notified on response or error */ + + // This is the function and add a new parameter for finding ids public void getFindingsByDetectorId(String detectorId, Table table, ActionListener listener ) { this.client.execute(GetDetectorAction.INSTANCE, new GetDetectorRequest(detectorId, -3L), new ActionListener<>() { @@ -131,7 +133,7 @@ public void getFindingsByMonitorIds( org.opensearch.commons.alerting.action.GetFindingsRequest req = new org.opensearch.commons.alerting.action.GetFindingsRequest( - null, + null, // Need to pass the findingId as List but in api it is a sting[it will change] table, null, findingIndexName, diff --git a/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetAlertsAction.java index 0d6bcb52d..de10dbf8b 100644 --- a/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetAlertsAction.java +++ b/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetAlertsAction.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.util.List; +import java.util.ArrayList; import java.util.Locale; import org.opensearch.client.node.NodeClient; import org.opensearch.commons.alerting.model.Table; @@ -19,8 +20,8 @@ import org.opensearch.securityanalytics.action.GetFindingsRequest; import org.opensearch.securityanalytics.model.Detector; - import static java.util.Collections.singletonList; +import java.util.Arrays; import static org.opensearch.rest.RestRequest.Method.GET; public class RestGetAlertsAction extends BaseRestHandler { @@ -32,8 +33,8 @@ public String getName() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - String detectorId = request.param("detector_id", null); + String[] findingIds = request.paramAsStringArray("findingIds", null); String detectorType = request.param("detectorType", null); String severityLevel = request.param("severityLevel", "ALL"); String alertState = request.param("alertState", "ALL"); @@ -56,12 +57,14 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli GetAlertsRequest req = new GetAlertsRequest( detectorId, + convertFindingIdsToList(findingIds), detectorType, table, severityLevel, alertState ); + // Request goes to TransportGetAlertsAction class return channel -> client.execute( GetAlertsAction.INSTANCE, req, @@ -73,4 +76,12 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli public List routes() { return singletonList(new Route(GET, SecurityAnalyticsPlugin.ALERTS_BASE_URI)); } -} \ No newline at end of file + + private ArrayList convertFindingIdsToList(String[] findingIds) { + if (findingIds == null) { + return new ArrayList<>(); + } + return new ArrayList<>(Arrays.asList(findingIds)); + } + +} diff --git a/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetFindingsAction.java b/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetFindingsAction.java index efc04e1e5..509ce2a18 100644 --- a/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetFindingsAction.java +++ b/src/main/java/org/opensearch/securityanalytics/resthandler/RestGetFindingsAction.java @@ -52,10 +52,12 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli GetFindingsRequest req = new GetFindingsRequest( detectorId, + // Add finding ids detectorType, table ); + // Request goes to TransportGetFindingsAction class return channel -> client.execute( GetFindingsAction.INSTANCE, req, diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetAlertsAction.java index f01929fc9..8512941cb 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetAlertsAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetAlertsAction.java @@ -74,6 +74,7 @@ public TransportGetAlertsAction(TransportService transportService, ActionFilters this.clusterService.getClusterSettings().addSettingsUpdateConsumer(SecurityAnalyticsSettings.FILTER_BY_BACKEND_ROLES, this::setFilterByEnabled); } + // The client request hits here @Override protected void doExecute(Task task, GetAlertsRequest request, ActionListener actionListener) { @@ -88,6 +89,8 @@ protected void doExecute(Task task, GetAlertsRequest request, ActionListener actionListener) { @@ -106,6 +108,7 @@ protected void doExecute(Task task, GetFindingsRequest request, ActionListener() { + alertssService.getAlertsByDetectorId("detector_id123", new ArrayList(), table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() { @Override public void onResponse(GetAlertsResponse getAlertsResponse) { assertEquals(2, (int)getAlertsResponse.getTotalAlerts()); @@ -259,7 +260,7 @@ public void testGetFindings_getFindingsByMonitorIdFailures() { ActionListener l = invocation.getArgument(6); l.onFailure(new IllegalArgumentException("Error getting findings")); return null; - }).when(alertssService).getAlertsByMonitorIds(any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(ActionListener.class)); + }).when(alertssService).getAlertsByMonitorIds(any(), any(), any(), anyString(), any(Table.class), anyString(), anyString(), any(ActionListener.class)); // Call getFindingsByDetectorId Table table = new Table( @@ -270,7 +271,7 @@ public void testGetFindings_getFindingsByMonitorIdFailures() { 0, null ); - alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() { + alertssService.getAlertsByDetectorId("detector_id123",new ArrayList(), table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() { @Override public void onResponse(GetAlertsResponse getAlertsResponse) { fail("this test should've failed"); @@ -305,7 +306,7 @@ public void testGetFindings_getDetectorFailure() { 0, null ); - alertssService.getAlertsByDetectorId("detector_id123", table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() { + alertssService.getAlertsByDetectorId("detector_id123", new ArrayList(), table, "severity_low", Alert.State.COMPLETED.toString(), new ActionListener<>() { @Override public void onResponse(GetAlertsResponse getAlertsResponse) { fail("this test should've failed"); diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java index fbd091595..d2855a87d 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java @@ -12,6 +12,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Iterator; import java.util.stream.Collectors; import org.apache.hc.core5.http.HttpStatus; @@ -38,11 +39,11 @@ import static org.opensearch.securityanalytics.TestHelpers.netFlowMappings; import static org.opensearch.securityanalytics.TestHelpers.randomAction; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; -import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndThreatIntel; +// import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndThreatIntel; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggers; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithTriggers; import static org.opensearch.securityanalytics.TestHelpers.randomDoc; -import static org.opensearch.securityanalytics.TestHelpers.randomDocWithIpIoc; +// import static org.opensearch.securityanalytics.TestHelpers.randomDocWithIpIoc; import static org.opensearch.securityanalytics.TestHelpers.randomIndex; import static org.opensearch.securityanalytics.TestHelpers.randomRule; import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; @@ -86,7 +87,7 @@ public void testGetAlerts_success() throws IOException { Detector detector = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(createdId)), getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), - List.of(new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(createdId), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction), List.of()))); + List.of(new DetectorTrigger("", "test-trigger", "1", List.of(), List.of(createdId), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction), List.of()))); createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); @@ -171,6 +172,167 @@ public void testGetAlerts_success() throws IOException { assertEquals(((ArrayList) ackAlertsResponseMap.get("acknowledged")).size(), 1); } + // @SuppressWarnings("unchecked") + public void testGetAlertsByFindingIds() throws IOException { + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + String rule = randomRule(); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.RULE_BASE_URI, Collections.singletonMap("category", randomDetectorType()), + new StringEntity(rule), new BasicHeader("Content-Type", "application/json")); + Assert.assertEquals("Create rule failed", RestStatus.CREATED, restStatus(createResponse)); + + Map responseBody = asMap(createResponse); + + String createdId = responseBody.get("_id").toString(); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response response = client().performRequest(createMappingRequest); + assertEquals(HttpStatus.SC_OK, response.getStatusLine().getStatusCode()); + + createAlertingMonitorConfigIndex(null); + Action triggerAction = randomAction(createDestination()); + + Detector detector = randomDetectorWithInputsAndTriggers(List.of(new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(createdId)), + getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList()))), + List.of(new DetectorTrigger("", "test-trigger", "1", List.of(), List.of(createdId), List.of(), List.of("attack.defense_evasion"), List.of(triggerAction), List.of()))); + + createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + Assert.assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + + responseBody = asMap(createResponse); + + createdId = responseBody.get("_id").toString(); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + createdId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + + String monitorId = ((List) ((Map) hit.getSourceAsMap().get("detector")).get("monitor_id")).get(0); + + indexDoc(index, "1", randomDoc()); + + Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + Map executeResults = entityAsMap(executeResponse); + + int noOfSigmaRuleMatches = ((List>) ((Map) executeResults.get("input_results")).get("results")).get(0).size(); + Assert.assertEquals(6, noOfSigmaRuleMatches); + + // 2 findings and 2 alerts are generated + indexDoc(index, "2", randomDoc()); + + executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + + Assert.assertEquals(1, ((Map) executeResults.get("trigger_results")).values().size()); + + for (Map.Entry> triggerResult: ((Map>) executeResults.get("trigger_results")).entrySet()) { + Assert.assertEquals(1, ((Map) triggerResult.getValue().get("action_results")).values().size()); + + for (Map.Entry> alertActionResult: ((Map>) triggerResult.getValue().get("action_results")).entrySet()) { + Map actionResults = alertActionResult.getValue(); + + for (Map.Entry actionResult: actionResults.entrySet()) { + Map actionOutput = ((Map>) actionResult.getValue()).get("output"); + String expectedMessage = triggerAction.getSubjectTemplate().getIdOrCode().replace("{{ctx.detector.name}}", detector.getName()) + .replace("{{ctx.trigger.name}}", "test-trigger").replace("{{ctx.trigger.severity}}", "1"); + + Assert.assertEquals(expectedMessage, actionOutput.get("subject")); + Assert.assertEquals(expectedMessage, actionOutput.get("message")); + } + } + } + + request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + hits = new ArrayList<>(); + + while (hits.size() == 0) { + hits = executeSearch(DetectorMonitorConfig.getAlertsIndex(randomDetectorType()), request); + } + + Map params = new HashMap<>(); + params.put("detector_id", createdId); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + Assert.assertEquals(2, getFindingsBody.get("total_findings")); + + // print the contents of a java map + + System.out.println("Printing the contents of the alerts map: -------------------------------------------------------------------------"); + + if (!getFindingsBody.isEmpty()) { + Iterator> it = getFindingsBody.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry obj = it.next(); + System.out.println(obj.getValue() + "Key: " + obj.getKey()); + } + } + + // Call GetAlerts API + params.clear(); + params.put("detector_id", createdId); + + List> findingsList = (List>) getFindingsBody.get("findings"); + + + System.out.println("size of getFindings:" + findingsList.size()); + + Map firstFinding = findingsList.get(0); + + Object findingIds = firstFinding.get("id"); + + System.out.println("---------------------------------------------------------------------------------------------------------"); + System.out.println("The findings id is: " + findingIds.toString()); + System.out.println("---------------------------------------------------------------------------------------------------------"); + params.put("findingIds", findingIds.toString()); + + + Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params, null); + Map getAlertsBody = asMap(getAlertsResponse); + // TODO enable asserts here when able + Assert.assertEquals(0, getAlertsBody.get("total_alerts")); + + // Write the continuation of this test case + + + String alertId = (String) ((ArrayList>) getAlertsBody.get("alerts")).get(0).get("id"); + String detectorId = (String) ((ArrayList>) getAlertsBody.get("alerts")).get(0).get("detector_id"); + params = new HashMap<>(); + String body = String.format(Locale.getDefault(), "{\"alerts\":[\"%s\"]}", alertId); + Request post = new Request("POST", String.format( + Locale.getDefault(), + "%s/%s/_acknowledge/alerts", + SecurityAnalyticsPlugin.DETECTOR_BASE_URI, + detectorId)); + post.setJsonEntity(body); + Response ackAlertsResponse = client().performRequest(post); + assertNotNull(ackAlertsResponse); + Map ackAlertsResponseMap = entityAsMap(ackAlertsResponse); + assertTrue(((ArrayList) ackAlertsResponseMap.get("missing")).isEmpty()); + assertTrue(((ArrayList) ackAlertsResponseMap.get("failed")).isEmpty()); + assertEquals(((ArrayList) ackAlertsResponseMap.get("acknowledged")).size(), 1); + } + + public void testGetAlerts_noDetector_failure() throws IOException { // Call GetAlerts API Map params = new HashMap<>();