Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds findingIds as a search parameter for alerts #794

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import java.io.IOException;
import java.util.Locale;
import java.util.ArrayList;
import java.util.List;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -19,36 +21,49 @@
public class GetAlertsRequest extends ActionRequest {

private String detectorId;
private ArrayList<String> findingIds;
private String logType;
private Table table;
private String severityLevel;
private String alertState;

public static final String DETECTOR_ID = "detector_id";


// Updated the constructor to include findingIds
public GetAlertsRequest(
String detectorId,
ArrayList<String> findingIds,
String logType,
Table table,
String severityLevel,
String alertState
) {
super();
this.detectorId = detectorId;
this.findingIds = findingIds;
this.logType = logType;
this.table = table;
this.severityLevel = severityLevel;
this.alertState = alertState;
}
public GetAlertsRequest(StreamInput sin) throws IOException {
this(
sin.readOptionalString(),
sin.readOptionalString(),
Table.readFrom(sin),
sin.readString(),
sin.readString()
);
}

public GetAlertsRequest(StreamInput sin) throws IOException {
super();

this.detectorId = sin.readOptionalString();

List<String> findingIdsList = sin.readStringList();
this.findingIds = findingIdsList != null ? new ArrayList<>(findingIdsList) : new ArrayList<>();

this.logType = sin.readOptionalString();
this.table = Table.readFrom(sin);
this.severityLevel = sin.readString();
this.alertState = sin.readString();
}




@Override
public ActionRequestValidationException validate() {
Expand All @@ -61,9 +76,11 @@ public ActionRequestValidationException validate() {
return validationException;
}

// Added the writeTo for findingIds
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(detectorId);
out.writeStringCollection(findingIds);
out.writeOptionalString(logType);
table.writeTo(out);
out.writeString(severityLevel);
Expand All @@ -89,4 +106,9 @@ public String getAlertState() {
public String getLogType() {
return logType;
}

// Getter Function for findingIds
public ArrayList<String> getFindingIds() {
return findingIds;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,27 @@ public class GetFindingsRequest extends ActionRequest {
private String detectorId;
private Table table;


public static final String DETECTOR_ID = "detector_id";

public GetFindingsRequest(String detectorId) {
super();
this.detectorId = detectorId;
}

public GetFindingsRequest(StreamInput sin) throws IOException {
this(

sin.readOptionalString(),
// sin.readOptionalList for arraylist findingIds
sin.readOptionalString(),
Table.readFrom(sin)
);
}

public GetFindingsRequest(String detectorId, String logType, Table table) {
this.detectorId = detectorId;
// Updated param above
this.logType = logType;
this.table = table;
}
Expand All @@ -57,6 +62,7 @@ public ActionRequestValidationException validate() {
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(detectorId);
out.writeOptionalString(logType);
// Write the finding ids
table.writeTo(out);
}

Expand All @@ -71,4 +77,5 @@ public String getLogType() {
public Table getTable() {
return table;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ public AlertsService(Client client) {
* Searches alerts generated by specific Detector
*
* @param detectorId id of Detector
* @param findingIds finding id to filter alerts
* @param table group of search related parameters
* @param severityLevel alert severity level
* @param alertState current alert state
* @param listener ActionListener to get notified on response or error
*/
public void getAlertsByDetectorId(
String detectorId,
ArrayList<String> findingIds,
Table table,
String severityLevel,
String alertState,
Expand All @@ -81,9 +83,11 @@ public void onResponse(GetDetectorResponse getDetectorResponse) {
monitorId -> monitorToDetectorMapping.put(monitorId, detector.getId())
);
// Get alerts for all monitor ids
// Do i need to add finding IDs for this method? Line 128 another doubt
AlertsService.this.getAlertsByMonitorIds(
monitorToDetectorMapping,
monitorIds,
findingIds,
DetectorMonitorConfig.getAllAlertsIndicesPattern(detector.getDetectorType()),
table,
severityLevel,
Expand Down Expand Up @@ -116,15 +120,18 @@ public void onFailure(Exception e) {
*
* @param monitorToDetectorMapping monitorId to detectorId mapping
* @param monitorIds list of monitor ids
* @param findingIds finding id to filter alerts
* @param alertIndex alert index to search alerts on
* @param table group of search related parameters
* @param severityLevel alert severity level
* @param alertState current alert state *
* @param listener ActionListener to get notified on response or error
*/

public void getAlertsByMonitorIds(
Map<String, String> monitorToDetectorMapping,
List<String> monitorIds,
List<String> findingIds,
String alertIndex,
Table table,
String severityLevel,
Expand All @@ -135,6 +142,7 @@ public void getAlertsByMonitorIds(
org.opensearch.commons.alerting.action.GetAlertsRequest req =
new org.opensearch.commons.alerting.action.GetAlertsRequest(
table,
findingIds,
severityLevel,
alertState,
null,
Expand Down Expand Up @@ -174,6 +182,7 @@ void setIndicesAdminClient(Client client) {

public void getAlerts(
List<Detector> detectors,
ArrayList<String> findingIds,
String logType,
Table table,
String severityLevel,
Expand All @@ -200,6 +209,7 @@ public void getAlerts(
AlertsService.this.getAlertsByMonitorIds(
monitorToDetectorMapping,
allMonitorIds,
findingIds,
DetectorMonitorConfig.getAllAlertsIndicesPattern(logType),
table,
severityLevel,
Expand Down Expand Up @@ -243,12 +253,14 @@ private AlertDto mapAlertToAlertDto(Alert alert, String detectorId) {
);
}

// Check where exactly is this method used?
public void getAlerts(List<String> alertIds,
Detector detector,
Table table,
ActionListener<org.opensearch.commons.alerting.action.GetAlertsResponse> actionListener) {
GetAlertsRequest request = new GetAlertsRequest(
table,
List.of(),
"ALL",
"ALL",
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ public FindingsService(Client client) {
* @param table group of search related parameters
* @param listener ActionListener to get notified on response or error
*/

// This is the function and add a new parameter for finding ids
public void getFindingsByDetectorId(String detectorId, Table table, ActionListener<GetFindingsResponse> listener ) {
this.client.execute(GetDetectorAction.INSTANCE, new GetDetectorRequest(detectorId, -3L), new ActionListener<>() {

Expand Down Expand Up @@ -131,7 +133,7 @@ public void getFindingsByMonitorIds(

org.opensearch.commons.alerting.action.GetFindingsRequest req =
new org.opensearch.commons.alerting.action.GetFindingsRequest(
null,
null, // Need to pass the findingId as List but in api it is a sting[it will change]
table,
null,
findingIndexName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.io.IOException;
import java.util.List;
import java.util.ArrayList;
import java.util.Locale;
import org.opensearch.client.node.NodeClient;
import org.opensearch.commons.alerting.model.Table;
Expand All @@ -19,8 +20,8 @@
import org.opensearch.securityanalytics.action.GetFindingsRequest;
import org.opensearch.securityanalytics.model.Detector;


import static java.util.Collections.singletonList;
import java.util.Arrays;
import static org.opensearch.rest.RestRequest.Method.GET;

public class RestGetAlertsAction extends BaseRestHandler {
Expand All @@ -32,8 +33,8 @@ public String getName() {

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {

String detectorId = request.param("detector_id", null);
String[] findingIds = request.paramAsStringArray("findingIds", null);
String detectorType = request.param("detectorType", null);
String severityLevel = request.param("severityLevel", "ALL");
String alertState = request.param("alertState", "ALL");
Expand All @@ -56,12 +57,14 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

GetAlertsRequest req = new GetAlertsRequest(
detectorId,
convertFindingIdsToList(findingIds),
detectorType,
table,
severityLevel,
alertState
);

// Request goes to TransportGetAlertsAction class
return channel -> client.execute(
GetAlertsAction.INSTANCE,
req,
Expand All @@ -73,4 +76,12 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
public List<Route> routes() {
return singletonList(new Route(GET, SecurityAnalyticsPlugin.ALERTS_BASE_URI));
}
}

private ArrayList<String> convertFindingIdsToList(String[] findingIds) {
if (findingIds == null) {
return new ArrayList<>();
}
return new ArrayList<>(Arrays.asList(findingIds));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

GetFindingsRequest req = new GetFindingsRequest(
detectorId,
// Add finding ids
detectorType,
table
);

// Request goes to TransportGetFindingsAction class
return channel -> client.execute(
GetFindingsAction.INSTANCE,
req,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public TransportGetAlertsAction(TransportService transportService, ActionFilters
this.clusterService.getClusterSettings().addSettingsUpdateConsumer(SecurityAnalyticsSettings.FILTER_BY_BACKEND_ROLES, this::setFilterByEnabled);
}

// The client request hits here
@Override
protected void doExecute(Task task, GetAlertsRequest request, ActionListener<GetAlertsResponse> actionListener) {

Expand All @@ -88,6 +89,8 @@ protected void doExecute(Task task, GetAlertsRequest request, ActionListener<Get
if (request.getLogType() == null) {
alertsService.getAlertsByDetectorId(
request.getDetectorId(),
// Added the getFinding Ids param
request.getFindingIds(),
request.getTable(),
request.getSeverityLevel(),
request.getAlertState(),
Expand Down Expand Up @@ -131,6 +134,7 @@ public void onResponse(SearchResponse searchResponse) {
}
alertsService.getAlerts(
detectors,
request.getFindingIds(),
request.getLogType(),
request.getTable(),
request.getSeverityLevel(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ public TransportGetFindingsAction(
this.clusterService.getClusterSettings().addSettingsUpdateConsumer(SecurityAnalyticsSettings.FILTER_BY_BACKEND_ROLES, this::setFilterByEnabled);
}


// Request hits here
@Override
protected void doExecute(Task task, GetFindingsRequest request, ActionListener<GetFindingsResponse> actionListener) {

Expand All @@ -106,6 +108,7 @@ protected void doExecute(Task task, GetFindingsRequest request, ActionListener<G

if (request.getLogType() == null) {
findingsService.getFindingsByDetectorId(
// request finding ids
request.getDetectorId(),
request.getTable(),
actionListener
Expand Down Expand Up @@ -146,6 +149,7 @@ public void onResponse(SearchResponse searchResponse) {
);
return;
}
// Need to add finding Ids in this method too
findingsService.getFindings(
detectors,
request.getLogType(),
Expand Down
Loading