diff --git a/html/js/i18n.js b/html/js/i18n.js
index 957aa2a2..080063c8 100644
--- a/html/js/i18n.js
+++ b/html/js/i18n.js
@@ -66,6 +66,8 @@ const i18n = {
address: 'Address',
advanced: 'Temporarily enable advanced interface features',
ago: 'ago',
+ aiGenSummary: 'AI-Generated Summary',
+ aiSummaryStale: 'This AI summary was generated for a previous version of the detection and may not be accurate.',
alertAcknowledge: 'Acknowledge',
alertEscalated: 'This alert has already been escalated',
alertUndoAcknowledge: 'Undo Acknowledge',
diff --git a/html/js/routes/detection.js b/html/js/routes/detection.js
index 83164cb5..1edae521 100644
--- a/html/js/routes/detection.js
+++ b/html/js/routes/detection.js
@@ -186,6 +186,7 @@ routes.push({ path: '/detection/:id', name: 'detection', component: {
{ pattern: /condition:/m, message: this.$root.i18n.invalidDetectionStrelkaMissingCondition, match: false },
],
},
+ showUnreviewedAiSummaries: false,
}},
created() {
this.$root.initializeEditor();
@@ -211,6 +212,7 @@ routes.push({ path: '/detection/:id', name: 'detection', component: {
this.renderAbbreviatedCount = params["renderAbbreviatedCount"];
this.severityTranslations = params['severityTranslations'];
this.ruleTemplates = params['templateDetections'];
+ this.showUnreviewedAiSummaries = params['showUnreviewedAiSummaries'];
if (this.$route.params.id === 'create') {
this.detect = this.newDetection();
@@ -1469,6 +1471,9 @@ routes.push({ path: '/detection/:id', name: 'detection', component: {
},
checkOverrideChangedKey(id, index, key) {
return this.changedOverrideKeys?.[id]?.[index]?.includes(key);
+ },
+ showAiSummary() {
+ return !!(this?.detect?.aiSummary && (this.detect.aiSummaryReviewed || this.showUnreviewedAiSummaries));
}
}
}});
diff --git a/html/js/routes/detection.test.js b/html/js/routes/detection.test.js
index 63cee55a..d8b2b257 100644
--- a/html/js/routes/detection.test.js
+++ b/html/js/routes/detection.test.js
@@ -1179,4 +1179,38 @@ test('validateSuricata', () => {
comp.detect.publicId = '999999';
msg = comp.validateSuricata();
expect(msg).toBe(null);
+});
+
+test('showAiSummary', () => {
+ comp.detect = null;
+ expect(comp.showAiSummary()).toBe(false);
+
+ comp.detect = { engine: 'strelka' };
+ expect(comp.showAiSummary()).toBe(false);
+
+ comp.detect.aiSummary = 'aiSummary';
+ expect(comp.showAiSummary()).toBe(false);
+
+ comp.detect.aiSummaryReviewed = true;
+ expect(comp.showAiSummary()).toBe(true);
+
+ comp.detect.aiSummary = '';
+ expect(comp.showAiSummary()).toBe(false);
+
+ comp.showUnreviewedAiSummaries = true;
+
+ comp.detect = null;
+ expect(comp.showAiSummary()).toBe(false);
+
+ comp.detect = { engine: 'elastalert' };
+ expect(comp.showAiSummary()).toBe(false);
+
+ comp.detect.aiSummary = 'aiSummary';
+ expect(comp.showAiSummary()).toBe(true);
+
+ comp.detect.aiSummaryReviewed = true;
+ expect(comp.showAiSummary()).toBe(true);
+
+ comp.detect.aiSummary = '';
+ expect(comp.showAiSummary()).toBe(false);
});
\ No newline at end of file
diff --git a/model/detection.go b/model/detection.go
index 8073256b..9a017d6b 100644
--- a/model/detection.go
+++ b/model/detection.go
@@ -120,6 +120,15 @@ type Detection struct {
// elastalert - sigma only
Product string `json:"product,omitempty"`
Service string `json:"service,omitempty"`
+
+ // AI Description fields
+ *AiFields `json:",omitempty"`
+}
+
+type AiFields struct {
+ AiSummary string `json:"aiSummary"`
+ AiSummaryReviewed bool `json:"aiSummaryReviewed"`
+ IsAiSummaryStale bool `json:"isSummaryStale"`
}
type DetectionComment struct {
@@ -353,3 +362,10 @@ type AuditInfo struct {
Op string
Detection *Detection
}
+
+type AiSummary struct {
+ PublicId string
+ Reviewed bool `yaml:"Reviewed"`
+ Summary string `yaml:"Summary"`
+ RuleBodyHash string `yaml:"Rule-Body-Hash"`
+}
diff --git a/model/rulerepo.go b/model/rulerepo.go
index cd9fb148..4f99c853 100644
--- a/model/rulerepo.go
+++ b/model/rulerepo.go
@@ -12,6 +12,7 @@ import (
type RuleRepo struct {
Repo string
+ Branch *string
License string
Folder *string
Community bool
diff --git a/server/detectionengine.go b/server/detectionengine.go
index 3c5b381e..a10f2c4e 100644
--- a/server/detectionengine.go
+++ b/server/detectionengine.go
@@ -21,6 +21,7 @@ type DetectionEngine interface {
GetState() *model.EngineState
GenerateUnusedPublicId(ctx context.Context) (string, error)
ApplyFilters(detect *model.Detection) (didFilterAct bool, err error)
+ MergeAuxiliaryData(detect *model.Detection) error
}
type SyncStatus struct {
diff --git a/server/detectionhandler.go b/server/detectionhandler.go
index 3275107e..d912ad17 100644
--- a/server/detectionhandler.go
+++ b/server/detectionhandler.go
@@ -90,6 +90,22 @@ func (h *DetectionHandler) getDetection(w http.ResponseWriter, r *http.Request)
return
}
+ eng, ok := h.server.DetectionEngines[detect.Engine]
+ if !ok {
+ log.WithFields(log.Fields{
+ "detectionEngine": detect.Engine,
+ "detectionPublicId": detectId,
+ }).Error("retrieved detection with unsupported engine")
+ } else {
+ err = eng.MergeAuxiliaryData(detect)
+ if err != nil {
+ log.WithError(err).WithFields(log.Fields{
+ "detectionEngine": detect.Engine,
+ "detectionPublicId": detectId,
+ }).Error("unable to merge auxiliary data into detection")
+ }
+ }
+
web.Respond(w, r, http.StatusOK, detect)
}
@@ -597,8 +613,8 @@ func (h *DetectionHandler) bulkUpdateDetectionAsync(ctx context.Context, body *B
filterApplied, err := engine.ApplyFilters(detect)
if err != nil {
logger.WithError(err).WithFields(log.Fields{
- "publicId": detect.PublicID,
- "engine": detect.Engine,
+ "detectionPublicId": detect.PublicID,
+ "detectionEngine": detect.Engine,
}).Error("unable to apply engine filters to detection")
return
diff --git a/server/modules/detections/ai_summary.go b/server/modules/detections/ai_summary.go
new file mode 100644
index 00000000..0ab52d4e
--- /dev/null
+++ b/server/modules/detections/ai_summary.go
@@ -0,0 +1,146 @@
+package detections
+
+import (
+ "errors"
+ "fmt"
+ "net/url"
+ "path"
+ "path/filepath"
+ "sync"
+ "time"
+
+ "github.com/security-onion-solutions/securityonion-soc/model"
+
+ "github.com/apex/log"
+ "gopkg.in/yaml.v3"
+)
+
+var aiRepoMutex = sync.RWMutex{}
+var lastSuccessfulAiUpdate time.Time
+
+type AiLoader interface {
+ LoadAuxiliaryData(summaries []*model.AiSummary) error
+}
+
+//go:generate mockgen -destination mock/mock_ailoader.go -package mock . AiLoader
+
+func RefreshAiSummaries(eng AiLoader, lang model.SigLanguage, isRunning *bool, aiRepoPath string, aiRepoUrl string, aiRepoBranch string, logger *log.Entry, iom IOManager) error {
+ err := updateAiRepo(isRunning, aiRepoPath, aiRepoUrl, aiRepoBranch, iom)
+ if err != nil {
+ if errors.Is(err, ErrModuleStopped) {
+ return err
+ }
+
+ logger.WithError(err).WithFields(log.Fields{
+ "aiRepoUrl": aiRepoUrl,
+ "aiRepoPath": aiRepoPath,
+ }).Error("unable to update AI repo")
+ }
+
+ parser, err := url.Parse(aiRepoUrl)
+ if err != nil {
+ log.WithError(err).WithField("aiRepoUrl", aiRepoUrl).Error("failed to parse repo URL, doing nothing with it")
+ } else {
+ _, lastFolder := path.Split(parser.Path)
+ repoPath := filepath.Join(aiRepoPath, lastFolder)
+
+ sums, err := readAiSummary(isRunning, repoPath, lang, logger, iom)
+ if err != nil {
+ logger.WithError(err).WithField("repoPath", repoPath).Error("unable to read AI summaries")
+ } else {
+ err = eng.LoadAuxiliaryData(sums)
+ if err != nil {
+ logger.WithError(err).Error("unable to load AI summaries")
+ } else {
+ logger.Info("successfully loaded AI summaries")
+ }
+ }
+ }
+
+ return nil
+}
+
+func updateAiRepo(isRunning *bool, baseRepoFolder string, repoUrl string, branch string, iom IOManager) error {
+ if time.Since(lastSuccessfulAiUpdate) < time.Second*5 {
+ log.Info("AI summary repo was updated recently, skipping update")
+ return nil
+ }
+
+ aiRepoMutex.Lock()
+ defer aiRepoMutex.Unlock()
+
+ if time.Since(lastSuccessfulAiUpdate) < time.Second*5 {
+ log.Info("AI summary repo was updated recently, skipping update")
+ return nil
+ }
+
+ var branchPtr *string
+ if branch != "" {
+ branchPtr = &branch
+ }
+
+ _, _, err := UpdateRepos(isRunning, baseRepoFolder, []*model.RuleRepo{
+ {
+ Repo: repoUrl,
+ Branch: branchPtr,
+ },
+ }, iom)
+
+ if err == nil {
+ lastSuccessfulAiUpdate = time.Now()
+ }
+
+ return err
+}
+
+func readAiSummary(isRunning *bool, repoRoot string, lang model.SigLanguage, logger *log.Entry, iom IOManager) (sums []*model.AiSummary, err error) {
+ aiRepoMutex.RLock()
+ defer aiRepoMutex.RUnlock()
+
+ filename := fmt.Sprintf("%s_summaries.yaml", lang)
+ targetFile := filepath.Join(repoRoot, "detections-ai/", filename)
+
+ logger.WithField("targetFile", targetFile).Info("reading AI summaries")
+
+ raw, err := iom.ReadFile(targetFile)
+ if err != nil {
+ return nil, err
+ }
+
+ // large yaml files take 30+ seconds to unmarshal, so we need to check if the
+ // module has stopped or risk becoming unresponsive when sent a signal to stop
+ done := false
+ data := map[string]*model.AiSummary{}
+
+ go func() {
+ err = yaml.Unmarshal(raw, data)
+ done = true
+ }()
+
+ for !done {
+ if !*isRunning {
+ return nil, ErrModuleStopped
+ }
+
+ time.Sleep(time.Millisecond * 200)
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ logger.Info("successfully unmarshalled AI summaries, parsing...")
+
+ for pid, sum := range data {
+ if !*isRunning {
+ return nil, ErrModuleStopped
+ }
+
+ sum.PublicId = pid
+ sums = append(sums, sum)
+ }
+
+ logger.WithField("aiSummaryCount", len(sums)).Info("successfully parsed AI summaries")
+
+ return sums, nil
+}
diff --git a/server/modules/detections/ai_summary_test.go b/server/modules/detections/ai_summary_test.go
new file mode 100644
index 00000000..f2a0e851
--- /dev/null
+++ b/server/modules/detections/ai_summary_test.go
@@ -0,0 +1,46 @@
+package detections
+
+import (
+ "io/fs"
+ "testing"
+
+ "github.com/apex/log"
+ "github.com/security-onion-solutions/securityonion-soc/model"
+ "github.com/security-onion-solutions/securityonion-soc/server/modules/detections/mock"
+
+ "github.com/tj/assert"
+ "go.uber.org/mock/gomock"
+)
+
+func TestRefreshAiSummaries(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ isRunning := true
+ repo := "http://github.com/user/repo1"
+ branch := "generated-summaries-stable"
+ summaries := `{"87e55c67-46f0-4a7b-a3c6-d473ab7e8392": { "Reviewed": false, "Summary": "ai text goes here"}, "a23077fc-a5ef-427f-92ab-d3de7f56834d": { "Reviewed": true, "Summary": "ai text goes here" } }`
+
+ iom := mock.NewMockIOManager(ctrl)
+ loader := mock.NewMockAiLoader(ctrl)
+
+ iom.EXPECT().ReadDir("baseRepoFolder").Return([]fs.DirEntry{}, nil)
+ iom.EXPECT().CloneRepo(gomock.Any(), "baseRepoFolder/repo1", repo, &branch).Return(nil)
+ iom.EXPECT().ReadFile("baseRepoFolder/repo1/detections-ai/sigma_summaries.yaml").Return([]byte(summaries), nil)
+ loader.EXPECT().LoadAuxiliaryData([]*model.AiSummary{
+ {
+ PublicId: "87e55c67-46f0-4a7b-a3c6-d473ab7e8392",
+ Summary: "ai text goes here",
+ },
+ {
+ PublicId: "a23077fc-a5ef-427f-92ab-d3de7f56834d",
+ Reviewed: true,
+ Summary: "ai text goes here",
+ },
+ }).Return(nil)
+
+ logger := log.WithField("test", true)
+
+ err := RefreshAiSummaries(loader, model.SigLangSigma, &isRunning, "baseRepoFolder", repo, branch, logger, iom)
+ assert.NoError(t, err)
+}
diff --git a/server/modules/detections/detengine_helpers.go b/server/modules/detections/detengine_helpers.go
index 69d8e4ba..a69ef2ad 100644
--- a/server/modules/detections/detengine_helpers.go
+++ b/server/modules/detections/detengine_helpers.go
@@ -18,7 +18,6 @@ import (
"sync"
"time"
- "github.com/security-onion-solutions/securityonion-soc/config"
"github.com/security-onion-solutions/securityonion-soc/model"
"github.com/apex/log"
@@ -122,7 +121,7 @@ type RepoOnDisk struct {
WasModified bool
}
-func UpdateRepos(isRunning *bool, baseRepoFolder string, rulesRepos []*model.RuleRepo, cfg *config.ServerConfig, iom IOManager) (allRepos []*RepoOnDisk, anythingNew bool, err error) {
+func UpdateRepos(isRunning *bool, baseRepoFolder string, rulesRepos []*model.RuleRepo, iom IOManager) (allRepos []*RepoOnDisk, anythingNew bool, err error) {
allRepos = make([]*RepoOnDisk, 0, len(rulesRepos))
// read existing repos
@@ -171,7 +170,7 @@ func UpdateRepos(isRunning *bool, baseRepoFolder string, rulesRepos []*model.Rul
defer cancel()
// repo already exists, pull
- dirty.WasModified, reclone = iom.PullRepo(ctx, repoPath)
+ dirty.WasModified, reclone = iom.PullRepo(ctx, repoPath, repo.Branch)
if dirty.WasModified {
anythingNew = true
}
@@ -194,7 +193,7 @@ func UpdateRepos(isRunning *bool, baseRepoFolder string, rulesRepos []*model.Rul
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel()
- err = iom.CloneRepo(ctx, repoPath, repo.Repo)
+ err = iom.CloneRepo(ctx, repoPath, repo.Repo, repo.Branch)
if err != nil {
log.WithError(err).WithField("repoPath", repoPath).Error("failed to clone repo, doing nothing with it")
continue
@@ -250,18 +249,18 @@ func CheckWriteNoRead(ctx context.Context, DetStore GetterByPublicId, writeNoRea
return false
}
- log.WithField("publicId", *writeNoRead).Error("detection was written but not read back, attempting read before continuing")
+ log.WithField("detectionPublicId", *writeNoRead).Error("detection was written but not read back, attempting read before continuing")
// det, err := e.srv.Detectionstore.GetDetectionByPublicId(e.srv.Context, *writeNoRead)
det, err := DetStore.GetDetectionByPublicId(ctx, *writeNoRead)
if err != nil {
- log.WithError(err).WithField("publicId", *writeNoRead).Error("failed to read back detection")
+ log.WithError(err).WithField("detectionPublicId", *writeNoRead).Error("failed to read back detection")
return true
}
if det == nil {
- log.WithField("publicId", *writeNoRead).Error("detection still not found")
+ log.WithField("detectionPublicId", *writeNoRead).Error("detection still not found")
return true
}
@@ -315,12 +314,12 @@ func DeduplicateByPublicId(detects []*model.Detection) []*model.Detection {
existing, inSet := set[detect.PublicID]
if inSet {
log.WithFields(log.Fields{
- "publicId": detect.PublicID,
- "engine": detect.Engine,
- "existingRuleset": existing.Ruleset,
- "duplicateRuleset": detect.Ruleset,
- "existingTitle": existing.Title,
- "duplicateTitle": detect.Title,
+ "detectionPublicId": detect.PublicID,
+ "detectionEngine": detect.Engine,
+ "existingRuleset": existing.Ruleset,
+ "duplicateRuleset": detect.Ruleset,
+ "existingTitle": existing.Title,
+ "duplicateTitle": detect.Title,
}).Warn("duplicate publicId found, skipping")
} else {
set[detect.PublicID] = detect
diff --git a/server/modules/detections/detengine_helpers_test.go b/server/modules/detections/detengine_helpers_test.go
index de90b179..30641f6d 100644
--- a/server/modules/detections/detengine_helpers_test.go
+++ b/server/modules/detections/detengine_helpers_test.go
@@ -13,7 +13,6 @@ import (
"testing"
"time"
- "github.com/security-onion-solutions/securityonion-soc/config"
"github.com/security-onion-solutions/securityonion-soc/model"
servermock "github.com/security-onion-solutions/securityonion-soc/server/mock"
"github.com/security-onion-solutions/securityonion-soc/server/modules/detections/handmock"
@@ -95,11 +94,11 @@ func TestTruncateList(t *testing.T) {
func TestDetermineWaitTimeNoState(t *testing.T) {
ctrl := gomock.NewController(t)
- mio := mock.NewMockIOManager(ctrl)
+ iom := mock.NewMockIOManager(ctrl)
- mio.EXPECT().ReadFile("state").Return(nil, fs.ErrNotExist)
+ iom.EXPECT().ReadFile("state").Return(nil, fs.ErrNotExist)
- lastImport, dur := DetermineWaitTime(mio, "state", time.Minute)
+ lastImport, dur := DetermineWaitTime(iom, "state", time.Minute)
assert.Nil(t, lastImport, "Expected lastImport to be nil")
assert.Equal(t, time.Minute*20, dur, "Expected duration to be 20 minutes")
@@ -107,47 +106,47 @@ func TestDetermineWaitTimeNoState(t *testing.T) {
func TestDetermineWaitTime(t *testing.T) {
ctrl := gomock.NewController(t)
- mio := mock.NewMockIOManager(ctrl)
+ iom := mock.NewMockIOManager(ctrl)
tenSecAgo := time.Now().Unix() - 10
tenSecAgoStr := strconv.FormatInt(tenSecAgo, 10)
- mio.EXPECT().ReadFile("state").Return([]byte(tenSecAgoStr), nil)
+ iom.EXPECT().ReadFile("state").Return([]byte(tenSecAgoStr), nil)
- lastImport, dur := DetermineWaitTime(mio, "state", time.Minute)
+ lastImport, dur := DetermineWaitTime(iom, "state", time.Minute)
assert.NotNil(t, lastImport, "Expected lastImport not to be nil")
assert.InEpsilon(t, time.Duration(time.Second*50), dur, 1)
}
func TestDetermineWaitTimeBadRead(t *testing.T) {
ctrl := gomock.NewController(t)
- mio := mock.NewMockIOManager(ctrl)
+ iom := mock.NewMockIOManager(ctrl)
- mio.EXPECT().ReadFile("state").Return(nil, errors.New("bad read"))
- mio.EXPECT().DeleteFile("state").Return(nil)
+ iom.EXPECT().ReadFile("state").Return(nil, errors.New("bad read"))
+ iom.EXPECT().DeleteFile("state").Return(nil)
- lastImport, dur := DetermineWaitTime(mio, "state", time.Minute)
+ lastImport, dur := DetermineWaitTime(iom, "state", time.Minute)
assert.Nil(t, lastImport, "Expected lastImport to be nil")
assert.Equal(t, time.Duration(time.Minute*20), dur)
}
func TestDetermineWaitTimeBadValue(t *testing.T) {
ctrl := gomock.NewController(t)
- mio := mock.NewMockIOManager(ctrl)
+ iom := mock.NewMockIOManager(ctrl)
- mio.EXPECT().ReadFile("state").Return([]byte("bad"), nil)
- mio.EXPECT().DeleteFile("state").Return(nil)
+ iom.EXPECT().ReadFile("state").Return([]byte("bad"), nil)
+ iom.EXPECT().DeleteFile("state").Return(nil)
- lastImport, dur := DetermineWaitTime(mio, "state", time.Minute)
+ lastImport, dur := DetermineWaitTime(iom, "state", time.Minute)
assert.Nil(t, lastImport, "Expected lastImport to be nil")
assert.Equal(t, time.Duration(time.Minute*20), dur)
}
func TestWriteStateFile(t *testing.T) {
ctrl := gomock.NewController(t)
- mio := mock.NewMockIOManager(ctrl)
+ iom := mock.NewMockIOManager(ctrl)
- mio.EXPECT().WriteFile("state", gomock.Any(), fs.FileMode(0644)).DoAndReturn(func(path string, contents []byte, perm fs.FileMode) error {
+ iom.EXPECT().WriteFile("state", gomock.Any(), fs.FileMode(0644)).DoAndReturn(func(path string, contents []byte, perm fs.FileMode) error {
unix, err := strconv.ParseInt(string(contents), 10, 64)
assert.NoError(t, err)
assert.InEpsilon(t, time.Now().Unix(), unix, 2)
@@ -155,7 +154,7 @@ func TestWriteStateFile(t *testing.T) {
return nil
})
- WriteStateFile(mio, "state")
+ WriteStateFile(iom, "state")
}
func TestCheckWriteNoRead(t *testing.T) {
@@ -167,28 +166,28 @@ func TestCheckWriteNoRead(t *testing.T) {
id := util.Ptr("99999")
ctx := context.Background()
- mio := servermock.NewMockDetectionstore(ctrl)
+ iom := servermock.NewMockDetectionstore(ctrl)
// No pending ID to read
- shouldFail := CheckWriteNoRead(ctx, mio, nil)
+ shouldFail := CheckWriteNoRead(ctx, iom, nil)
assert.False(t, shouldFail)
// Error querying ES
- mio.EXPECT().GetDetectionByPublicId(gomock.Any(), *id).Return(nil, errors.New("connection error"))
+ iom.EXPECT().GetDetectionByPublicId(gomock.Any(), *id).Return(nil, errors.New("connection error"))
- shouldFail = CheckWriteNoRead(ctx, mio, id)
+ shouldFail = CheckWriteNoRead(ctx, iom, id)
assert.True(t, shouldFail)
// Detection still not found
- mio.EXPECT().GetDetectionByPublicId(gomock.Any(), *id).Return(nil, nil)
+ iom.EXPECT().GetDetectionByPublicId(gomock.Any(), *id).Return(nil, nil)
- shouldFail = CheckWriteNoRead(ctx, mio, id)
+ shouldFail = CheckWriteNoRead(ctx, iom, id)
assert.True(t, shouldFail)
// Successfully read back the missing ID
- mio.EXPECT().GetDetectionByPublicId(gomock.Any(), *id).Return(&model.Detection{}, nil)
+ iom.EXPECT().GetDetectionByPublicId(gomock.Any(), *id).Return(&model.Detection{}, nil)
- shouldFail = CheckWriteNoRead(ctx, mio, id)
+ shouldFail = CheckWriteNoRead(ctx, iom, id)
assert.False(t, shouldFail)
}
@@ -408,6 +407,8 @@ func TestUpdateRepos(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
+ branch := "branch"
+
iom := mock.NewMockIOManager(ctrl)
iom.EXPECT().ReadDir("baseRepoFolder").Return([]fs.DirEntry{
&handmock.MockDirEntry{
@@ -419,8 +420,8 @@ func TestUpdateRepos(t *testing.T) {
Dir: true,
},
}, nil)
- iom.EXPECT().PullRepo(gomock.Any(), "baseRepoFolder/repo1").Return(false, false)
- iom.EXPECT().CloneRepo(gomock.Any(), "baseRepoFolder/repo2", "http://github.com/user/repo2").Return(nil)
+ iom.EXPECT().PullRepo(gomock.Any(), "baseRepoFolder/repo1", nil).Return(false, false)
+ iom.EXPECT().CloneRepo(gomock.Any(), "baseRepoFolder/repo2", "http://github.com/user/repo2", &branch).Return(nil)
iom.EXPECT().RemoveAll("baseRepoFolder/repo3").Return(nil)
isRunning := true
@@ -430,12 +431,12 @@ func TestUpdateRepos(t *testing.T) {
Repo: "http://github.com/user/repo1",
},
{
- Repo: "http://github.com/user/repo2",
+ Repo: "http://github.com/user/repo2",
+ Branch: &branch,
},
}
- cfg := &config.ServerConfig{}
- allRepos, anythingNew, err := UpdateRepos(&isRunning, "baseRepoFolder", repos, cfg, iom)
+ allRepos, anythingNew, err := UpdateRepos(&isRunning, "baseRepoFolder", repos, iom)
assert.NoError(t, err)
assert.Len(t, allRepos, len(repos))
assert.Equal(t, &RepoOnDisk{
diff --git a/server/modules/detections/io_manager.go b/server/modules/detections/io_manager.go
index b359847d..bb7909ac 100644
--- a/server/modules/detections/io_manager.go
+++ b/server/modules/detections/io_manager.go
@@ -9,6 +9,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
+ "fmt"
"io/fs"
"net/http"
"net/url"
@@ -19,6 +20,7 @@ import (
"github.com/apex/log"
"github.com/go-git/go-git/v5"
+ "github.com/go-git/go-git/v5/plumbing"
"github.com/security-onion-solutions/securityonion-soc/config"
)
@@ -34,8 +36,8 @@ type IOManager interface {
MakeRequest(*http.Request) (*http.Response, error)
ExecCommand(cmd *exec.Cmd) ([]byte, int, time.Duration, error)
WalkDir(root string, fn fs.WalkDirFunc) error
- CloneRepo(ctx context.Context, path string, repo string) (err error)
- PullRepo(ctx context.Context, path string) (pulled bool, reclone bool)
+ CloneRepo(ctx context.Context, path string, repo string, branch *string) (err error)
+ PullRepo(ctx context.Context, path string, branch *string) (pulled bool, reclone bool)
}
type ResourceManager struct {
@@ -120,25 +122,31 @@ func (_ *ResourceManager) WalkDir(root string, fn fs.WalkDirFunc) error {
return filepath.WalkDir(root, fn)
}
-func (rm *ResourceManager) CloneRepo(ctx context.Context, path string, repo string) (err error) {
+func (rm *ResourceManager) CloneRepo(ctx context.Context, path string, repo string, branch *string) (err error) {
proxyOpts, err := proxyToTransportOptions(rm.Config.Proxy)
if err != nil {
return err
}
- _, err = git.PlainCloneContext(ctx, path, false, &git.CloneOptions{
+ opts := &git.CloneOptions{
Depth: 1,
SingleBranch: true,
URL: repo,
ProxyOptions: proxyOpts,
CABundle: []byte(rm.Config.AdditionalCA),
InsecureSkipTLS: rm.Config.InsecureSkipVerify,
- })
+ }
+
+ if branch != nil && *branch != "" {
+ opts.ReferenceName = plumbing.ReferenceName(fmt.Sprintf("refs/heads/%s", *branch))
+ }
+
+ _, err = git.PlainCloneContext(ctx, path, false, opts)
return err
}
-func (rm *ResourceManager) PullRepo(ctx context.Context, path string) (pulled bool, reclone bool) {
+func (rm *ResourceManager) PullRepo(ctx context.Context, path string, branch *string) (pulled bool, reclone bool) {
gitrepo, err := git.PlainOpen(path)
if err != nil {
log.WithError(err).WithField("repoPath", path).Error("failed to open repo, doing nothing with it")
@@ -167,13 +175,19 @@ func (rm *ResourceManager) PullRepo(ctx context.Context, path string) (pulled bo
log.WithError(err).WithField("proxy", rm.Config.Proxy).Error("unable to parse proxy url, ignoring proxy")
}
- err = work.PullContext(ctx, &git.PullOptions{
+ opts := &git.PullOptions{
Depth: 1,
SingleBranch: true,
ProxyOptions: proxyOpts,
CABundle: []byte(rm.Config.AdditionalCA),
InsecureSkipTLS: rm.Config.InsecureSkipVerify,
- })
+ }
+
+ if branch != nil && *branch != "" {
+ opts.ReferenceName = plumbing.ReferenceName(fmt.Sprintf("refs/heads/%s", *branch))
+ }
+
+ err = work.PullContext(ctx, opts)
if err != nil && err != git.NoErrAlreadyUpToDate {
log.WithError(err).WithField("repoPath", path).Error("failed to pull repo, doing nothing with it")
diff --git a/server/modules/detections/mock/mock_ailoader.go b/server/modules/detections/mock/mock_ailoader.go
new file mode 100644
index 00000000..a93acbf2
--- /dev/null
+++ b/server/modules/detections/mock/mock_ailoader.go
@@ -0,0 +1,53 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/security-onion-solutions/securityonion-soc/server/modules/detections (interfaces: AiLoader)
+//
+// Generated by this command:
+//
+// mockgen -destination mock/mock_ailoader.go -package mock . AiLoader
+//
+// Package mock is a generated GoMock package.
+package mock
+
+import (
+ reflect "reflect"
+
+ model "github.com/security-onion-solutions/securityonion-soc/model"
+ gomock "go.uber.org/mock/gomock"
+)
+
+// MockAiLoader is a mock of AiLoader interface.
+type MockAiLoader struct {
+ ctrl *gomock.Controller
+ recorder *MockAiLoaderMockRecorder
+}
+
+// MockAiLoaderMockRecorder is the mock recorder for MockAiLoader.
+type MockAiLoaderMockRecorder struct {
+ mock *MockAiLoader
+}
+
+// NewMockAiLoader creates a new mock instance.
+func NewMockAiLoader(ctrl *gomock.Controller) *MockAiLoader {
+ mock := &MockAiLoader{ctrl: ctrl}
+ mock.recorder = &MockAiLoaderMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockAiLoader) EXPECT() *MockAiLoaderMockRecorder {
+ return m.recorder
+}
+
+// LoadAuxiliaryData mocks base method.
+func (m *MockAiLoader) LoadAuxiliaryData(arg0 []*model.AiSummary) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "LoadAuxiliaryData", arg0)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// LoadAuxiliaryData indicates an expected call of LoadAuxiliaryData.
+func (mr *MockAiLoaderMockRecorder) LoadAuxiliaryData(arg0 any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAuxiliaryData", reflect.TypeOf((*MockAiLoader)(nil).LoadAuxiliaryData), arg0)
+}
diff --git a/server/modules/detections/mock/mock_iomanager.go b/server/modules/detections/mock/mock_iomanager.go
index 49c6e40f..1a43fb6d 100644
--- a/server/modules/detections/mock/mock_iomanager.go
+++ b/server/modules/detections/mock/mock_iomanager.go
@@ -43,17 +43,17 @@ func (m *MockIOManager) EXPECT() *MockIOManagerMockRecorder {
}
// CloneRepo mocks base method.
-func (m *MockIOManager) CloneRepo(arg0 context.Context, arg1, arg2 string) error {
+func (m *MockIOManager) CloneRepo(arg0 context.Context, arg1, arg2 string, arg3 *string) error {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "CloneRepo", arg0, arg1, arg2)
+ ret := m.ctrl.Call(m, "CloneRepo", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error)
return ret0
}
// CloneRepo indicates an expected call of CloneRepo.
-func (mr *MockIOManagerMockRecorder) CloneRepo(arg0, arg1, arg2 any) *gomock.Call {
+func (mr *MockIOManagerMockRecorder) CloneRepo(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloneRepo", reflect.TypeOf((*MockIOManager)(nil).CloneRepo), arg0, arg1, arg2)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloneRepo", reflect.TypeOf((*MockIOManager)(nil).CloneRepo), arg0, arg1, arg2, arg3)
}
// DeleteFile mocks base method.
@@ -103,18 +103,18 @@ func (mr *MockIOManagerMockRecorder) MakeRequest(arg0 any) *gomock.Call {
}
// PullRepo mocks base method.
-func (m *MockIOManager) PullRepo(arg0 context.Context, arg1 string) (bool, bool) {
+func (m *MockIOManager) PullRepo(arg0 context.Context, arg1 string, arg2 *string) (bool, bool) {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "PullRepo", arg0, arg1)
+ ret := m.ctrl.Call(m, "PullRepo", arg0, arg1, arg2)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(bool)
return ret0, ret1
}
// PullRepo indicates an expected call of PullRepo.
-func (mr *MockIOManagerMockRecorder) PullRepo(arg0, arg1 any) *gomock.Call {
+func (mr *MockIOManagerMockRecorder) PullRepo(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PullRepo", reflect.TypeOf((*MockIOManager)(nil).PullRepo), arg0, arg1)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PullRepo", reflect.TypeOf((*MockIOManager)(nil).PullRepo), arg0, arg1, arg2)
}
// ReadDir mocks base method.
diff --git a/server/modules/detections/sync.go b/server/modules/detections/sync.go
index 466cd055..68f7b649 100644
--- a/server/modules/detections/sync.go
+++ b/server/modules/detections/sync.go
@@ -77,11 +77,11 @@ func SyncScheduler(e DetailedDetectionEngine, syncParams *SyncSchedulerParams, e
}
log.WithFields(log.Fields{
- "detectionEngine": engName,
- "waitTimeSeconds": timerDur.Seconds(),
- "forceSync": forceSync,
- "lastSyncSuccess": lastSyncStatus,
- "expectedStartTime": time.Now().Add(timerDur).Format(time.RFC3339),
+ "detectionEngine": engName,
+ "waitTimeSeconds": timerDur.Seconds(),
+ "forceSync": forceSync,
+ "lastSyncSuccess": lastSyncStatus,
+ "expectedStartTime": time.Now().Add(timerDur).Format(time.RFC3339),
}).Info("waiting for next community rules sync")
e.ResumeIntegrityChecker()
@@ -107,7 +107,7 @@ func SyncScheduler(e DetailedDetectionEngine, syncParams *SyncSchedulerParams, e
syncId := uuid.New().String()
logger := log.WithFields(log.Fields{
"detectionEngine": engName,
- "syncId": syncId,
+ "syncId": syncId,
})
startTime := time.Now()
diff --git a/server/modules/elastalert/elastalert.go b/server/modules/elastalert/elastalert.go
index 50827526..3bcf65cd 100644
--- a/server/modules/elastalert/elastalert.go
+++ b/server/modules/elastalert/elastalert.go
@@ -9,6 +9,7 @@ import (
"archive/zip"
"bytes"
"context"
+ "crypto/md5"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
@@ -60,6 +61,20 @@ const (
DEFAULT_COMMUNITY_RULES_IMPORT_ERROR_SECS = 300
DEFAULT_FAIL_AFTER_CONSECUTIVE_ERROR_COUNT = 10
DEFAULT_INTEGRITY_CHECK_FREQUENCY_SECONDS = 600
+ DEFAULT_AI_REPO = "https://github.com/Security-Onion-Solutions/securityonion-resources"
+ DEFAULT_AI_REPO_BRANCH = "generated-summaries-stable"
+ DEFAULT_AI_REPO_PATH = "/opt/sensoroni/repos"
+ DEFAULT_SHOW_AI_SUMMARIES = true
+)
+
+var ( // treat as constant
+ DEFAULT_RULES_REPOS = []*model.RuleRepo{
+ {
+ Repo: "https://github.com/Security-Onion-Solutions/securityonion-resources",
+ License: "DRL",
+ Folder: util.Ptr("sigma/stable"),
+ },
+ }
)
var acceptedExtensions = map[string]bool{
@@ -87,6 +102,11 @@ type ElastAlertEngine struct {
airgapEnabled bool
notify bool
writeNoRead *string
+ aiSummaries *sync.Map // map[string]*detections.AiSummary{}
+ showAiSummaries bool
+ aiRepoUrl string
+ aiRepoBranch string
+ aiRepoPath string
detections.SyncSchedulerParams
detections.IntegrityCheckerData
detections.IOManager
@@ -133,6 +153,7 @@ func (e *ElastAlertEngine) Init(config module.ModuleConfig) (err error) {
e.InterruptChan = make(chan bool, 1)
e.IntegrityCheckerData.Thread = &sync.WaitGroup{}
e.IntegrityCheckerData.Interrupt = make(chan bool, 1)
+ e.aiSummaries = &sync.Map{}
e.airgapBasePath = module.GetStringDefault(config, "airgapBasePath", DEFAULT_AIRGAP_BASE_PATH)
e.CommunityRulesImportFrequencySeconds = module.GetIntDefault(config, "communityRulesImportFrequencySeconds", DEFAULT_COMMUNITY_RULES_IMPORT_FREQUENCY_SECONDS)
@@ -152,13 +173,7 @@ func (e *ElastAlertEngine) Init(config module.ModuleConfig) (err error) {
e.parseSigmaPackages(pkgs)
e.reposFolder = module.GetStringDefault(config, "reposFolder", DEFAULT_REPOS_FOLDER)
- e.rulesRepos, err = model.GetReposDefault(config, "rulesRepos", []*model.RuleRepo{
- {
- Repo: "https://github.com/Security-Onion-Solutions/securityonion-resources",
- License: "DRL",
- Folder: util.Ptr("sigma/stable"),
- },
- })
+ e.rulesRepos, err = model.GetReposDefault(config, "rulesRepos", DEFAULT_RULES_REPOS)
if err != nil {
return fmt.Errorf("unable to parse ElastAlert's rulesRepos: %w", err)
}
@@ -171,6 +186,11 @@ func (e *ElastAlertEngine) Init(config module.ModuleConfig) (err error) {
e.SyncSchedulerParams.StateFilePath = module.GetStringDefault(config, "stateFilePath", DEFAULT_STATE_FILE_PATH)
+ e.showAiSummaries = module.GetBoolDefault(config, "showAiSummaries", DEFAULT_SHOW_AI_SUMMARIES)
+ e.aiRepoUrl = module.GetStringDefault(config, "aiRepoUrl", DEFAULT_AI_REPO)
+ e.aiRepoBranch = module.GetStringDefault(config, "aiRepoBranch", DEFAULT_AI_REPO_BRANCH)
+ e.aiRepoPath = module.GetStringDefault(config, "aiRepoPath", DEFAULT_AI_REPO_PATH)
+
return nil
}
@@ -179,9 +199,28 @@ func (e *ElastAlertEngine) Start() error {
e.isRunning = true
e.IntegrityCheckerData.IsRunning = true
+ // start long running processes
go detections.SyncScheduler(e, &e.SyncSchedulerParams, &e.EngineState, model.EngineNameElastAlert, &e.isRunning, e.IOManager)
go detections.IntegrityChecker(model.EngineNameElastAlert, e, &e.IntegrityCheckerData, &e.EngineState.IntegrityFailure)
+ // update Ai Summaries once and don't block
+ if e.showAiSummaries {
+ go func() {
+ logger := log.WithField("detectionEngine", model.EngineNameElastAlert)
+
+ err := detections.RefreshAiSummaries(e, model.SigLangSigma, &e.isRunning, e.aiRepoPath, e.aiRepoUrl, e.aiRepoBranch, logger, e.IOManager)
+ if err != nil {
+ if errors.Is(err, detections.ErrModuleStopped) {
+ return
+ }
+
+ logger.WithError(err).Error("unable to refresh AI summaries")
+ } else {
+ logger.Info("successfully refreshed AI summaries")
+ }
+ }()
+ }
+
return nil
}
@@ -432,6 +471,19 @@ func (e *ElastAlertEngine) Sync(logger *log.Entry, forceSync bool) error {
e.writeNoRead = nil
+ if e.showAiSummaries {
+ err := detections.RefreshAiSummaries(e, model.SigLangSigma, &e.isRunning, e.aiRepoPath, e.aiRepoUrl, e.aiRepoBranch, logger, e.IOManager)
+ if err != nil {
+ if errors.Is(err, detections.ErrModuleStopped) {
+ return err
+ }
+
+ logger.WithError(err).Error("unable to refresh AI summaries")
+ } else {
+ logger.Info("successfully refreshed AI summaries")
+ }
+ }
+
// announce the beginning of the sync
e.EngineState.Syncing = true
@@ -473,7 +525,7 @@ func (e *ElastAlertEngine) Sync(logger *log.Entry, forceSync bool) error {
}
// ensure repos are up to date
- dirtyRepos, repoChanges, err := detections.UpdateRepos(&e.isRunning, e.reposFolder, e.rulesRepos, e.srv.Config, e.IOManager)
+ dirtyRepos, repoChanges, err := detections.UpdateRepos(&e.isRunning, e.reposFolder, e.rulesRepos, e.IOManager)
if err != nil {
if errors.Is(err, detections.ErrModuleStopped) {
return err
@@ -1421,6 +1473,41 @@ func (e *ElastAlertEngine) DuplicateDetection(ctx context.Context, detection *mo
return det, nil
}
+func (e *ElastAlertEngine) LoadAuxiliaryData(summaries []*model.AiSummary) error {
+ sum := &sync.Map{}
+ for _, summary := range summaries {
+ sum.Store(summary.PublicId, summary)
+ }
+
+ e.aiSummaries = sum
+
+ log.WithFields(log.Fields{
+ "detectionEngine": model.EngineNameElastAlert,
+ "aiSummaryCount": len(summaries),
+ }).Info("loaded AI summaries")
+
+ return nil
+}
+
+func (e *ElastAlertEngine) MergeAuxiliaryData(detect *model.Detection) error {
+ if e.showAiSummaries {
+ obj, ok := e.aiSummaries.Load(detect.PublicID)
+ if ok {
+ sig := md5.Sum([]byte(detect.Content))
+ hexSig := hex.EncodeToString(sig[:])
+
+ summary := obj.(*model.AiSummary)
+ detect.AiFields = &model.AiFields{
+ AiSummary: summary.Summary,
+ AiSummaryReviewed: summary.Reviewed,
+ IsAiSummaryStale: !strings.EqualFold(summary.RuleBodyHash, hexSig),
+ }
+ }
+ }
+
+ return nil
+}
+
type CustomWrapper struct {
DetectionTitle string `yaml:"detection_title"`
DetectionPublicId string `yaml:"detection_public_id"`
diff --git a/server/modules/elastalert/elastalert_test.go b/server/modules/elastalert/elastalert_test.go
index 1f6402b9..29ca3fd6 100644
--- a/server/modules/elastalert/elastalert_test.go
+++ b/server/modules/elastalert/elastalert_test.go
@@ -1107,11 +1107,19 @@ func TestSyncIncrementalNoChanges(t *testing.T) {
IntegrityCheckerData: detections.IntegrityCheckerData{
IsRunning: true,
},
- IOManager: iom,
+ IOManager: iom,
+ showAiSummaries: true,
+ aiRepoUrl: "aiRepoUrl",
+ aiRepoBranch: "aiRepoBranch",
+ aiRepoPath: "aiRepoPath",
}
logger := log.WithField("detectionEngine", "test-elastalert")
+ // RefreshAiSummaries
+ iom.EXPECT().ReadDir("aiRepoPath").Return([]fs.DirEntry{}, nil)
+ iom.EXPECT().CloneRepo(gomock.Any(), "aiRepoPath/aiRepoUrl", "aiRepoUrl", util.Ptr("aiRepoBranch")).Return(nil)
+ iom.EXPECT().ReadFile("aiRepoPath/aiRepoUrl/detections-ai/sigma_summaries.yaml").Return([]byte("{}"), nil)
// checkSigmaPipelines
iom.EXPECT().ReadFile("sigmaPipelineFinal").Return([]byte("data"), nil)
iom.EXPECT().ReadFile("sigmaPipelineSO").Return([]byte("data"), nil)
@@ -1128,7 +1136,7 @@ func TestSyncIncrementalNoChanges(t *testing.T) {
Dir: true,
},
}, nil)
- iom.EXPECT().PullRepo(gomock.Any(), "repos/repo").Return(false, false)
+ iom.EXPECT().PullRepo(gomock.Any(), "repos/repo", nil).Return(false, false)
// check for changes before sync
iom.EXPECT().ReadFile("rulesFingerprintFile").Return([]byte(`{"core+": "c6OTI9nTQxGEeeNkSZZB9+OESMNvfMXrb+XLtMiVhf0="}`), nil)
// WriteStateFile
@@ -1200,7 +1208,11 @@ func TestSyncChanges(t *testing.T) {
IntegrityCheckerData: detections.IntegrityCheckerData{
IsRunning: true,
},
- IOManager: iom,
+ IOManager: iom,
+ showAiSummaries: true,
+ aiRepoUrl: "aiRepoUrl",
+ aiRepoBranch: "aiRepoBranch",
+ aiRepoPath: "aiRepoPath",
}
logger := log.WithField("detectionEngine", "test-elastalert")
@@ -1208,6 +1220,10 @@ func TestSyncChanges(t *testing.T) {
workItems := []esutil.BulkIndexerItem{}
auditItems := []esutil.BulkIndexerItem{}
+ // RefreshAiSummaries
+ iom.EXPECT().ReadDir("aiRepoPath").Return([]fs.DirEntry{}, nil)
+ iom.EXPECT().CloneRepo(gomock.Any(), "aiRepoPath/aiRepoUrl", "aiRepoUrl", util.Ptr("aiRepoBranch")).Return(nil)
+ iom.EXPECT().ReadFile("aiRepoPath/aiRepoUrl/detections-ai/sigma_summaries.yaml").Return([]byte("{}"), nil)
// checkSigmaPipelines
iom.EXPECT().ReadFile("sigmaPipelineFinal").Return([]byte("data"), nil)
iom.EXPECT().ReadFile("sigmaPipelineSO").Return([]byte("data"), nil)
@@ -1224,7 +1240,7 @@ func TestSyncChanges(t *testing.T) {
Dir: true,
},
}, nil)
- iom.EXPECT().PullRepo(gomock.Any(), "repos/repo").Return(false, false)
+ iom.EXPECT().PullRepo(gomock.Any(), "repos/repo", nil).Return(false, false)
// parseRepoRules
iom.EXPECT().WalkDir("repos/repo", gomock.Any()).DoAndReturn(func(path string, fn fs.WalkDirFunc) error {
files := []fs.DirEntry{
@@ -1348,3 +1364,87 @@ func TestSyncChanges(t *testing.T) {
assert.Equal(t, []string{"abc", "", "deleteme"}, workDocIds) // update has an id, create does not, delete does
}
+
+func TestLoadAndMergeAuxiliaryData(t *testing.T) {
+ tests := []struct {
+ Name string
+ PublicId string
+ Content string
+ ExpectedAiFields bool
+ ExpectedAiSummary string
+ ExpectedReviewed bool
+ ExpectedStale bool
+ }{
+ {
+ Name: "No Auxiliary Data",
+ PublicId: "bd82a1a6-7bac-401e-afcf-5adf07c0c035",
+ ExpectedAiFields: false,
+ },
+ {
+ Name: "Data, Unreviewed",
+ PublicId: "67ee455d-099f-4048-b021-43bb91af9298",
+ Content: "alert",
+ ExpectedAiFields: true,
+ ExpectedAiSummary: "Summary for 67ee455d-099f-4048-b021-43bb91af9298",
+ ExpectedReviewed: false,
+ ExpectedStale: false,
+ },
+ {
+ Name: "Data, Reviewed",
+ PublicId: "83b3a29f-3009-4884-86c6-b6c3973788fa",
+ Content: "no-alert",
+ ExpectedAiFields: true,
+ ExpectedAiSummary: "Summary for 83b3a29f-3009-4884-86c6-b6c3973788fa",
+ ExpectedReviewed: true,
+ ExpectedStale: true,
+ },
+ }
+
+ e := ElastAlertEngine{
+ showAiSummaries: true,
+ }
+ err := e.LoadAuxiliaryData([]*model.AiSummary{
+ {
+ PublicId: "83b3a29f-3009-4884-86c6-b6c3973788fa",
+ Summary: "Summary for 83b3a29f-3009-4884-86c6-b6c3973788fa",
+ Reviewed: true,
+ RuleBodyHash: "7ed21143076d0cca420653d4345baa2f",
+ },
+ {
+ PublicId: "67ee455d-099f-4048-b021-43bb91af9298",
+ Summary: "Summary for 67ee455d-099f-4048-b021-43bb91af9298",
+ Reviewed: false,
+ RuleBodyHash: "7ed21143076d0cca420653d4345baa2f",
+ },
+ })
+ assert.NoError(t, err)
+
+ for _, test := range tests {
+ test := test
+ t.Run(test.Name, func(t *testing.T) {
+ det := &model.Detection{
+ PublicID: test.PublicId,
+ Content: test.Content,
+ }
+
+ e.showAiSummaries = true
+ err := e.MergeAuxiliaryData(det)
+ assert.NoError(t, err)
+ if test.ExpectedAiFields {
+ assert.NotNil(t, det.AiFields)
+ assert.Equal(t, test.ExpectedAiSummary, det.AiSummary)
+ assert.Equal(t, test.ExpectedReviewed, det.AiSummaryReviewed)
+ assert.Equal(t, test.ExpectedStale, det.IsAiSummaryStale)
+ } else {
+ assert.Nil(t, det.AiFields)
+ }
+
+ e.showAiSummaries = false
+ det.AiFields = nil
+
+ err = e.MergeAuxiliaryData(det)
+ assert.NoError(t, err)
+ assert.Nil(t, det.AiFields)
+ })
+ }
+}
diff --git a/server/modules/strelka/strelka.go b/server/modules/strelka/strelka.go
index 33d0f857..3886d8ad 100644
--- a/server/modules/strelka/strelka.go
+++ b/server/modules/strelka/strelka.go
@@ -8,6 +8,7 @@ package strelka
import (
"bytes"
"context"
+ "crypto/md5"
"crypto/sha256"
"encoding/hex"
"encoding/json"
@@ -49,6 +50,10 @@ const (
DEFAULT_COMMUNITY_RULES_IMPORT_ERROR_SECS = 300
DEFAULT_FAIL_AFTER_CONSECUTIVE_ERROR_COUNT = 10
DEFAULT_INTEGRITY_CHECK_FREQUENCY_SECONDS = 600
+ DEFAULT_AI_REPO = "https://github.com/Security-Onion-Solutions/securityonion-resources"
+ DEFAULT_AI_REPO_BRANCH = "generated-summaries-stable"
+ DEFAULT_AI_REPO_PATH = "/opt/sensoroni/repos"
+ DEFAULT_SHOW_AI_SUMMARIES = true
)
var titleUpdater = regexp.MustCompile(`(?im)rule\s+(\w+)(\s+:(\s*[^{]+))?(\s+)(//.*$)?(\n?){`)
@@ -66,6 +71,11 @@ type StrelkaEngine struct {
compileYaraPythonScriptPath string
notify bool
writeNoRead *string
+ aiSummaries *sync.Map // map[string]*detections.AiSummary{}
+ showAiSummaries bool
+ aiRepoUrl string
+ aiRepoBranch string
+ aiRepoPath string
detections.SyncSchedulerParams
detections.IntegrityCheckerData
detections.IOManager
@@ -103,6 +113,7 @@ func (e *StrelkaEngine) Init(config module.ModuleConfig) (err error) {
e.InterruptChan = make(chan bool, 1)
e.IntegrityCheckerData.Thread = &sync.WaitGroup{}
e.IntegrityCheckerData.Interrupt = make(chan bool, 1)
+ e.aiSummaries = &sync.Map{}
e.CommunityRulesImportFrequencySeconds = module.GetIntDefault(config, "communityRulesImportFrequencySeconds", DEFAULT_COMMUNITY_RULES_IMPORT_FREQUENCY_SECONDS)
e.yaraRulesFolder = module.GetStringDefault(config, "yaraRulesFolder", DEFAULT_YARA_RULES_FOLDER)
@@ -125,6 +136,11 @@ func (e *StrelkaEngine) Init(config module.ModuleConfig) (err error) {
e.StateFilePath = module.GetStringDefault(config, "stateFilePath", DEFAULT_STATE_FILE_PATH)
+ e.showAiSummaries = module.GetBoolDefault(config, "showAiSummaries", DEFAULT_SHOW_AI_SUMMARIES)
+ e.aiRepoUrl = module.GetStringDefault(config, "aiRepoUrl", DEFAULT_AI_REPO)
+ e.aiRepoBranch = module.GetStringDefault(config, "aiRepoBranch", DEFAULT_AI_REPO_BRANCH)
+ e.aiRepoPath = module.GetStringDefault(config, "aiRepoPath", DEFAULT_AI_REPO_PATH)
+
return nil
}
@@ -132,9 +148,28 @@ func (e *StrelkaEngine) Start() error {
e.srv.DetectionEngines[model.EngineNameStrelka] = e
e.isRunning = true
+ // start long running processes
go detections.SyncScheduler(e, &e.SyncSchedulerParams, &e.EngineState, model.EngineNameStrelka, &e.isRunning, e.IOManager)
go detections.IntegrityChecker(model.EngineNameStrelka, e, &e.IntegrityCheckerData, &e.EngineState.IntegrityFailure)
+ // update Ai Summaries once and don't block
+ if e.showAiSummaries {
+ go func() {
+ logger := log.WithField("detectionEngine", model.EngineNameStrelka)
+
+ err := detections.RefreshAiSummaries(e, model.SigLangYara, &e.isRunning, e.aiRepoPath, e.aiRepoUrl, e.aiRepoBranch, logger, e.IOManager)
+ if err != nil {
+ if errors.Is(err, detections.ErrModuleStopped) {
+ return
+ }
+
+ logger.WithError(err).Error("unable to refresh AI summaries")
+ } else {
+ logger.Info("successfully refreshed AI summaries")
+ }
+ }()
+ }
+
return nil
}
@@ -193,7 +228,7 @@ func (e *StrelkaEngine) IsRunning() bool {
}
func (e *StrelkaEngine) ValidateRule(data string) (string, error) {
- _, err := e.parseYaraRules([]byte(data), false)
+ _, err := e.parseYaraRules([]byte(data))
if err != nil {
return "", err
}
@@ -210,7 +245,7 @@ func (e *StrelkaEngine) ConvertRule(ctx context.Context, detect *model.Detection
}
func (s *StrelkaEngine) ExtractDetails(detect *model.Detection) error {
- rules, err := s.parseYaraRules([]byte(detect.Content), false)
+ rules, err := s.parseYaraRules([]byte(detect.Content))
if err != nil {
return err
}
@@ -258,10 +293,23 @@ func (e *StrelkaEngine) Sync(logger *log.Entry, forceSync bool) error {
e.writeNoRead = nil
+ if e.showAiSummaries {
+ err := detections.RefreshAiSummaries(e, model.SigLangYara, &e.isRunning, e.aiRepoPath, e.aiRepoUrl, e.aiRepoBranch, logger, e.IOManager)
+ if err != nil {
+ if errors.Is(err, detections.ErrModuleStopped) {
+ return err
+ }
+
+ logger.WithError(err).Error("unable to refresh AI summaries")
+ } else {
+ logger.Info("successfully refreshed AI summaries")
+ }
+ }
+
e.EngineState.Syncing = true
// ensure repos are up to date
- allRepos, anythingNew, err := detections.UpdateRepos(&e.isRunning, e.reposFolder, e.rulesRepos, e.srv.Config, e.IOManager)
+ allRepos, anythingNew, err := detections.UpdateRepos(&e.isRunning, e.reposFolder, e.rulesRepos, e.IOManager)
if err != nil {
if errors.Is(err, detections.ErrModuleStopped) {
return err
@@ -365,7 +413,7 @@ func (e *StrelkaEngine) Sync(logger *log.Entry, forceSync bool) error {
return nil
}
- parsed, err := e.parseYaraRules(raw, true)
+ parsed, err := e.parseYaraRules(raw)
if err != nil {
logger.WithError(err).WithField("yaraRuleFile", path).Error("failed to parse yara rule file")
return nil
@@ -485,7 +533,7 @@ func (e *StrelkaEngine) Sync(logger *log.Entry, forceSync bool) error {
})
if err != nil && err.Error() == "Object not found" {
e.writeNoRead = util.Ptr(detect.PublicID)
- logger.WithField("publicId", detect.PublicID).Error("unable to read back successful write")
+ logger.WithField("detectionPublicId", detect.PublicID).Error("unable to read back successful write")
break
}
@@ -496,7 +544,7 @@ func (e *StrelkaEngine) Sync(logger *log.Entry, forceSync bool) error {
}
if err != nil {
- logger.WithError(err).WithField("publicId", detect.PublicID).Error("failed to update detection")
+ logger.WithError(err).WithField("detectionPublicId", detect.PublicID).Error("failed to update detection")
continue
}
} else {
@@ -538,7 +586,7 @@ func (e *StrelkaEngine) Sync(logger *log.Entry, forceSync bool) error {
})
if err != nil && err.Error() == "Object not found" {
e.writeNoRead = util.Ptr(detect.PublicID)
- logger.WithField("publicId", detect.PublicID).Error("unable to read back successful write")
+ logger.WithField("detectionPublicId", detect.PublicID).Error("unable to read back successful write")
break
}
@@ -549,7 +597,7 @@ func (e *StrelkaEngine) Sync(logger *log.Entry, forceSync bool) error {
}
if err != nil {
- logger.WithError(err).WithField("publicId", detect.PublicID).Error("failed to create detection")
+ logger.WithError(err).WithField("detectionPublicId", detect.PublicID).Error("failed to create detection")
continue
}
}
@@ -610,7 +658,7 @@ func (e *StrelkaEngine) Sync(logger *log.Entry, forceSync bool) error {
},
})
if err != nil {
- logger.WithError(err).WithField("publicId", publicId).Error("Failed to delete unreferenced community detection")
+ logger.WithError(err).WithField("detectionPublicId", publicId).Error("Failed to delete unreferenced community detection")
continue
}
}
@@ -739,7 +787,7 @@ func (e *StrelkaEngine) Sync(logger *log.Entry, forceSync bool) error {
return nil
}
-func (e *StrelkaEngine) parseYaraRules(data []byte, filter bool) ([]*YaraRule, error) {
+func (e *StrelkaEngine) parseYaraRules(data []byte) ([]*YaraRule, error) {
rules := []*YaraRule{}
rule := &YaraRule{}
@@ -1030,7 +1078,7 @@ func (e *StrelkaEngine) syncDetections(ctx context.Context) (err error) {
}
func (e *StrelkaEngine) DuplicateDetection(ctx context.Context, detection *model.Detection) (*model.Detection, error) {
- rules, err := e.parseYaraRules([]byte(detection.Content), false)
+ rules, err := e.parseYaraRules([]byte(detection.Content))
if err != nil {
return nil, err
}
@@ -1084,6 +1132,40 @@ func (e *StrelkaEngine) DuplicateDetection(ctx context.Context, detection *model
return det, nil
}
+func (e *StrelkaEngine) LoadAuxiliaryData(summaries []*model.AiSummary) error {
+ sum := &sync.Map{}
+ for _, summary := range summaries {
+ sum.Store(summary.PublicId, summary)
+ }
+
+ e.aiSummaries = sum
+
+ log.WithFields(log.Fields{
+ "detectionEngine": model.EngineNameStrelka,
+ "aiSummaryCount": len(summaries),
+ }).Info("loaded AI summaries")
+
+ return nil
+}
+
+func (e *StrelkaEngine) MergeAuxiliaryData(detect *model.Detection) error {
+ if e.showAiSummaries {
+ obj, ok := e.aiSummaries.Load(detect.PublicID)
+ if ok {
+ sig := md5.Sum([]byte(detect.Content))
+ hexSig := hex.EncodeToString(sig[:])
+
+ summary := obj.(*model.AiSummary)
+ detect.AiFields = &model.AiFields{
+ AiSummary: summary.Summary,
+ AiSummaryReviewed: summary.Reviewed,
+ IsAiSummaryStale: !strings.EqualFold(summary.RuleBodyHash, hexSig),
+ }
+ }
+ }
+
+ return nil
+}
func (e *StrelkaEngine) GenerateUnusedPublicId(ctx context.Context) (string, error) {
// PublicIDs for Strelka are the rule name which should correlate with what the rule does.
// Cannot generate arbitrary but still useful public IDs
diff --git a/server/modules/strelka/strelka_test.go b/server/modules/strelka/strelka_test.go
index 4a372f44..0300e9b6 100644
--- a/server/modules/strelka/strelka_test.go
+++ b/server/modules/strelka/strelka_test.go
@@ -284,7 +284,7 @@ func TestSyncStrelka(t *testing.T) {
}{
{
Name: "Enable Simple Rules",
- InitMock: func(mockDetStore *servermock.MockDetectionstore, mio *mock.MockIOManager) {
+ InitMock: func(mockDetStore *servermock.MockDetectionstore, iom *mock.MockIOManager) {
mockDetStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{
"1": {
PublicID: "1",
@@ -300,11 +300,11 @@ func TestSyncStrelka(t *testing.T) {
},
}, nil)
- mio.EXPECT().ReadDir("yaraRulesFolder").Return(nil, nil)
+ iom.EXPECT().ReadDir("yaraRulesFolder").Return(nil, nil)
- mio.EXPECT().WriteFile(gomock.Any(), []byte(simpleRule), fs.FileMode(0644)).Return(nil).MaxTimes(2)
+ iom.EXPECT().WriteFile(gomock.Any(), []byte(simpleRule), fs.FileMode(0644)).Return(nil).MaxTimes(2)
- mio.EXPECT().ExecCommand(gomock.Cond(func(c any) bool {
+ iom.EXPECT().ExecCommand(gomock.Cond(func(c any) bool {
cmd := c.(*exec.Cmd)
if !strings.HasSuffix(cmd.Path, "python3") {
@@ -330,7 +330,7 @@ func TestSyncStrelka(t *testing.T) {
ctrl := gomock.NewController(t)
mockDetStore := servermock.NewMockDetectionstore(ctrl)
- mio := mock.NewMockIOManager(ctrl)
+ iom := mock.NewMockIOManager(ctrl)
mod := NewStrelkaEngine(&server.Server{
DetectionEngines: map[model.EngineName]server.DetectionEngine{},
@@ -338,12 +338,12 @@ func TestSyncStrelka(t *testing.T) {
})
mod.isRunning = true
mod.srv.DetectionEngines[model.EngineNameSuricata] = mod
- mod.IOManager = mio
+ mod.IOManager = iom
mod.compileYaraPythonScriptPath = "compileYaraPythonScriptPath"
mod.yaraRulesFolder = "yaraRulesFolder"
- test.InitMock(mockDetStore, mio)
+ test.InitMock(mockDetStore, iom)
errMap, err := mod.SyncLocalDetections(ctx, nil)
@@ -563,7 +563,7 @@ func TestParseRule(t *testing.T) {
t.Run(test.Name, func(t *testing.T) {
t.Parallel()
- rules, err := e.parseYaraRules([]byte(test.Input), true)
+ rules, err := e.parseYaraRules([]byte(test.Input))
if test.ExpectedError == nil {
assert.NoError(t, err)
assert.NotNil(t, rules)
@@ -643,7 +643,7 @@ func TestToDetection(t *testing.T) {
License: "license",
}
- rules, err := e.parseYaraRules([]byte(BasicRuleWMeta), false)
+ rules, err := e.parseYaraRules([]byte(BasicRuleWMeta))
assert.NoError(t, err)
assert.NotEmpty(t, rules)
assert.Equal(t, 1, len(rules))
@@ -722,12 +722,12 @@ func TestGetCompilationResult(t *testing.T) {
"compiled_sha256": "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
}`
- mio := mock.NewMockIOManager(ctrl)
- mio.EXPECT().ReadFile("/opt/so/state/detections_yara_compilation-total.log").Return([]byte(jsn), nil)
+ iom := mock.NewMockIOManager(ctrl)
+ iom.EXPECT().ReadFile("/opt/so/state/detections_yara_compilation-total.log").Return([]byte(jsn), nil)
eng := &StrelkaEngine{
yaraRulesFolder: "/opt/so/conf/strelka/rules",
- IOManager: mio,
+ IOManager: iom,
}
report, err := eng.getCompilationReport()
@@ -755,12 +755,12 @@ func TestVerifyCompiledHash(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
- mio := mock.NewMockIOManager(ctrl)
- mio.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return([]byte("abc"), nil).Times(3)
- mio.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return(nil, os.ErrNotExist).Times(2)
+ iom := mock.NewMockIOManager(ctrl)
+ iom.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return([]byte("abc"), nil).Times(3)
+ iom.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return(nil, os.ErrNotExist).Times(2)
eng := &StrelkaEngine{
- IOManager: mio,
+ IOManager: iom,
yaraRulesFolder: "/opt/so/conf/strelka/rules",
}
@@ -996,23 +996,6 @@ func TestSyncIncrementalNoChanges(t *testing.T) {
detStore := servermock.NewMockDetectionstore(ctrl)
iom := mock.NewMockIOManager(ctrl)
- // UpdateRepos
- iom.EXPECT().ReadDir("repos").Return([]fs.DirEntry{
- &handmock.MockDirEntry{
- Filename: "repo",
- Dir: true,
- },
- }, nil)
- iom.EXPECT().PullRepo(gomock.Any(), "repos/repo").Return(false, false)
- // WriteStateFile
- iom.EXPECT().WriteFile("stateFilePath", gomock.Any(), fs.FileMode(0644)).Return(nil)
- // IntegrityCheck
- iom.EXPECT().ReadFile("/opt/so/state/detections_yara_compilation-total.log").Return([]byte(`{"timestamp": "now", "success": ["publicId"], "failure": [], "compiled_sha256": "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}`), nil) // getCompilationReport
- iom.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return([]byte("abc"), nil) // verifyCompiledHash
- detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{
- "publicId": nil,
- }, nil)
-
eng := &StrelkaEngine{
srv: &server.Server{
Detectionstore: detStore,
@@ -1031,11 +1014,36 @@ func TestSyncIncrementalNoChanges(t *testing.T) {
IntegrityCheckerData: detections.IntegrityCheckerData{
IsRunning: true,
},
- IOManager: iom,
+ IOManager: iom,
+ showAiSummaries: true,
+ aiRepoUrl: "aiRepoUrl",
+ aiRepoBranch: "aiRepoBranch",
+ aiRepoPath: "aiRepoPath",
}
logger := log.WithField("detectionEngine", "test-strelka")
+ // RefreshAiSummaries
+ iom.EXPECT().ReadDir("aiRepoPath").Return([]fs.DirEntry{}, nil)
+ iom.EXPECT().CloneRepo(gomock.Any(), "aiRepoPath/aiRepoUrl", "aiRepoUrl", util.Ptr("aiRepoBranch")).Return(nil)
+ iom.EXPECT().ReadFile("aiRepoPath/aiRepoUrl/detections-ai/yara_summaries.yaml").Return([]byte("{}"), nil)
+ // UpdateRepos
+ iom.EXPECT().ReadDir("repos").Return([]fs.DirEntry{
+ &handmock.MockDirEntry{
+ Filename: "repo",
+ Dir: true,
+ },
+ }, nil)
+ iom.EXPECT().PullRepo(gomock.Any(), "repos/repo", nil).Return(false, false)
+ // WriteStateFile
+ iom.EXPECT().WriteFile("stateFilePath", gomock.Any(), fs.FileMode(0644)).Return(nil)
+ // IntegrityCheck
+ iom.EXPECT().ReadFile("/opt/so/state/detections_yara_compilation-total.log").Return([]byte(`{"timestamp": "now", "success": ["publicId"], "failure": [], "compiled_sha256": "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}`), nil) // getCompilationReport
+ iom.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return([]byte("abc"), nil) // verifyCompiledHash
+ detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{
+ "publicId": nil,
+ }, nil)
+
err := eng.Sync(logger, false)
assert.NoError(t, err)
@@ -1080,7 +1088,11 @@ func TestSyncChanges(t *testing.T) {
IntegrityCheckerData: detections.IntegrityCheckerData{
IsRunning: true,
},
- IOManager: iom,
+ IOManager: iom,
+ showAiSummaries: true,
+ aiRepoUrl: "aiRepoUrl",
+ aiRepoBranch: "aiRepoBranch",
+ aiRepoPath: "aiRepoPath",
}
logger := log.WithField("detectionEngine", "test-strelka")
@@ -1088,6 +1100,10 @@ func TestSyncChanges(t *testing.T) {
workItems := []esutil.BulkIndexerItem{}
auditItems := []esutil.BulkIndexerItem{}
+ // RefreshAiSummaries
+ iom.EXPECT().ReadDir("aiRepoPath").Return([]fs.DirEntry{}, nil)
+ iom.EXPECT().CloneRepo(gomock.Any(), "aiRepoPath/aiRepoUrl", "aiRepoUrl", util.Ptr("aiRepoBranch")).Return(nil)
+ iom.EXPECT().ReadFile("aiRepoPath/aiRepoUrl/detections-ai/yara_summaries.yaml").Return([]byte("{}"), nil)
// UpdateRepos
iom.EXPECT().ReadDir("repos").Return([]fs.DirEntry{
&handmock.MockDirEntry{
@@ -1095,7 +1111,7 @@ func TestSyncChanges(t *testing.T) {
Dir: true,
},
}, nil)
- iom.EXPECT().PullRepo(gomock.Any(), "repos/repo").Return(true, false)
+ iom.EXPECT().PullRepo(gomock.Any(), "repos/repo", nil).Return(true, false)
// Sync
detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{
"dummy": {
@@ -1237,3 +1253,86 @@ func TestSyncChanges(t *testing.T) {
assert.Equal(t, []string{"abc", "", "deleteme"}, workDocIds) // update has an id, create does not, delete does
}
+
+func TestLoadAndMergeAuxiliaryData(t *testing.T) {
+ tests := []struct {
+ Name string
+ PublicId string
+ Content string
+ ExpectedAiFields bool
+ ExpectedAiSummary string
+ ExpectedReviewed bool
+ ExpectedStale bool
+ }{
+ {
+ Name: "No Auxiliary Data",
+ PublicId: "Webshell_FOPO_Obfuscation_APT_ON_Nov17_1",
+ ExpectedAiFields: false,
+ },
+ {
+ Name: "Data, Unreviewed",
+ PublicId: "Webshell_acid_FaTaLisTiCz_Fx_fx_p0isoN_sh3ll_x0rg_byp4ss_256",
+ Content: "no-alert",
+ ExpectedAiFields: true,
+ ExpectedAiSummary: "Summary for Webshell_acid_FaTaLisTiCz_Fx_fx_p0isoN_sh3ll_x0rg_byp4ss_256",
+ ExpectedReviewed: false,
+ ExpectedStale: true,
+ },
+ {
+ Name: "Data, Reviewed",
+ PublicId: "_root_040_zip_Folder_deploy",
+ Content: "alert",
+ ExpectedAiFields: true,
+ ExpectedAiSummary: "Summary for _root_040_zip_Folder_deploy",
+ ExpectedReviewed: true,
+ },
+ }
+
+ e := StrelkaEngine{
+ showAiSummaries: true,
+ }
+ err := e.LoadAuxiliaryData([]*model.AiSummary{
+ {
+ PublicId: "_root_040_zip_Folder_deploy",
+ Summary: "Summary for _root_040_zip_Folder_deploy",
+ Reviewed: true,
+ RuleBodyHash: "7ed21143076d0cca420653d4345baa2f",
+ },
+ {
+ PublicId: "Webshell_acid_FaTaLisTiCz_Fx_fx_p0isoN_sh3ll_x0rg_byp4ss_256",
+ Summary: "Summary for Webshell_acid_FaTaLisTiCz_Fx_fx_p0isoN_sh3ll_x0rg_byp4ss_256",
+ Reviewed: false,
+ RuleBodyHash: "7ed21143076d0cca420653d4345baa2f",
+ },
+ })
+ assert.NoError(t, err)
+
+ for _, test := range tests {
+ test := test
+ t.Run(test.Name, func(t *testing.T) {
+ det := &model.Detection{
+ PublicID: test.PublicId,
+ Content: test.Content,
+ }
+
+ e.showAiSummaries = true
+ err := e.MergeAuxiliaryData(det)
+ assert.NoError(t, err)
+ if test.ExpectedAiFields {
+ assert.NotNil(t, det.AiFields)
+ assert.Equal(t, test.ExpectedAiSummary, det.AiSummary)
+ assert.Equal(t, test.ExpectedReviewed, det.AiSummaryReviewed)
+ assert.Equal(t, test.ExpectedStale, det.IsAiSummaryStale)
+ } else {
+ assert.Nil(t, det.AiFields)
+ }
+
+ e.showAiSummaries = false
+ det.AiFields = nil
+
+ err = e.MergeAuxiliaryData(det)
+ assert.NoError(t, err)
+ assert.Nil(t, det.AiFields)
+ })
+ }
+}
diff --git a/server/modules/suricata/migration-2.4.70_test.go b/server/modules/suricata/migration-2.4.70_test.go
index b3e39e0f..44a93414 100644
--- a/server/modules/suricata/migration-2.4.70_test.go
+++ b/server/modules/suricata/migration-2.4.70_test.go
@@ -53,11 +53,11 @@ func TestM2470ReadStateFile(t *testing.T) {
for _, test := range tests {
test := test
t.Run(test.Name, func(t *testing.T) {
- mio := mock.NewMockIOManager(ctrl)
- mio.EXPECT().ReadFile(idstoolsYaml).Return([]byte(test.Contents), nil)
+ iom := mock.NewMockIOManager(ctrl)
+ iom.EXPECT().ReadFile(idstoolsYaml).Return([]byte(test.Contents), nil)
e := &SuricataEngine{
- IOManager: mio,
+ IOManager: iom,
}
shouldMigrate, err := e.m2470ReadStateFile(idstoolsYaml)
@@ -71,11 +71,11 @@ func TestM2470WriteStateFileSuccess(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
- mio := mock.NewMockIOManager(ctrl)
- mio.EXPECT().WriteFile("stateFile", []byte("1"), fs.FileMode(0644)).Return(nil)
+ iom := mock.NewMockIOManager(ctrl)
+ iom.EXPECT().WriteFile("stateFile", []byte("1"), fs.FileMode(0644)).Return(nil)
e := &SuricataEngine{
- IOManager: mio,
+ IOManager: iom,
}
err := e.m2470WriteStateFileSuccess("stateFile")
@@ -86,11 +86,11 @@ func TestM2470LoadEnabledDisabled(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
- mio := mock.NewMockIOManager(ctrl)
- mio.EXPECT().ReadFile(idstoolsYaml).Return([]byte(`{ "idstools": { "sids": {"enabled": ["1", "2", "3"], "disabled": ["4", "5", "6"]} }}`), nil)
+ iom := mock.NewMockIOManager(ctrl)
+ iom.EXPECT().ReadFile(idstoolsYaml).Return([]byte(`{ "idstools": { "sids": {"enabled": ["1", "2", "3"], "disabled": ["4", "5", "6"]} }}`), nil)
e := &SuricataEngine{
- IOManager: mio,
+ IOManager: iom,
}
enabled, disabled, err := e.m2470LoadEnabledDisabled()
@@ -99,7 +99,7 @@ func TestM2470LoadEnabledDisabled(t *testing.T) {
assert.Equal(t, []string{"1", "2", "3"}, enabled)
assert.Equal(t, []string{"4", "5", "6"}, disabled)
- mio.EXPECT().ReadFile(idstoolsYaml).Return([]byte(`{}`), nil)
+ iom.EXPECT().ReadFile(idstoolsYaml).Return([]byte(`{}`), nil)
enabled, disabled, err = e.m2470LoadEnabledDisabled()
assert.NoError(t, err)
@@ -107,7 +107,7 @@ func TestM2470LoadEnabledDisabled(t *testing.T) {
assert.Equal(t, 0, len(enabled))
assert.Equal(t, 0, len(disabled))
- mio.EXPECT().ReadFile(idstoolsYaml).Return([]byte(`{ "idstools": {}}`), nil)
+ iom.EXPECT().ReadFile(idstoolsYaml).Return([]byte(`{ "idstools": {}}`), nil)
enabled, disabled, err = e.m2470LoadEnabledDisabled()
assert.NoError(t, err)
@@ -314,13 +314,13 @@ func TestM2470LoadOverrides(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
- mio := mock.NewMockIOManager(ctrl)
- mio.EXPECT().ReadFile(sidsYaml).Return([]byte(`{ "2013030": [ "suppress": {"gen_id": 1, "track": "by_src", "ip": "10.10.3.0/24"} ]}`), nil) // success
- mio.EXPECT().ReadFile(sidsYaml).Return(nil, errors.New("bad")) // bad error
- mio.EXPECT().ReadFile(sidsYaml).Return(nil, fs.ErrNotExist) // good error
+ iom := mock.NewMockIOManager(ctrl)
+ iom.EXPECT().ReadFile(sidsYaml).Return([]byte(`{ "2013030": [ "suppress": {"gen_id": 1, "track": "by_src", "ip": "10.10.3.0/24"} ]}`), nil) // success
+ iom.EXPECT().ReadFile(sidsYaml).Return(nil, errors.New("bad")) // bad error
+ iom.EXPECT().ReadFile(sidsYaml).Return(nil, fs.ErrNotExist) // good error
e := &SuricataEngine{
- IOManager: mio,
+ IOManager: iom,
}
// file is present and contains data
diff --git a/server/modules/suricata/suricata.go b/server/modules/suricata/suricata.go
index 23eb4fb4..9fbac039 100644
--- a/server/modules/suricata/suricata.go
+++ b/server/modules/suricata/suricata.go
@@ -8,6 +8,7 @@ package suricata
import (
"bytes"
"context"
+ "crypto/md5"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -55,6 +56,10 @@ const (
DEFAULT_COMMUNITY_RULES_IMPORT_ERROR_SECS = 300
DEFAULT_FAIL_AFTER_CONSECUTIVE_ERROR_COUNT = 10
DEFAULT_INTEGRITY_CHECK_FREQUENCY_SECONDS = 600
+ DEFAULT_AI_REPO = "https://github.com/Security-Onion-Solutions/securityonion-resources"
+ DEFAULT_AI_REPO_BRANCH = "generated-summaries-stable"
+ DEFAULT_AI_REPO_PATH = "/opt/sensoroni/repos"
+ DEFAULT_SHOW_AI_SUMMARIES = true
CUSTOM_RULE_LOC = "/nsm/rules/detect-suricata/custom_temp"
)
@@ -79,6 +84,11 @@ type SuricataEngine struct {
checkMigrationsOnce func()
enableRegex []*regexp.Regexp
disableRegex []*regexp.Regexp
+ aiSummaries *sync.Map // map[string]*detections.AiSummary{}
+ showAiSummaries bool
+ aiRepoUrl string
+ aiRepoBranch string
+ aiRepoPath string
detections.SyncSchedulerParams
detections.IntegrityCheckerData
detections.IOManager
@@ -113,6 +123,7 @@ func (e *SuricataEngine) Init(config module.ModuleConfig) (err error) {
e.InterruptChan = make(chan bool, 1)
e.IntegrityCheckerData.Thread = &sync.WaitGroup{}
e.IntegrityCheckerData.Interrupt = make(chan bool, 1)
+ e.aiSummaries = &sync.Map{}
e.communityRulesFile = module.GetStringDefault(config, "communityRulesFile", DEFAULT_COMMUNITY_RULES_FILE)
e.allRulesFile = module.GetStringDefault(config, "allRulesFile", DEFAULT_ALL_RULES_FILE)
@@ -155,6 +166,11 @@ func (e *SuricataEngine) Init(config module.ModuleConfig) (err error) {
return fmt.Errorf("unable to get custom rulesets: %w", err)
}
+ e.showAiSummaries = module.GetBoolDefault(config, "showAiSummaries", DEFAULT_SHOW_AI_SUMMARIES)
+ e.aiRepoUrl = module.GetStringDefault(config, "aiRepoUrl", DEFAULT_AI_REPO)
+ e.aiRepoBranch = module.GetStringDefault(config, "aiRepoBranch", DEFAULT_AI_REPO_BRANCH)
+ e.aiRepoPath = module.GetStringDefault(config, "aiRepoPath", DEFAULT_AI_REPO_PATH)
+
return nil
}
@@ -163,9 +179,28 @@ func (e *SuricataEngine) Start() error {
e.isRunning = true
e.IntegrityCheckerData.IsRunning = true
+ // start long running processes
go detections.SyncScheduler(e, &e.SyncSchedulerParams, &e.EngineState, model.EngineNameSuricata, &e.isRunning, e.IOManager)
go detections.IntegrityChecker(model.EngineNameSuricata, e, &e.IntegrityCheckerData, &e.EngineState.IntegrityFailure)
+ // update Ai Summaries once and don't block
+ if e.showAiSummaries {
+ go func() {
+ logger := log.WithField("detectionEngine", model.EngineNameSuricata)
+
+ err := detections.RefreshAiSummaries(e, model.SigLangSuricata, &e.isRunning, e.aiRepoPath, e.aiRepoUrl, e.aiRepoBranch, logger, e.IOManager)
+ if err != nil {
+ if errors.Is(err, detections.ErrModuleStopped) {
+ return
+ }
+
+ logger.WithError(err).Error("unable to refresh AI summaries")
+ } else {
+ logger.Info("successfully refreshed AI summaries")
+ }
+ }()
+ }
+
return nil
}
@@ -318,6 +353,19 @@ func (e *SuricataEngine) Sync(logger *log.Entry, forceSync bool) error {
e.writeNoRead = nil
+ if e.showAiSummaries {
+ err := detections.RefreshAiSummaries(e, model.SigLangSuricata, &e.isRunning, e.aiRepoPath, e.aiRepoUrl, e.aiRepoBranch, logger, e.IOManager)
+ if err != nil {
+ if errors.Is(err, detections.ErrModuleStopped) {
+ return err
+ }
+
+ logger.WithError(err).Error("unable to refresh AI summaries")
+ } else {
+ logger.Info("successfully refreshed AI summaries")
+ }
+ }
+
e.EngineState.Syncing = true
rules, hash, err := e.readAndHash(e.communityRulesFile)
@@ -1693,6 +1741,41 @@ func (e *SuricataEngine) DuplicateDetection(ctx context.Context, detection *mode
return det, nil
}
+func (e *SuricataEngine) LoadAuxiliaryData(summaries []*model.AiSummary) error {
+ sum := &sync.Map{}
+ for _, summary := range summaries {
+ sum.Store(summary.PublicId, summary)
+ }
+
+ e.aiSummaries = sum
+
+ log.WithFields(log.Fields{
+ "detectionEngine": model.EngineNameSuricata,
+ "aiSummaryCount": len(summaries),
+ }).Info("loaded AI summaries")
+
+ return nil
+}
+
+func (e *SuricataEngine) MergeAuxiliaryData(detect *model.Detection) error {
+ if e.showAiSummaries {
+ obj, ok := e.aiSummaries.Load(detect.PublicID)
+ if ok {
+ sig := md5.Sum([]byte(detect.Content))
+ hexSig := hex.EncodeToString(sig[:])
+
+ summary := obj.(*model.AiSummary)
+ detect.AiFields = &model.AiFields{
+ AiSummary: summary.Summary,
+ AiSummaryReviewed: summary.Reviewed,
+ IsAiSummaryStale: !strings.EqualFold(summary.RuleBodyHash, hexSig),
+ }
+ }
+ }
+
+ return nil
+}
+
func (e *SuricataEngine) ReadCustomRulesets() (detects []*model.Detection, err error) {
detects = []*model.Detection{}
diff --git a/server/modules/suricata/suricata_test.go b/server/modules/suricata/suricata_test.go
index 6582f6e5..aed2206f 100644
--- a/server/modules/suricata/suricata_test.go
+++ b/server/modules/suricata/suricata_test.go
@@ -2184,11 +2184,19 @@ func TestSyncIncrementalNoChanges(t *testing.T) {
IntegrityCheckerData: detections.IntegrityCheckerData{
IsRunning: true,
},
- IOManager: iom,
+ IOManager: iom,
+ showAiSummaries: true,
+ aiRepoUrl: "aiRepoUrl",
+ aiRepoBranch: "aiRepoBranch",
+ aiRepoPath: "aiRepoPath",
}
logger := log.WithField("detectionEngine", "test-suricata")
+ // RefreshAiSummaries
+ iom.EXPECT().ReadDir("aiRepoPath").Return([]fs.DirEntry{}, nil)
+ iom.EXPECT().CloneRepo(gomock.Any(), "aiRepoPath/aiRepoUrl", "aiRepoUrl", util.Ptr("aiRepoBranch")).Return(nil)
+ iom.EXPECT().ReadFile("aiRepoPath/aiRepoUrl/detections-ai/suricata_summaries.yaml").Return([]byte("{}"), nil)
// readAndHash
iom.EXPECT().ReadFile("communityRulesFile").Return([]byte(SimpleRule), nil)
// readFingerprint
@@ -2254,7 +2262,11 @@ func TestSyncChanges(t *testing.T) {
IntegrityCheckerData: detections.IntegrityCheckerData{
IsRunning: true,
},
- IOManager: iom,
+ IOManager: iom,
+ showAiSummaries: true,
+ aiRepoUrl: "aiRepoUrl",
+ aiRepoBranch: "aiRepoBranch",
+ aiRepoPath: "aiRepoPath",
}
logger := log.WithField("detectionEngine", "test-suricata")
@@ -2262,6 +2274,10 @@ func TestSyncChanges(t *testing.T) {
workItems := []esutil.BulkIndexerItem{}
auditItems := []esutil.BulkIndexerItem{}
+ // RefreshAiSummaries
+ iom.EXPECT().ReadDir("aiRepoPath").Return([]fs.DirEntry{}, nil)
+ iom.EXPECT().CloneRepo(gomock.Any(), "aiRepoPath/aiRepoUrl", "aiRepoUrl", util.Ptr("aiRepoBranch")).Return(nil)
+ iom.EXPECT().ReadFile("aiRepoPath/aiRepoUrl/detections-ai/suricata_summaries.yaml").Return([]byte("{}"), nil)
// readAndHash
iom.EXPECT().ReadFile("communityRulesFile").Return([]byte(SimpleRule+"\n"+FlowbitsRuleA), nil)
// syncCommunityDetections
@@ -2484,3 +2500,87 @@ func toRegex(s ...string) []*regexp.Regexp {
return r
}
+
+func TestLoadAndMergeAuxiliaryData(t *testing.T) {
+ tests := []struct {
+ Name string
+ PublicId string
+ Content string
+ ExpectedAiFields bool
+ ExpectedAiSummary string
+ ExpectedReviewed bool
+ ExpectedStale bool
+ }{
+ {
+ Name: "No Auxiliary Data",
+ PublicId: "100000",
+ Content: "alert",
+ ExpectedAiFields: false,
+ },
+ {
+ Name: "Data, Unreviewed",
+ PublicId: "100002",
+ Content: "no-alert",
+ ExpectedAiFields: true,
+ ExpectedAiSummary: "Summary for 100002",
+ ExpectedReviewed: false,
+ ExpectedStale: true,
+ },
+ {
+ Name: "Data, Reviewed",
+ PublicId: "100001",
+ Content: "alert",
+ ExpectedAiFields: true,
+ ExpectedAiSummary: "Summary for 100001",
+ ExpectedReviewed: true,
+ },
+ }
+
+ e := SuricataEngine{
+ showAiSummaries: true,
+ }
+ err := e.LoadAuxiliaryData([]*model.AiSummary{
+ {
+ PublicId: "100001",
+ Summary: "Summary for 100001",
+ RuleBodyHash: "7ed21143076d0cca420653d4345baa2f",
+ Reviewed: true,
+ },
+ {
+ PublicId: "100002",
+ Summary: "Summary for 100002",
+ RuleBodyHash: "7ed21143076d0cca420653d4345baa2f",
+ Reviewed: false,
+ },
+ })
+ assert.NoError(t, err)
+
+ for _, test := range tests {
+ test := test
+ t.Run(test.Name, func(t *testing.T) {
+ det := &model.Detection{
+ PublicID: test.PublicId,
+ Content: test.Content,
+ }
+
+ e.showAiSummaries = true
+ err := e.MergeAuxiliaryData(det)
+ assert.NoError(t, err)
+ if test.ExpectedAiFields {
+ assert.NotNil(t, det.AiFields)
+ assert.Equal(t, test.ExpectedAiSummary, det.AiSummary)
+ assert.Equal(t, test.ExpectedReviewed, det.AiSummaryReviewed)
+ assert.Equal(t, test.ExpectedStale, det.IsAiSummaryStale)
+ } else {
+ assert.Nil(t, det.AiFields)
+ }
+
+ e.showAiSummaries = false
+ det.AiFields = nil
+
+ err = e.MergeAuxiliaryData(det)
+ assert.NoError(t, err)
+ assert.Nil(t, det.AiFields)
+ })
+ }
+}