Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds support for alerts and triggers on group by based sigma rules #545

Merged
merged 7 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,22 @@ public String getSeverity() {
return severity;
}

public List<String> getRuleTypes() {
return ruleTypes;
}

public List<String> getRuleIds() {
return ruleIds;
}

public List<String> getRuleSeverityLevels() {
return ruleSeverityLevels;
}

public List<String> getTags() {
return tags;
}

public List<Action> getActions() {
List<Action> transformedActions = new ArrayList<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public class SecurityAnalyticsSettings {

public static final Setting<Boolean> ENABLE_WORKFLOW_USAGE = Setting.boolSetting(
"plugins.security_analytics.enable_workflow_usage",
false,
true,
Setting.Property.NodeScope, Setting.Property.Dynamic
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
import org.opensearch.securityanalytics.rules.exceptions.SigmaError;
import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings;
import org.opensearch.securityanalytics.util.DetectorIndices;
import org.opensearch.securityanalytics.util.DetectorUtils;
import org.opensearch.securityanalytics.util.IndexUtils;
import org.opensearch.securityanalytics.util.MonitorService;
import org.opensearch.securityanalytics.util.RuleIndices;
Expand All @@ -114,6 +115,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -155,7 +157,7 @@ public class TransportIndexDetectorAction extends HandledTransportAction<IndexDe
private final MonitorService monitorService;
private final IndexNameExpressionResolver indexNameExpressionResolver;

private volatile TimeValue indexTimeout;
private final TimeValue indexTimeout;
@Inject
public TransportIndexDetectorAction(TransportService transportService,
Client client,
Expand Down Expand Up @@ -275,15 +277,15 @@ private void createMonitorFromQueries(List<Pair<String, Rule>> rulesById, Detect

StepListener<List<IndexMonitorResponse>> indexMonitorsStep = new StepListener<>();
indexMonitorsStep.whenComplete(
indexMonitorResponses -> saveWorkflow(detector, indexMonitorResponses, refreshPolicy, listener),
indexMonitorResponses -> saveWorkflow(rulesById, detector, indexMonitorResponses, refreshPolicy, listener),
e -> {
log.error("Failed to index the workflow", e);
listener.onFailure(e);
});

int numberOfUnprocessedResponses = monitorRequests.size() - 1;
if (numberOfUnprocessedResponses == 0) {
saveWorkflow(detector, monitorResponses, refreshPolicy, listener);
saveWorkflow(rulesById, detector, monitorResponses, refreshPolicy, listener);
} else {
// Saves the rest of the monitors and saves the workflow if supported
saveMonitors(
Expand Down Expand Up @@ -312,7 +314,7 @@ private void createMonitorFromQueries(List<Pair<String, Rule>> rulesById, Detect
AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(0), namedWriteableRegistry, indexDocLevelMonitorStep);
indexDocLevelMonitorStep.whenComplete(addedFirstMonitorResponse -> {
monitorResponses.add(addedFirstMonitorResponse);
saveWorkflow(detector, monitorResponses, refreshPolicy, listener);
saveWorkflow(rulesById, detector, monitorResponses, refreshPolicy, listener);
},
listener::onFailure
);
Expand Down Expand Up @@ -346,20 +348,23 @@ public void onFailure(Exception e) {
/**
* If the workflow is enabled, saves the workflow, updates the detector and returns the saved monitors
* if not, returns the saved monitors
*
* @param rulesById
* @param detector
* @param monitorResponses
* @param refreshPolicy
* @param actionListener
*/
private void saveWorkflow(
Detector detector,
List<IndexMonitorResponse> monitorResponses,
RefreshPolicy refreshPolicy,
ActionListener<List<IndexMonitorResponse>> actionListener
List<Pair<String, Rule>> rulesById, Detector detector,
List<IndexMonitorResponse> monitorResponses,
RefreshPolicy refreshPolicy,
ActionListener<List<IndexMonitorResponse>> actionListener
) {
if (enabledWorkflowUsage) {
workflowService.upsertWorkflow(
monitorResponses.stream().map(IndexMonitorResponse::getId).collect(Collectors.toList()),
rulesById,
monitorResponses,
null,
detector,
refreshPolicy,
Expand Down Expand Up @@ -446,7 +451,7 @@ public void onResponse(Map<String, Map<String, String>> ruleFieldMappings) {
monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect(
Collectors.toList()));

updateAlertingMonitors(detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener);
updateAlertingMonitors(rulesById, detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener);
} catch (IOException | SigmaError ex) {
listener.onFailure(ex);
}
Expand Down Expand Up @@ -474,7 +479,7 @@ public void onFailure(Exception e) {
monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect(
Collectors.toList()));

updateAlertingMonitors(detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener);
updateAlertingMonitors(rulesById, detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener);
}
}

Expand All @@ -493,6 +498,7 @@ public void onFailure(Exception e) {
* @param listener Listener that accepts the list of updated monitors if the action was successful
*/
private void updateAlertingMonitors(
List<Pair<String, Rule>> rulesById,
Detector detector,
List<IndexMonitorRequest> monitorsToBeAdded,
List<IndexMonitorRequest> monitorsToBeUpdated,
Expand All @@ -519,6 +525,7 @@ private void updateAlertingMonitors(
}
if (detector.isWorkflowSupported() && enabledWorkflowUsage) {
updateWorkflowStep(
rulesById,
detector,
monitorsToBeDeleted,
refreshPolicy,
Expand Down Expand Up @@ -560,6 +567,7 @@ public void onFailure(Exception e) {
}

private void updateWorkflowStep(
List<Pair<String, Rule>> rulesById,
Detector detector,
List<String> monitorsToBeDeleted,
RefreshPolicy refreshPolicy,
Expand Down Expand Up @@ -596,8 +604,9 @@ public void onFailure(Exception e) {
} else {
// Update workflow and delete the monitors
workflowService.upsertWorkflow(
addedMonitorIds,
updatedMonitorIds,
rulesById,
addNewMonitorsResponse,
updateMonitorResponse,
detector,
refreshPolicy,
detector.getWorkflowIds().get(0),
Expand Down Expand Up @@ -667,6 +676,58 @@ private IndexMonitorRequest createDocLevelMonitorRequest(List<Pair<String, Rule>
return new IndexMonitorRequest(monitorId, SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, restMethod, monitor, null);
}

/**
* Creates doc level monitor which generates per document alerts for the findings of the bucket level delegate monitors in a workflow.
* This monitor has match all query applied to generate the alerts per each finding doc.
*/
private IndexMonitorRequest createDocLevelMonitorMatchAllRequest(
Detector detector,
WriteRequest.RefreshPolicy refreshPolicy,
String monitorId,
RestRequest.Method restMethod
) {
List<DocLevelMonitorInput> docLevelMonitorInputs = new ArrayList<>();
List<DocLevelQuery> docLevelQueries = new ArrayList<>();
String monitorName = detector.getName() + "_chained_findings";
String actualQuery = "_id:*";
DocLevelQuery docLevelQuery = new DocLevelQuery(
monitorName,
monitorName + "doc",
actualQuery,
Collections.emptyList()
);
docLevelQueries.add(docLevelQuery);

DocLevelMonitorInput docLevelMonitorInput = new DocLevelMonitorInput(detector.getName(), detector.getInputs().get(0).getIndices(), docLevelQueries);
docLevelMonitorInputs.add(docLevelMonitorInput);

List<DocumentLevelTrigger> triggers = new ArrayList<>();
List<DetectorTrigger> detectorTriggers = detector.getTriggers();

for (DetectorTrigger detectorTrigger : detectorTriggers) {
String id = detectorTrigger.getId();
String name = detectorTrigger.getName();
String severity = detectorTrigger.getSeverity();
List<Action> actions = detectorTrigger.getActions();
Script condition = detectorTrigger.convertToCondition();

triggers.add(new DocumentLevelTrigger(id, name, severity, actions, condition));
}

Monitor monitor = new Monitor(monitorId, Monitor.NO_VERSION, monitorName, false, detector.getSchedule(), detector.getLastUpdateTime(), null,
Monitor.MonitorType.DOC_LEVEL_MONITOR, detector.getUser(), 1, docLevelMonitorInputs, triggers, Map.of(),
new DataSources(detector.getRuleIndex(),
detector.getFindingsIndex(),
detector.getFindingsIndexPattern(),
detector.getAlertsIndex(),
detector.getAlertsHistoryIndex(),
detector.getAlertsHistoryIndexPattern(),
DetectorMonitorConfig.getRuleIndexMappingsByType(),
true), PLUGIN_OWNER_FIELD);

return new IndexMonitorRequest(monitorId, SequenceNumbers.UNASSIGNED_SEQ_NO, SequenceNumbers.UNASSIGNED_PRIMARY_TERM, refreshPolicy, restMethod, monitor, null);
}

private void buildBucketLevelMonitorRequests(List<Pair<String, Rule>> queries, Detector detector, WriteRequest.RefreshPolicy refreshPolicy, String monitorId, RestRequest.Method restMethod, ActionListener<List<IndexMonitorRequest>> listener) throws IOException, SigmaError {

logTypeService.getRuleFieldMappings(new ActionListener<>() {
Expand Down Expand Up @@ -697,6 +758,10 @@ public void onResponse(Map<String, Map<String, String>> ruleFieldMappings) {
queryBackendMap.get(rule.getCategory())));
}
}
// if workflow usage enabled, add chained findings monitor request if there are bucket level requests and if the detector triggers have any group by rules configured to trigger
if (enabledWorkflowUsage && !monitorRequests.isEmpty() && !DetectorUtils.getAggRuleIdsConfiguredToTrigger(detector, queries).isEmpty()) {
monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId()+"_chained_findings", Method.POST));
}
listener.onResponse(monitorRequests);
} catch (IOException | SigmaError ex) {
listener.onFailure(ex);
Expand Down Expand Up @@ -1431,7 +1496,11 @@ private Map<String, String> mapMonitorIds(List<IndexMonitorResponse> monitorResp
if (MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()) {
return it.getMonitor().getTriggers().get(0).getId();
} else {
return Detector.DOC_LEVEL_MONITOR;
if (it.getMonitor().getName().contains("_chained_findings")) {
return "chained_findings_monitor";
} else {
return Detector.DOC_LEVEL_MONITOR;
}
}
},
IndexMonitorResponse::getId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
*/
package org.opensearch.securityanalytics.util;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.lucene.search.TotalHits;
import org.opensearch.cluster.routing.Preference;
import org.opensearch.commons.alerting.action.IndexMonitorResponse;
import org.opensearch.commons.alerting.model.Monitor;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -25,13 +28,15 @@
import org.opensearch.search.suggest.Suggest;
import org.opensearch.securityanalytics.model.Detector;
import org.opensearch.securityanalytics.model.DetectorInput;
import org.opensearch.securityanalytics.model.Rule;

import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class DetectorUtils {

Expand Down Expand Up @@ -95,4 +100,36 @@ public void onFailure(Exception e) {
}
});
}

public static List<String> getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(
Detector detector,
List<Pair<String, Rule>> rulesById,
List<IndexMonitorResponse> monitorResponses
) {
List<String> aggRuleIdsConfiguredToTrigger = getAggRuleIdsConfiguredToTrigger(detector, rulesById);
return monitorResponses.stream().filter(
// In the case of bucket level monitors rule id is trigger id
it -> Monitor.MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()
&& !it.getMonitor().getTriggers().isEmpty()
&& aggRuleIdsConfiguredToTrigger.contains(it.getMonitor().getTriggers().get(0).getId())
).map(IndexMonitorResponse::getId).collect(Collectors.toList());
}
public static List<String> getAggRuleIdsConfiguredToTrigger(Detector detector, List<Pair<String, Rule>> rulesById) {
Set<String> ruleIdsConfiguredToTrigger = detector.getTriggers().stream().flatMap(t -> t.getRuleIds().stream()).collect(Collectors.toSet());
Set<String> tagsConfiguredToTrigger = detector.getTriggers().stream().flatMap(t -> t.getTags().stream()).collect(Collectors.toSet());
return rulesById.stream()
.filter(it -> checkIfRuleIsAggAndTriggerable( it.getRight(), ruleIdsConfiguredToTrigger, tagsConfiguredToTrigger))
.map(stringRulePair -> stringRulePair.getRight().getId())
.collect(Collectors.toList());
}

private static boolean checkIfRuleIsAggAndTriggerable(Rule rule, Set<String> ruleIdsConfiguredToTrigger, Set<String> tagsConfiguredToTrigger) {
if (rule.isAggregationRule()) {
return ruleIdsConfiguredToTrigger.contains(rule.getId())
|| rule.getTags().stream().anyMatch(tag -> tagsConfiguredToTrigger.contains(tag.getValue()));
}
return false;
}


}
Loading