diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index de939e8d0e9..8931008d7ff 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1612,6 +1612,49 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) { } } +func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + postureCheckIDs []string + expectedCount int + }{ + { + name: "retrieve existing posture checks by existing IDs", + postureCheckIDs: []string{"csplshq7qv948l48f7t0", "cspnllq7qv95uq1r4k90"}, + expectedCount: 2, + }, + { + name: "empty posture check IDs list", + postureCheckIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing posture check IDs", + postureCheckIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing posture check IDs", + postureCheckIDs: []string{"cspnllq7qv95uq1r4k90", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthShare, accountID, tt.postureCheckIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + func TestSqlStore_SavePostureChecks(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) @@ -1699,3 +1742,116 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) { }) } } + +func TestSqlStore_GetPolicyByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + policyID string + expectError bool + }{ + { + name: "retrieve existing policy", + policyID: "cs1tnh0hhcjnqoiuebf0", + expectError: false, + }, + { + name: "retrieve non-existing policy checks", + policyID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty policy ID", + policyID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, tt.policyID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, policy) + } else { + require.NoError(t, err) + require.NotNil(t, policy) + require.Equal(t, tt.policyID, policy.ID) + } + }) + } +} + +func TestSqlStore_CreatePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + policy := &Policy{ + ID: "policy-id", + AccountID: accountID, + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err = store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err) + + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + require.NoError(t, err) + require.Equal(t, savePolicy, policy) + +} + +func TestSqlStore_SavePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + policyID := "cs1tnh0hhcjnqoiuebf0" + + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + require.NoError(t, err) + + policy.Enabled = false + policy.Description = "policy" + err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err) + + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + require.NoError(t, err) + require.Equal(t, savePolicy, policy) +} + +func TestSqlStore_DeletePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + policyID := "cs1tnh0hhcjnqoiuebf0" + + err = store.DeletePolicy(context.Background(), LockingStrengthShare, accountID, policyID) + require.NoError(t, err) + + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + require.Error(t, err) + require.Nil(t, policy) +} diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 1646ff4da6c..37db2731625 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -35,4 +35,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-3465 INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); +INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); INSERT INTO installations VALUES(1,'');