diff --git a/pkg/cmd/flags/policy_test.go b/pkg/cmd/flags/policy_test.go index f5e286c17f47..6e5d33613423 100644 --- a/pkg/cmd/flags/policy_test.go +++ b/pkg/cmd/flags/policy_test.go @@ -8,8 +8,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/aquasecurity/tracee/pkg/events" "github.com/aquasecurity/tracee/pkg/filters" k8s "github.com/aquasecurity/tracee/pkg/k8s/apis/tracee.aquasec.com/v1beta1" + "github.com/aquasecurity/tracee/pkg/policy" "github.com/aquasecurity/tracee/pkg/policy/v1beta1" ) @@ -2125,3 +2127,334 @@ func TestCreatePolicies(t *testing.T) { }) } } + +func TestCreateSinglePolicy(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + policyIdx int + scope policyScopes + events policyEvents + newBinary bool + wantPolicy func() *policy.Policy + wantErr error + }{ + { + name: "basic policy", + policyIdx: 1, + scope: policyScopes{ + policyName: "test-policy", + scopeFlags: []scopeFlag{{ + full: "comm=bash", + scopeName: "comm", + operator: "=", + operatorAndValues: "=bash", + }}, + }, + events: policyEvents{ + policyName: "test-policy", + eventFlags: []eventFlag{{ + full: "write", + eventName: "write", + }}, + }, + wantPolicy: func() *policy.Policy { + p := policy.NewPolicy() + p.ID = 1 + p.Name = "test-policy" + p.CommFilter = filters.NewStringFilter(nil) + _ = p.CommFilter.Parse("=bash") + p.Rules[events.Write] = policy.RuleData{ + EventID: events.Write, + ScopeFilter: filters.NewScopeFilter(), + DataFilter: filters.NewDataFilter(), + RetFilter: filters.NewIntFilter(), + } + return p + }, + }, + { + name: "multiple filters", + policyIdx: 2, + scope: policyScopes{ + policyName: "multi-filter", + scopeFlags: []scopeFlag{ + { + full: "uid=1000", + scopeName: "uid", + operator: "=", + operatorAndValues: "=1000", + }, + { + full: "container", + scopeName: "container", + }, + }, + }, + events: policyEvents{ + policyName: "multi-filter", + eventFlags: []eventFlag{ + { + full: "open", + eventName: "open", + }, + { + full: "write.retval=0", + eventName: "write", + eventOptionType: "retval", + operatorAndValues: "=0", + }, + }, + }, + wantPolicy: func() *policy.Policy { + p := policy.NewPolicy() + p.ID = 2 + p.Name = "multi-filter" + p.UIDFilter = filters.NewUInt32Filter() + _ = p.UIDFilter.Parse("=1000") + p.ContFilter = filters.NewBoolFilter() + _ = p.ContFilter.Parse("container") + + p.Rules[events.Open] = policy.RuleData{ + EventID: events.Open, + ScopeFilter: filters.NewScopeFilter(), + DataFilter: filters.NewDataFilter(), + RetFilter: filters.NewIntFilter(), + } + p.Rules[events.Write] = policy.RuleData{ + EventID: events.Write, + ScopeFilter: filters.NewScopeFilter(), + DataFilter: filters.NewDataFilter(), + RetFilter: filters.NewIntFilter(), + } + _ = p.Rules[events.Write].RetFilter.Parse("=0") + return p + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got, err := createSinglePolicy(tc.policyIdx, tc.scope, tc.events, tc.newBinary) + + if tc.wantErr != nil { + require.Error(t, err) + assert.Equal(t, tc.wantErr.Error(), err.Error()) + return + } + + require.NoError(t, err) + want := tc.wantPolicy() + assert.Equal(t, want, got) + }) + } +} + +func TestParseScopeFilters(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + policy *policy.Policy + scopeFlags []scopeFlag + newBinary bool + wantErr error + validate func(*testing.T, *policy.Policy) + }{ + { + name: "single comm filter", + policy: policy.NewPolicy(), + scopeFlags: []scopeFlag{{ + full: "comm=bash", + scopeName: "comm", + operator: "=", + operatorAndValues: "=bash", + }}, + validate: func(t *testing.T, p *policy.Policy) { + assert.NotNil(t, p.CommFilter) + }, + }, + { + name: "container filter variations", + policy: policy.NewPolicy(), + scopeFlags: []scopeFlag{ + { + full: "container", + scopeName: "container", + }, + { + full: "container=new", + scopeName: "container", + operator: "=", + operatorAndValues: "=new", + }, + }, + validate: func(t *testing.T, p *policy.Policy) { + assert.NotNil(t, p.ContFilter) + assert.NotNil(t, p.NewContFilter) + }, + }, + { + name: "invalid scope filter", + policy: policy.NewPolicy(), + scopeFlags: []scopeFlag{{ + full: "invalid=value", + scopeName: "invalid", + operator: "=", + operatorAndValues: "=value", + }}, + wantErr: InvalidScopeOptionError("invalid=value", false), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := parseScopeFilters(tc.policy, tc.scopeFlags, tc.newBinary) + + if tc.wantErr != nil { + require.Error(t, err) + assert.Equal(t, tc.wantErr.Error(), err.Error()) + return + } + + require.NoError(t, err) + if tc.validate != nil { + tc.validate(t, tc.policy) + } + }) + } +} + +func TestParseEventFilters(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + policy *policy.Policy + eventFlags []eventFlag + wantErr error + validate func(*testing.T, *policy.Policy) + }{ + { + name: "basic event", + policy: policy.NewPolicy(), + eventFlags: []eventFlag{{ + full: "write", + eventName: "write", + }}, + validate: func(t *testing.T, p *policy.Policy) { + assert.Contains(t, p.Rules, events.Write) + }, + }, + { + name: "event with retval filter", + policy: policy.NewPolicy(), + eventFlags: []eventFlag{{ + full: "write.retval=0", + eventName: "write", + eventOptionType: "retval", + operatorAndValues: "=0", + }}, + validate: func(t *testing.T, p *policy.Policy) { + assert.Contains(t, p.Rules, events.Write) + assert.NotNil(t, p.Rules[events.Write].RetFilter) + }, + }, + { + name: "event with data filter", + policy: policy.NewPolicy(), + eventFlags: []eventFlag{{ + full: "openat.data.pathname=/etc/passwd", + eventName: "openat", + eventOptionType: "data", + eventOptionName: "pathname", + operatorAndValues: "=/etc/passwd", + }}, + validate: func(t *testing.T, p *policy.Policy) { + assert.Contains(t, p.Rules, events.Openat) + assert.NotNil(t, p.Rules[events.Openat].DataFilter) + }, + }, + { + name: "wildcard event", + policy: policy.NewPolicy(), + eventFlags: []eventFlag{{ + full: "sched_process_*", + eventName: "sched_process_*", + }}, + validate: func(t *testing.T, p *policy.Policy) { + // Check that all sched_process events are included + assert.Contains(t, p.Rules, events.SchedProcessExec) + assert.Contains(t, p.Rules, events.SchedProcessFork) + assert.Contains(t, p.Rules, events.SchedProcessExit) + }, + }, + { + name: "wildcard event with filter", + policy: policy.NewPolicy(), + eventFlags: []eventFlag{{ + full: "sched_process_*", + eventName: "sched_process_*", + }, { + full: "sched_process_exec.retval=0", + eventName: "sched_process_exec", + eventOptionType: "retval", + operatorAndValues: "=0", + }}, + validate: func(t *testing.T, p *policy.Policy) { + assert.Contains(t, p.Rules, events.SchedProcessExec) + assert.Contains(t, p.Rules, events.SchedProcessFork) + assert.Contains(t, p.Rules, events.SchedProcessExit) + // Check that retval filter is applied only to sched_process_exec event + assert.NotNil(t, p.Rules[events.SchedProcessExec].RetFilter) + assert.NotNil(t, p.Rules[events.SchedProcessFork].RetFilter) + assert.NotNil(t, p.Rules[events.SchedProcessExit].RetFilter) + }, + }, + { + name: "non-existing event", + policy: policy.NewPolicy(), + eventFlags: []eventFlag{{ + full: "nonexistent", + eventName: "nonexistent", + }}, + wantErr: InvalidEventError("nonexistent"), + }, + { + name: "non-existing event expansion", + policy: policy.NewPolicy(), + eventFlags: []eventFlag{{ + full: "nonexistent*", + eventName: "nonexistent*", + }}, + wantErr: InvalidEventError("nonexistent*"), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := parseEventFilters(tc.policy, tc.eventFlags) + + if tc.wantErr != nil { + require.Error(t, err) + assert.Equal(t, tc.wantErr.Error(), err.Error()) + return + } + + require.NoError(t, err) + if tc.validate != nil { + tc.validate(t, tc.policy) + } + }) + } +}