Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] add mitre attack based auto-correlations support in correlation engine #540

Merged
merged 2 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.lucene.search.join.ScoreMode;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.cluster.routing.Preference;
import org.opensearch.commons.alerting.model.DocLevelQuery;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.search.MultiSearchRequest;
import org.opensearch.action.search.MultiSearchResponse;
Expand All @@ -23,26 +24,32 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.NestedQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.securityanalytics.config.monitors.DetectorMonitorConfig;
import org.opensearch.securityanalytics.logtype.LogTypeService;
import org.opensearch.securityanalytics.model.CorrelationQuery;
import org.opensearch.securityanalytics.model.CorrelationRule;
import org.opensearch.securityanalytics.model.Detector;
import org.opensearch.securityanalytics.transport.TransportCorrelateFindingAction;
import org.opensearch.securityanalytics.util.AutoCorrelationsRepo;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;


Expand All @@ -58,18 +65,147 @@ public class JoinEngine {

private final TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction;

private final LogTypeService logTypeService;

private static final Logger log = LogManager.getLogger(JoinEngine.class);

public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRegistry xContentRegistry,
long corrTimeWindow, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction) {
long corrTimeWindow, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction,
LogTypeService logTypeService) {
this.client = client;
this.request = request;
this.xContentRegistry = xContentRegistry;
this.corrTimeWindow = corrTimeWindow;
this.correlateFindingAction = correlateFindingAction;
this.logTypeService = logTypeService;
}

public void onSearchDetectorResponse(Detector detector, Finding finding) {
try {
generateAutoCorrelations(detector, finding);
} catch (IOException ex) {
correlateFindingAction.onFailures(ex);
}
}

@SuppressWarnings("unchecked")
private void generateAutoCorrelations(Detector detector, Finding finding) throws IOException {
Map<String, Set<String>> autoCorrelations = AutoCorrelationsRepo.autoCorrelationsAsMap();
long findingTimestamp = finding.getTimestamp().toEpochMilli();

Set<String> tags = new HashSet<>();
for (DocLevelQuery query : finding.getDocLevelQueries()) {
tags.addAll(query.getTags().stream().filter(tag -> tag.startsWith("attack.")).collect(Collectors.toList()));
}
Set<String> validIntrusionSets = AutoCorrelationsRepo.validIntrusionSets(autoCorrelations, tags);

MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery("source", "Sigma");

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(queryBuilder);
searchSourceBuilder.size(100);

SearchRequest request = new SearchRequest();
request.source(searchSourceBuilder);
logTypeService.searchLogTypes(request, new ActionListener<>() {
@Override
public void onResponse(SearchResponse response) {
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
SearchHit[] logTypes = response.getHits().getHits();
List<String> logTypeNames = new ArrayList<>();
for (SearchHit logType: logTypes) {
String logTypeName = logType.getSourceAsMap().get("name").toString();
logTypeNames.add(logTypeName);

RangeQueryBuilder queryBuilder = QueryBuilders.rangeQuery("timestamp")
.gte(findingTimestamp - corrTimeWindow)
.lte(findingTimestamp + corrTimeWindow);

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(queryBuilder);
searchSourceBuilder.size(10000);
searchSourceBuilder.fetchField("queries");
SearchRequest searchRequest = new SearchRequest();
searchRequest.indices(DetectorMonitorConfig.getAllFindingsIndicesPattern(logTypeName));
searchRequest.source(searchSourceBuilder);
searchRequest.preference(Preference.PRIMARY_FIRST.type());
mSearchRequest.add(searchRequest);
}

if (!mSearchRequest.requests().isEmpty()) {
client.multiSearch(mSearchRequest, new ActionListener<>() {
@Override
public void onResponse(MultiSearchResponse items) {
MultiSearchResponse.Item[] responses = items.getResponses();

Map<String, List<String>> autoCorrelationsMap = new HashMap<>();
int idx = 0;
for (MultiSearchResponse.Item response : responses) {
if (response.isFailure()) {
log.info(response.getFailureMessage());
continue;
}
String logTypeName = logTypeNames.get(idx);

SearchHit[] findings = response.getResponse().getHits().getHits();

for (SearchHit foundFinding : findings) {
if (!foundFinding.getId().equals(finding.getId())) {
Set<String> findingTags = new HashSet<>();
List<Map<String, Object>> queries = (List<Map<String, Object>>) foundFinding.getSourceAsMap().get("queries");
for (Map<String, Object> query : queries) {
List<String> queryTags = (List<String>) query.get("tags");
findingTags.addAll(queryTags.stream().filter(queryTag -> queryTag.startsWith("attack.")).collect(Collectors.toList()));
}

boolean canCorrelate = false;
for (String tag: tags) {
if (findingTags.contains(tag)) {
canCorrelate = true;
break;
}
}

Set<String> foundIntrusionSets = AutoCorrelationsRepo.validIntrusionSets(autoCorrelations, findingTags);
for (String validIntrusionSet: validIntrusionSets) {
if (foundIntrusionSets.contains(validIntrusionSet)) {
canCorrelate = true;
break;
}
}

if (canCorrelate) {
if (autoCorrelationsMap.containsKey(logTypeName)) {
autoCorrelationsMap.get(logTypeName).add(foundFinding.getId());
} else {
List<String> autoCorrelatedFindings = new ArrayList<>();
autoCorrelatedFindings.add(foundFinding.getId());
autoCorrelationsMap.put(logTypeName, autoCorrelatedFindings);
}
}
}
}
++idx;
}
onAutoCorrelations(detector, finding, autoCorrelationsMap);
}

@Override
public void onFailure(Exception e) {
correlateFindingAction.onFailures(e);
}
});
}
}

@Override
public void onFailure(Exception e) {
correlateFindingAction.onFailures(e);
}
});
}

private void onAutoCorrelations(Detector detector, Finding finding, Map<String, List<String>> autoCorrelations) {
String detectorType = detector.getDetectorType().toLowerCase(Locale.ROOT);
List<String> indices = detector.getInputs().get(0).getIndices();
List<String> relatedDocIds = finding.getCorrelatedDocIds();
Expand Down Expand Up @@ -113,20 +249,20 @@ public void onResponse(SearchResponse response) {
}
}

getValidDocuments(detectorType, indices, correlationRules, relatedDocIds);
getValidDocuments(detectorType, indices, correlationRules, relatedDocIds, autoCorrelations);
}

@Override
public void onFailure(Exception e) {
correlateFindingAction.onFailures(e);
getValidDocuments(detectorType, indices, List.of(), List.of(), autoCorrelations);
}
});
}

/**
* this method checks if the finding to be correlated has valid related docs(or not) which match join criteria.
*/
private void getValidDocuments(String detectorType, List<String> indices, List<CorrelationRule> correlationRules, List<String> relatedDocIds) {
private void getValidDocuments(String detectorType, List<String> indices, List<CorrelationRule> correlationRules, List<String> relatedDocIds, Map<String, List<String>> autoCorrelations) {
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<CorrelationRule> validCorrelationRules = new ArrayList<>();

Expand Down Expand Up @@ -189,7 +325,9 @@ public void onResponse(MultiSearchResponse items) {
}
}
searchFindingsByTimestamp(detectorType, categoryToQueriesMap,
filteredCorrelationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()));
filteredCorrelationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()),
autoCorrelations
);
}

@Override
Expand All @@ -198,15 +336,19 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), List.of());
if (!autoCorrelations.isEmpty()) {
correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of());
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), List.of());
}
}
}

/**
* this method searches for parent findings given the log category & correlation time window & collects all related docs
* for them.
*/
private void searchFindingsByTimestamp(String detectorType, Map<String, List<CorrelationQuery>> categoryToQueriesMap, List<String> correlationRules) {
private void searchFindingsByTimestamp(String detectorType, Map<String, List<CorrelationQuery>> categoryToQueriesMap, List<String> correlationRules, Map<String, List<String>> autoCorrelations) {
long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli();
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<Pair<String, List<CorrelationQuery>>> categoryToQueriesPairs = new ArrayList<>();
Expand Down Expand Up @@ -260,7 +402,7 @@ public void onResponse(MultiSearchResponse items) {
relatedDocIds));
++idx;
}
searchDocsWithFilterKeys(detectorType, relatedDocsMap, correlationRules);
searchDocsWithFilterKeys(detectorType, relatedDocsMap, correlationRules, autoCorrelations);
}

@Override
Expand All @@ -269,14 +411,18 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
if (!autoCorrelations.isEmpty()) {
correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of());
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
}
}
}

/**
* Given the related docs from parent findings, this method filters only those related docs which match parent join criteria.
*/
private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearchCriteria> relatedDocsMap, List<String> correlationRules) {
private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearchCriteria> relatedDocsMap, List<String> correlationRules, Map<String, List<String>> autoCorrelations) {
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<String> categories = new ArrayList<>();

Expand Down Expand Up @@ -324,7 +470,7 @@ public void onResponse(MultiSearchResponse items) {
filteredRelatedDocIds.put(categories.get(idx), docIds);
++idx;
}
getCorrelatedFindings(detectorType, filteredRelatedDocIds, correlationRules);
getCorrelatedFindings(detectorType, filteredRelatedDocIds, correlationRules, autoCorrelations);
}

@Override
Expand All @@ -333,15 +479,19 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
if (!autoCorrelations.isEmpty()) {
correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of());
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
}
}
}

/**
* Given the filtered related docs of the parent findings, this method gets the actual filtered parent findings for
* the finding to be correlated.
*/
private void getCorrelatedFindings(String detectorType, Map<String, List<String>> filteredRelatedDocIds, List<String> correlationRules) {
private void getCorrelatedFindings(String detectorType, Map<String, List<String>> filteredRelatedDocIds, List<String> correlationRules, Map<String, List<String>> autoCorrelations) {
long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli();
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<String> categories = new ArrayList<>();
Expand Down Expand Up @@ -397,6 +547,16 @@ public void onResponse(MultiSearchResponse items) {
}
++idx;
}

for (Map.Entry<String, List<String>> autoCorrelation: autoCorrelations.entrySet()) {
if (correlatedFindings.containsKey(autoCorrelation.getKey())) {
Set<String> alreadyCorrelatedFindings = new HashSet<>(correlatedFindings.get(autoCorrelation.getKey()));
alreadyCorrelatedFindings.addAll(autoCorrelation.getValue());
correlatedFindings.put(autoCorrelation.getKey(), new ArrayList<>(alreadyCorrelatedFindings));
} else {
correlatedFindings.put(autoCorrelation.getKey(), autoCorrelation.getValue());
}
}
correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules);
}

Expand All @@ -406,7 +566,11 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
if (!autoCorrelations.isEmpty()) {
correlateFindingAction.getTimestampFeature(detectorType, autoCorrelations, null, List.of());
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ public class TransportCorrelateFindingAction extends HandledTransportAction<Acti

private final CorrelationIndices correlationIndices;

private final LogTypeService logTypeService;

private final ClusterService clusterService;

private final Settings settings;
Expand All @@ -100,6 +102,7 @@ public TransportCorrelateFindingAction(TransportService transportService,
NamedXContentRegistry xContentRegistry,
DetectorIndices detectorIndices,
CorrelationIndices correlationIndices,
LogTypeService logTypeService,
ClusterService clusterService,
Settings settings,
ActionFilters actionFilters) {
Expand All @@ -108,6 +111,7 @@ public TransportCorrelateFindingAction(TransportService transportService,
this.xContentRegistry = xContentRegistry;
this.detectorIndices = detectorIndices;
this.correlationIndices = correlationIndices;
this.logTypeService = logTypeService;
this.clusterService = clusterService;
this.settings = settings;
this.threadPool = this.detectorIndices.getThreadPool();
Expand Down Expand Up @@ -186,7 +190,7 @@ public class AsyncCorrelateFindingAction {

this.response =new AtomicReference<>();

this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, this);
this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, this, logTypeService);
this.vectorEmbeddingsEngine = new VectorEmbeddingsEngine(client, indexTimeout, corrTimeWindow, this);
}

Expand Down
Loading