diff --git a/server/modules/influxdb/influxdbmetrics.go b/server/modules/influxdb/influxdbmetrics.go index 43c85220f..e41824048 100644 --- a/server/modules/influxdb/influxdbmetrics.go +++ b/server/modules/influxdb/influxdbmetrics.go @@ -16,6 +16,7 @@ import ( "github.com/apex/log" influxdb2 "github.com/influxdata/influxdb-client-go/v2" "github.com/influxdata/influxdb-client-go/v2/api" + "github.com/security-onion-solutions/securityonion-soc/licensing" "github.com/security-onion-solutions/securityonion-soc/model" "github.com/security-onion-solutions/securityonion-soc/server" ) @@ -499,6 +500,9 @@ func (metrics *InfluxDBMetrics) UpdateNodeMetrics(ctx context.Context, node *mod enhancedStatusEnabled := (metrics.client != nil) status = node.UpdateOverallStatus(enhancedStatusEnabled) + + licensing.ValidateFeature(licensing.FEAT_FPS, node.FpsEnabled == 1) + licensing.ValidateFeature(licensing.FEAT_LKS, node.LksEnabled == 1) } return status } diff --git a/server/modules/influxdb/influxdbmetrics_test.go b/server/modules/influxdb/influxdbmetrics_test.go index 579954391..193db01da 100644 --- a/server/modules/influxdb/influxdbmetrics_test.go +++ b/server/modules/influxdb/influxdbmetrics_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/security-onion-solutions/securityonion-soc/licensing" "github.com/security-onion-solutions/securityonion-soc/model" "github.com/security-onion-solutions/securityonion-soc/server" "github.com/stretchr/testify/assert" @@ -137,3 +138,91 @@ func TestGetOsNeedsRestart(tester *testing.T) { assert.Equal(tester, 1, metrics.getOsNeedsRestart("bar")) assert.Equal(tester, 0, metrics.getOsNeedsRestart("missing")) } + +func TestUpdateNodeMetricsLksUnlicensed(tester *testing.T) { + licensing.Test("odc", 0, 0, "", "") + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) + + metrics := NewInfluxDBMetrics(server.NewFakeAuthorizedServer(nil)) + metrics.lastOsUpdateTime = time.Now() + metrics.lksEnabled = make(map[string]int) + metrics.lksEnabled["id1"] = 1 + metrics.lksEnabled["id2"] = 0 + + node1 := model.NewNode("id1") + node2 := model.NewNode("id2") + node3 := model.NewNode("id3") + + metrics.UpdateNodeMetrics(context.Background(), node3) + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) + metrics.UpdateNodeMetrics(context.Background(), node2) + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) + metrics.UpdateNodeMetrics(context.Background(), node1) + assert.Equal(tester, licensing.LICENSE_STATUS_EXCEEDED, licensing.GetStatus()) +} + +func TestUpdateNodeMetricsLksLicensed(tester *testing.T) { + licensing.Test("lks", 0, 0, "", "") + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) + + metrics := NewInfluxDBMetrics(server.NewFakeAuthorizedServer(nil)) + metrics.lastOsUpdateTime = time.Now() + metrics.lksEnabled = make(map[string]int) + metrics.lksEnabled["id1"] = 1 + metrics.lksEnabled["id2"] = 0 + + node1 := model.NewNode("id1") + node2 := model.NewNode("id2") + node3 := model.NewNode("id3") + + metrics.UpdateNodeMetrics(context.Background(), node3) + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) + metrics.UpdateNodeMetrics(context.Background(), node2) + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) + metrics.UpdateNodeMetrics(context.Background(), node1) + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) +} + +func TestUpdateNodeMetricsFpsUnlicensed(tester *testing.T) { + licensing.Test("odc", 0, 0, "", "") + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) + + metrics := NewInfluxDBMetrics(server.NewFakeAuthorizedServer(nil)) + metrics.lastOsUpdateTime = time.Now() + metrics.fpsEnabled = make(map[string]int) + metrics.fpsEnabled["id1"] = 1 + metrics.fpsEnabled["id2"] = 0 + + node1 := model.NewNode("id1") + node2 := model.NewNode("id2") + node3 := model.NewNode("id3") + + metrics.UpdateNodeMetrics(context.Background(), node3) + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) + metrics.UpdateNodeMetrics(context.Background(), node2) + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) + metrics.UpdateNodeMetrics(context.Background(), node1) + assert.Equal(tester, licensing.LICENSE_STATUS_EXCEEDED, licensing.GetStatus()) +} + +func TestUpdateNodeMetricsFpsLicensed(tester *testing.T) { + licensing.Test("fps", 0, 0, "", "") + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) + + metrics := NewInfluxDBMetrics(server.NewFakeAuthorizedServer(nil)) + metrics.lastOsUpdateTime = time.Now() + metrics.fpsEnabled = make(map[string]int) + metrics.fpsEnabled["id1"] = 1 + metrics.fpsEnabled["id2"] = 0 + + node1 := model.NewNode("id1") + node2 := model.NewNode("id2") + node3 := model.NewNode("id3") + + metrics.UpdateNodeMetrics(context.Background(), node3) + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) + metrics.UpdateNodeMetrics(context.Background(), node2) + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) + metrics.UpdateNodeMetrics(context.Background(), node1) + assert.Equal(tester, licensing.LICENSE_STATUS_ACTIVE, licensing.GetStatus()) +}