diff --git a/alerting/src/main/kotlin/org/opensearch/alerting/AlertingPlugin.kt b/alerting/src/main/kotlin/org/opensearch/alerting/AlertingPlugin.kt index 7ef92d435..b4bee8d84 100644 --- a/alerting/src/main/kotlin/org/opensearch/alerting/AlertingPlugin.kt +++ b/alerting/src/main/kotlin/org/opensearch/alerting/AlertingPlugin.kt @@ -52,6 +52,7 @@ import org.opensearch.alerting.settings.AlertingSettings.Companion.DOC_LEVEL_MON import org.opensearch.alerting.settings.DestinationSettings import org.opensearch.alerting.settings.LegacyOpenDistroAlertingSettings import org.opensearch.alerting.settings.LegacyOpenDistroDestinationSettings +import org.opensearch.alerting.threatintel.ThreatIntelDetectionService import org.opensearch.alerting.transport.TransportAcknowledgeAlertAction import org.opensearch.alerting.transport.TransportAcknowledgeChainedAlertAction import org.opensearch.alerting.transport.TransportDeleteMonitorAction @@ -258,6 +259,7 @@ internal class AlertingPlugin : PainlessExtension, ActionPlugin, ScriptPlugin, R val lockService = LockService(client, clusterService) alertIndices = AlertIndices(settings, client, threadPool, clusterService) val alertService = AlertService(client, xContentRegistry, alertIndices) + val threatIntelDetectionService = ThreatIntelDetectionService(client, xContentRegistry) val triggerService = TriggerService(scriptService) runner = MonitorRunnerService .registerClusterService(clusterService) @@ -310,7 +312,8 @@ internal class AlertingPlugin : PainlessExtension, ActionPlugin, ScriptPlugin, R destinationMigrationCoordinator, lockService, alertService, - triggerService + triggerService, + threatIntelDetectionService ) } diff --git a/alerting/src/main/kotlin/org/opensearch/alerting/threatintel/ThreatIntelDetectionService.kt b/alerting/src/main/kotlin/org/opensearch/alerting/threatintel/ThreatIntelDetectionService.kt new file mode 100644 index 000000000..5ab3f18ad --- /dev/null +++ b/alerting/src/main/kotlin/org/opensearch/alerting/threatintel/ThreatIntelDetectionService.kt @@ -0,0 +1,200 @@ +package org.opensearch.alerting.threatintel + +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import org.apache.logging.log4j.LogManager +import org.opensearch.action.DocWriteRequest +import org.opensearch.action.admin.indices.refresh.RefreshAction +import org.opensearch.action.admin.indices.refresh.RefreshRequest +import org.opensearch.action.admin.indices.refresh.RefreshResponse +import org.opensearch.action.bulk.BulkRequest +import org.opensearch.action.bulk.BulkResponse +import org.opensearch.action.index.IndexRequest +import org.opensearch.action.search.SearchRequest +import org.opensearch.action.search.SearchResponse +import org.opensearch.action.support.GroupedActionListener +import org.opensearch.alerting.opensearchapi.suspendUntil +import org.opensearch.alerting.transport.TransportDocLevelMonitorFanOutAction +import org.opensearch.client.Client +import org.opensearch.common.document.DocumentField +import org.opensearch.common.xcontent.XContentType +import org.opensearch.commons.alerting.model.DocLevelMonitorInput +import org.opensearch.commons.alerting.model.DocLevelQuery +import org.opensearch.commons.alerting.model.Finding +import org.opensearch.commons.alerting.model.Monitor +import org.opensearch.commons.alerting.util.string +import org.opensearch.core.xcontent.NamedXContentRegistry +import org.opensearch.core.xcontent.ToXContent +import org.opensearch.core.xcontent.XContentBuilder +import org.opensearch.index.query.QueryBuilders +import org.opensearch.search.SearchHit +import java.time.Instant +import java.util.UUID +import java.util.stream.Collectors +import kotlin.coroutines.resume +import kotlin.coroutines.resumeWithException +import kotlin.coroutines.suspendCoroutine +import kotlin.math.min + +private val log = LogManager.getLogger(TransportDocLevelMonitorFanOutAction::class.java) +private val scope: CoroutineScope = CoroutineScope(Dispatchers.IO) + +// todo logging n try-catch +class ThreatIntelDetectionService( + val client: Client, + val xContentRegistry: NamedXContentRegistry, +) { + + val BATCH_SIZE = 65536 + val IOC_FIELD_NAME = "ioc" + suspend fun scanDataAgainstThreatIntel(monitor: Monitor, threatIntelIndices: List, hits: List) { + val start = System.currentTimeMillis() + try { + val stringList = buildTerms(monitor, hits) + log.error("TI_DEBUG: num iocs in queried data: ${stringList.size}") + searchTermsOnIndices(monitor, stringList.toList(), threatIntelIndices) + } catch (e: Exception) { + log.error("TI_DEBUG: failed to scan data against threat intel", e) + } finally { + val end = System.currentTimeMillis() + if (hits.isNotEmpty() && threatIntelIndices.isNotEmpty()) { + val l = end - start + log.error("TI_DEBUG: TOTAL TIME TAKEN for Threat intel matching for ${hits.size} is $l millis") + } + } + } + + private fun buildTerms(monitor: Monitor, hits: List): MutableSet { + try { + val input = monitor.inputs[0] as DocLevelMonitorInput + val iocFieldNames = input.iocFieldNames + val iocsInData = mutableSetOf() + for (hit in hits) { + if (hit.fields.isNotEmpty()) { + for (entry in hit.fields.entries) { + if (iocFieldNames.contains(entry.key)) { + if (entry.value.values.isNotEmpty()) { + iocsInData.addAll( + entry.value.values.stream().map { it.toString() } + .collect( + Collectors.toList() + ) + ) // fixme should we get input from customer on which specific ioc like ip or dns is present in which field + } + } + } + } + } + return iocsInData + } catch (e: Exception) { + log.error("TI_DEBUG: Failed to extract IoC's from the queryable data to scan against threat intel") + return mutableSetOf() + } + } + + private suspend fun searchTermsOnIndices(monitor: Monitor, iocs: List, threatIntelIndices: List) { + val iocSubLists = iocs.chunkSublists(BATCH_SIZE) + // TODO get unique values from list first + val responses: Collection = + suspendCoroutine { cont -> // todo implement a listener that tolerates multiple exceptions + val groupedListener = GroupedActionListener( + object : org.opensearch.core.action.ActionListener> { + override fun onResponse(responses: Collection) { + + cont.resume(responses) + } + + override fun onFailure(e: Exception) { + if (e.cause is Exception) + cont.resumeWithException(e.cause as Exception) + else + cont.resumeWithException(e) + } + }, + iocSubLists.size + ) + // chunk all iocs from queryable data and perform terms query for matches + // if matched return only the ioc's that matched and not the entire document + for (iocSubList in iocSubLists) { + if (iocSubList.isEmpty()) continue + val searchRequest = SearchRequest(*threatIntelIndices.toTypedArray()) + val queryBuilder = QueryBuilders.boolQuery() + queryBuilder.filter(QueryBuilders.boolQuery().must(QueryBuilders.termsQuery(IOC_FIELD_NAME, iocSubList))) + searchRequest.source().query(queryBuilder) + searchRequest.source().fetchSource(false).fetchField(IOC_FIELD_NAME) + client.search(searchRequest, groupedListener) + } + } + val iocMatches = mutableSetOf() + for (response in responses) { + log.error("TI_DEBUG search response took: ${response.took} millis") + if (response.hits.hits.isEmpty()) continue + for (hit in response.hits.hits) { + if (hit.fields != null && hit.fields.containsKey(IOC_FIELD_NAME)) { + val element: DocumentField? = hit.fields[IOC_FIELD_NAME] + if (element!!.values.isNotEmpty()) + iocMatches.add(element.values[0].toString()) + } + } + } + log.error("TI_DEBUG num ioc matches: ${iocMatches.size}") + createFindings(monitor, iocMatches.toList()) + } + + // Function to chunk a list into sublists of specified size + fun List.chunkSublists(chunkSize: Int): List> { + return (0..size step chunkSize).map { subList(fromIndex = it, toIndex = min(it + chunkSize, size)) } + } + + suspend fun createFindings(monitor: Monitor, iocMatches: List) { + val findingDocPairs = mutableListOf>() + val findings = mutableListOf() + val indexRequests = mutableListOf() + val findingsToTriggeredQueries = mutableMapOf>() + + for (iocMatch in iocMatches) { + val finding = Finding( + id = "ioc" + UUID.randomUUID().toString(), + relatedDocIds = listOf(iocMatch), + correlatedDocIds = listOf(), + monitorId = monitor.id, + monitorName = monitor.name, + index = (monitor.inputs[0] as DocLevelMonitorInput).indices[0], + docLevelQueries = listOf(DocLevelQuery("threat_intel", iocMatch, emptyList(), "", emptyList())), + timestamp = Instant.now(), + executionId = null, + ) + val findingStr = + finding.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS) + .string() + log.debug("Findings: $findingStr") + indexRequests += IndexRequest(monitor.dataSources.findingsIndex) + .source(findingStr, XContentType.JSON) + .id(finding.id) + .opType(DocWriteRequest.OpType.CREATE) + } + bulkIndexFindings(monitor, indexRequests) + } + + private suspend fun bulkIndexFindings( + monitor: Monitor, + indexRequests: List, + ) { + indexRequests.chunked(1000).forEach { batch -> + val bulkResponse: BulkResponse = client.suspendUntil { + bulk(BulkRequest().add(batch), it) + } + if (bulkResponse.hasFailures()) { + bulkResponse.items.forEach { item -> + if (item.isFailed) { + log.error("Failed indexing the finding ${item.id} of monitor [${monitor.id}]") + } + } + } else { + log.debug("[${bulkResponse.items.size}] All findings successfully indexed.") + } + } + val res: RefreshResponse = + client.suspendUntil { client.execute(RefreshAction.INSTANCE, RefreshRequest(monitor.dataSources.findingsIndex)) } + } +} diff --git a/alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportDocLevelMonitorFanOutAction.kt b/alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportDocLevelMonitorFanOutAction.kt index 48f55dadc..11a03dcab 100644 --- a/alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportDocLevelMonitorFanOutAction.kt +++ b/alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportDocLevelMonitorFanOutAction.kt @@ -61,6 +61,7 @@ import org.opensearch.alerting.settings.AlertingSettings.Companion.MAX_ACTIONABL import org.opensearch.alerting.settings.AlertingSettings.Companion.PERCOLATE_QUERY_DOCS_SIZE_MEMORY_PERCENTAGE_LIMIT import org.opensearch.alerting.settings.AlertingSettings.Companion.PERCOLATE_QUERY_MAX_NUM_DOCS_IN_MEMORY import org.opensearch.alerting.settings.DestinationSettings +import org.opensearch.alerting.threatintel.ThreatIntelDetectionService import org.opensearch.alerting.util.AlertingException import org.opensearch.alerting.util.defaultToPerExecutionAction import org.opensearch.alerting.util.destinationmigration.NotificationActionConfigs @@ -140,7 +141,8 @@ class TransportDocLevelMonitorFanOutAction val alertService: AlertService, val scriptService: ScriptService, val settings: Settings, - val xContentRegistry: NamedXContentRegistry + val xContentRegistry: NamedXContentRegistry, + val threatIntelDetectionService: ThreatIntelDetectionService, ) : HandledTransportAction( DocLevelMonitorFanOutAction.NAME, transportService, actionFilters, ::DocLevelMonitorFanOutRequest ), @@ -152,21 +154,38 @@ class TransportDocLevelMonitorFanOutAction var totalDocsSizeInBytesStat = 0L var docsSizeOfBatchInBytes = 0L var findingsToTriggeredQueries: Map> = mutableMapOf() + // Maps a finding ID to the related document. private val findingIdToDocSource = mutableMapOf() - @Volatile var percQueryMaxNumDocsInMemory: Int = PERCOLATE_QUERY_MAX_NUM_DOCS_IN_MEMORY.get(settings) - @Volatile var percQueryDocsSizeMemoryPercentageLimit: Int = PERCOLATE_QUERY_DOCS_SIZE_MEMORY_PERCENTAGE_LIMIT.get(settings) - @Volatile var docLevelMonitorShardFetchSize: Int = DOC_LEVEL_MONITOR_SHARD_FETCH_SIZE.get(settings) - @Volatile var findingsIndexBatchSize: Int = FINDINGS_INDEXING_BATCH_SIZE.get(settings) - @Volatile var maxActionableAlertCount: Long = MAX_ACTIONABLE_ALERT_COUNT.get(settings) - @Volatile var retryPolicy = BackoffPolicy.constantBackoff(ALERT_BACKOFF_MILLIS.get(settings), ALERT_BACKOFF_COUNT.get(settings)) - @Volatile var allowList: List = DestinationSettings.ALLOW_LIST.get(settings) - @Volatile var fetchOnlyQueryFieldNames = DOC_LEVEL_MONITOR_FETCH_ONLY_QUERY_FIELDS_ENABLED.get(settings) + @Volatile + var percQueryMaxNumDocsInMemory: Int = PERCOLATE_QUERY_MAX_NUM_DOCS_IN_MEMORY.get(settings) + + @Volatile + var percQueryDocsSizeMemoryPercentageLimit: Int = PERCOLATE_QUERY_DOCS_SIZE_MEMORY_PERCENTAGE_LIMIT.get(settings) + + @Volatile + var docLevelMonitorShardFetchSize: Int = DOC_LEVEL_MONITOR_SHARD_FETCH_SIZE.get(settings) + + @Volatile + var findingsIndexBatchSize: Int = FINDINGS_INDEXING_BATCH_SIZE.get(settings) + + @Volatile + var maxActionableAlertCount: Long = MAX_ACTIONABLE_ALERT_COUNT.get(settings) + + @Volatile + var retryPolicy = BackoffPolicy.constantBackoff(ALERT_BACKOFF_MILLIS.get(settings), ALERT_BACKOFF_COUNT.get(settings)) + + @Volatile + var allowList: List = DestinationSettings.ALLOW_LIST.get(settings) + + @Volatile + var fetchOnlyQueryFieldNames = DOC_LEVEL_MONITOR_FETCH_ONLY_QUERY_FIELDS_ENABLED.get(settings) /* Contains list of docs source that are held in memory to submit to percolate query against query index. * Docs are fetched from the source index per shard and transformed.*/ val transformedDocs = mutableListOf>() + val searchHitsBeingProcessed: MutableList = mutableListOf() init { clusterService.clusterSettings.addSettingsUpdateConsumer(PERCOLATE_QUERY_MAX_NUM_DOCS_IN_MEMORY) { @@ -201,7 +220,7 @@ class TransportDocLevelMonitorFanOutAction override fun doExecute( task: Task, request: DocLevelMonitorFanOutRequest, - listener: ActionListener + listener: ActionListener, ) { scope.launch { executeMonitor(request, listener) @@ -210,7 +229,7 @@ class TransportDocLevelMonitorFanOutAction private suspend fun executeMonitor( request: DocLevelMonitorFanOutRequest, - listener: ActionListener + listener: ActionListener, ) { try { val monitor = request.monitor @@ -265,7 +284,8 @@ class TransportDocLevelMonitorFanOutAction updatedIndexNames, concreteIndicesSeenSoFar, ArrayList(fieldsToBeQueried), - shardIds.map { it.id } + iocFields = docLevelMonitorInput.iocFieldNames, + shardIds.map { it.id }, ) { shard, maxSeqNo -> // function passed to update last run context with new max sequence number indexExecutionContext.updatedLastRunContext[shard] = maxSeqNo } @@ -357,7 +377,7 @@ class TransportDocLevelMonitorFanOutAction queryToDocIds: Map>, dryrun: Boolean, executionId: String, - workflowRunContext: WorkflowRunContext? + workflowRunContext: WorkflowRunContext?, ): DocumentLevelTriggerRunResult { val triggerCtx = DocumentLevelTriggerExecutionContext(monitor, trigger) val triggerResult = triggerService.runDocLevelTrigger(monitor, trigger, queryToDocIds) @@ -536,7 +556,7 @@ class TransportDocLevelMonitorFanOutAction private suspend fun bulkIndexFindings( monitor: Monitor, - indexRequests: List + indexRequests: List, ) { indexRequests.chunked(findingsIndexBatchSize).forEach { batch -> val bulkResponse: BulkResponse = client.suspendUntil { @@ -557,7 +577,7 @@ class TransportDocLevelMonitorFanOutAction private fun publishFinding( monitor: Monitor, - finding: Finding + finding: Finding, ) { val publishFindingsRequest = PublishFindingsRequest(monitor.id, finding) AlertingPluginInterface.publishFinding( @@ -575,7 +595,7 @@ class TransportDocLevelMonitorFanOutAction action: Action, ctx: TriggerExecutionContext, monitor: Monitor, - dryrun: Boolean + dryrun: Boolean, ): ActionRunResult { return try { if (ctx is QueryLevelTriggerExecutionContext && !MonitorRunnerService.isActionActionable(action, ctx.alert)) { @@ -624,7 +644,7 @@ class TransportDocLevelMonitorFanOutAction protected suspend fun getConfigAndSendNotification( action: Action, subject: String?, - message: String + message: String, ): String { val config = getConfigForNotificationAction(action) if (config.destination == null && config.channel == null) { @@ -673,8 +693,9 @@ class TransportDocLevelMonitorFanOutAction monitorInputIndices: List, concreteIndices: List, fieldsToBeQueried: List, + iocFields: List, shardList: List, - updateLastRunContext: (String, String) -> Unit + updateLastRunContext: (String, String) -> Unit, ) { for (shardId in shardList) { val shard = shardId.toString() @@ -690,6 +711,7 @@ class TransportDocLevelMonitorFanOutAction to, indexExecutionCtx.docIds, fieldsToBeQueried, + iocFields ) if (hits.hits.isEmpty()) { if (to == Long.MAX_VALUE) { @@ -697,6 +719,7 @@ class TransportDocLevelMonitorFanOutAction } break } + searchHitsBeingProcessed.addAll(hits.hits.asList()) if (to == Long.MAX_VALUE) { // max sequence number of shard needs to be computed updateLastRunContext(shard, hits.hits[0].seqNo.toString()) } @@ -783,9 +806,16 @@ class TransportDocLevelMonitorFanOutAction } } totalDocsQueriedStat += transformedDocs.size.toLong() + if ((monitor.inputs[0] as DocLevelMonitorInput).iocFieldNames.isNotEmpty()) + threatIntelDetectionService.scanDataAgainstThreatIntel( + monitor, + listOf(".opensearch-sap-threat-intel*"), + searchHitsBeingProcessed + ) } finally { transformedDocs.clear() docsSizeOfBatchInBytes = 0 + searchHitsBeingProcessed.clear() } } @@ -869,6 +899,7 @@ class TransportDocLevelMonitorFanOutAction maxSeqNo: Long, docIds: List? = null, fieldsToFetch: List, + iocFields: List, ): SearchHits { if (prevSeqNo?.equals(maxSeqNo) == true && maxSeqNo != 0L) { return SearchHits.empty() @@ -898,6 +929,7 @@ class TransportDocLevelMonitorFanOutAction request.source().fetchField(field) } } + iocFields.forEach { request.source().fetchField(it) } val response: SearchResponse = client.suspendUntil { client.search(request, it) } if (response.status() !== RestStatus.OK) { throw IOException("Failed to search shard: [$shard] in index [$index]. Response status is ${response.status()}") @@ -1042,7 +1074,7 @@ class TransportDocLevelMonitorFanOutAction */ private suspend fun getDocSources( findingToDocPairs: List>, - monitor: Monitor + monitor: Monitor, ) { val docFieldTags = parseSampleDocTags(monitor.triggers) val request = MultiGetRequest() @@ -1074,7 +1106,7 @@ class TransportDocLevelMonitorFanOutAction * To cover both of these cases, the Notification config will take precedence and if it is not found, the Destination will be retrieved. */ private suspend fun getConfigForNotificationAction( - action: Action + action: Action, ): NotificationActionConfigs { var destination: Destination? = null var notificationPermissionException: Exception? = null @@ -1150,7 +1182,7 @@ class TransportDocLevelMonitorFanOutAction } private fun constructErrorMessageFromTriggerResults( - triggerResults: MutableMap? = null + triggerResults: MutableMap? = null, ): String { var errorMessage = "" if (triggerResults != null) { @@ -1175,6 +1207,6 @@ class TransportDocLevelMonitorFanOutAction var indexName: String, var concreteIndexName: String, var docId: String, - var docSource: BytesReference + var docSource: BytesReference, ) } diff --git a/alerting/src/test/kotlin/org/opensearch/alerting/MonitorDataSourcesIT.kt b/alerting/src/test/kotlin/org/opensearch/alerting/MonitorDataSourcesIT.kt index 7480a02c7..598787559 100644 --- a/alerting/src/test/kotlin/org/opensearch/alerting/MonitorDataSourcesIT.kt +++ b/alerting/src/test/kotlin/org/opensearch/alerting/MonitorDataSourcesIT.kt @@ -413,7 +413,7 @@ class MonitorDataSourcesIT : AlertingSingleNodeTestCase() { queryFieldNames = listOf("alias.some.fff", "source.ip.v6.v1") ) val docLevelInput = DocLevelMonitorInput( - "description", listOf(index), listOf(docQuery1, docQuery2, docQuery3, docQuery4, docQuery5, docQuery6, docQuery7) + "description", listOf(index), listOf(docQuery1), iocFieldNames = listOf("source.ip.v6.v1") ) val trigger = randomDocumentLevelTrigger(condition = ALWAYS_RUN) val customFindingsIndex = "custom_findings_index" @@ -442,7 +442,23 @@ class MonitorDataSourcesIT : AlertingSingleNodeTestCase() { "type.subtype" : "some subtype", "supertype.type" : "some type" }""" + val testDoc1 = """{ + "message" : "This is an error from IAD region", + "source.ip.v6.v1" : 123456, + "source.ip.v6.v2" : 16645, + "source.ip.v4.v0" : 120, + "test_bad_char" : "\u0000", + "test_strict_date_time" : "$testTime", + "test_field.some_other_field" : "us-west-2", + "type.subtype" : "some subtype", + "supertype.type" : "some type" + }""" + val doc = "{\"ioc\" : \"12345\"}" + val doc1 = "{\"ioc\" : \"123456\"}" indexDoc(index, "1", testDoc) + indexDoc(index, "2", testDoc1) + indexDoc(".opensearch-sap-threat-intel", "1", doc) + indexDoc(".opensearch-sap-threat-intel", "2", doc1) client().admin().indices().putMapping( PutMappingRequest(index).source("alias.some.fff", "type=alias,path=test_field.some_other_field") ) diff --git a/core/build.gradle b/core/build.gradle index b1ecf7eac..7aeb8c284 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -19,7 +19,7 @@ dependencies { exclude group: 'com.google.guava' } implementation 'com.google.guava:guava:32.0.1-jre' - api "org.opensearch:common-utils:${common_utils_version}@jar" + api files("/Users/snistala/Documents/opensearch/common-utils/build/libs/common-utils-3.0.0.0-SNAPSHOT.jar") implementation 'commons-validator:commons-validator:1.7' testImplementation "org.opensearch.test:framework:${opensearch_version}"