Skip to content

Commit

Permalink
feat(misconf): loading embedded checks as a fallback (#6502)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikpivkin authored Apr 19, 2024
1 parent 9b7d713 commit 12ec0df
Show file tree
Hide file tree
Showing 10 changed files with 271 additions and 42 deletions.
3 changes: 2 additions & 1 deletion pkg/cloud/report/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/aws/aws-sdk-go-v2/aws/arn"

ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/iac/rego"
"github.com/aquasecurity/trivy/pkg/iac/scan"
"github.com/aquasecurity/trivy/pkg/types"
)
Expand Down Expand Up @@ -57,7 +58,7 @@ func ConvertResults(results scan.Results, provider string, scoped []string) map[

// empty namespace implies a go rule from defsec, "builtin" refers to a built-in rego rule
// this ensures we don't generate bad links for custom policies
if result.RegoNamespace() == "" || strings.HasPrefix(result.RegoNamespace(), "builtin.") {
if result.RegoNamespace() == "" || rego.IsBuiltinNamespace(result.RegoNamespace()) {
primaryURL = fmt.Sprintf("https://avd.aquasec.com/misconfig/%s", strings.ToLower(result.Rule().AVDID))
}

Expand Down
8 changes: 2 additions & 6 deletions pkg/commands/artifact/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/fanal/walker"
"github.com/aquasecurity/trivy/pkg/flag"
"github.com/aquasecurity/trivy/pkg/iac/rego"
"github.com/aquasecurity/trivy/pkg/javadb"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/misconf"
Expand Down Expand Up @@ -50,11 +51,6 @@ const (
)

var (
defaultPolicyNamespaces = []string{
"appshield",
"defsec",
"builtin",
}
SkipScan = errors.New("skip subsequent processes")
)

Expand Down Expand Up @@ -598,7 +594,7 @@ func initScannerConfig(opts flag.Options, cacheClient cache.Cache) (ScannerConfi
configScannerOptions = misconf.ScannerOption{
Debug: opts.Debug,
Trace: opts.Trace,
Namespaces: append(opts.PolicyNamespaces, defaultPolicyNamespaces...),
Namespaces: append(opts.PolicyNamespaces, rego.BuiltinNamespaces()...),
PolicyPaths: append(opts.PolicyPaths, downloadedPolicyPaths...),
DataPaths: opts.DataPaths,
HelmValues: opts.HelmValues,
Expand Down
6 changes: 3 additions & 3 deletions pkg/iac/rego/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

"github.com/open-policy-agent/opa/ast"

rules2 "github.com/aquasecurity/trivy-policies"
checks "github.com/aquasecurity/trivy-policies"
"github.com/aquasecurity/trivy/pkg/iac/rules"
)

Expand Down Expand Up @@ -62,11 +62,11 @@ func RegisterRegoRules(modules map[string]*ast.Module) {
}

func LoadEmbeddedPolicies() (map[string]*ast.Module, error) {
return LoadPoliciesFromDirs(rules2.EmbeddedPolicyFileSystem, ".")
return LoadPoliciesFromDirs(checks.EmbeddedPolicyFileSystem, ".")
}

func LoadEmbeddedLibraries() (map[string]*ast.Module, error) {
return LoadPoliciesFromDirs(rules2.EmbeddedLibraryFileSystem, ".")
return LoadPoliciesFromDirs(checks.EmbeddedLibraryFileSystem, ".")
}

func LoadPoliciesFromDirs(target fs.FS, paths ...string) (map[string]*ast.Module, error) {
Expand Down
130 changes: 106 additions & 24 deletions pkg/iac/rego/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,25 @@ import (

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/bundle"
"github.com/samber/lo"
)

var builtinNamespaces = map[string]struct{}{
"builtin": {},
"defsec": {},
"appshield": {},
}

func BuiltinNamespaces() []string {
return lo.Keys(builtinNamespaces)
}

func IsBuiltinNamespace(namespace string) bool {
return lo.ContainsBy(BuiltinNamespaces(), func(ns string) bool {
return strings.HasPrefix(namespace, ns+".")
})
}

func IsRegoFile(name string) bool {
return strings.HasSuffix(name, bundle.RegoExt) && !strings.HasSuffix(name, "_test"+bundle.RegoExt)
}
Expand Down Expand Up @@ -38,28 +55,20 @@ func (s *Scanner) loadPoliciesFromReaders(readers []io.Reader) (map[string]*ast.
return modules, nil
}

func (s *Scanner) loadEmbedded(enableEmbeddedLibraries, enableEmbeddedPolicies bool) error {
if enableEmbeddedLibraries {
loadedLibs, errLoad := LoadEmbeddedLibraries()
if errLoad != nil {
return fmt.Errorf("failed to load embedded rego libraries: %w", errLoad)
}
for name, policy := range loadedLibs {
s.policies[name] = policy
}
s.debug.Log("Loaded %d embedded libraries.", len(loadedLibs))
func (s *Scanner) loadEmbedded() error {
loaded, err := LoadEmbeddedLibraries()
if err != nil {
return fmt.Errorf("failed to load embedded rego libraries: %w", err)
}
s.embeddedLibs = loaded
s.debug.Log("Loaded %d embedded libraries.", len(loaded))

if enableEmbeddedPolicies {
loaded, err := LoadEmbeddedPolicies()
if err != nil {
return fmt.Errorf("failed to load embedded rego policies: %w", err)
}
for name, policy := range loaded {
s.policies[name] = policy
}
s.debug.Log("Loaded %d embedded policies.", len(loaded))
loaded, err = LoadEmbeddedPolicies()
if err != nil {
return fmt.Errorf("failed to load embedded rego policies: %w", err)
}
s.embeddedChecks = loaded
s.debug.Log("Loaded %d embedded policies.", len(loaded))

return nil
}
Expand All @@ -71,14 +80,22 @@ func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies b
}

if s.policyFS != nil {
s.debug.Log("Overriding filesystem for policies!")
s.debug.Log("Overriding filesystem for checks!")
srcFS = s.policyFS
}

if err := s.loadEmbedded(enableEmbeddedLibraries, enableEmbeddedPolicies); err != nil {
if err := s.loadEmbedded(); err != nil {
return err
}

if enableEmbeddedPolicies {
s.policies = lo.Assign(s.policies, s.embeddedChecks)
}

if enableEmbeddedLibraries {
s.policies = lo.Assign(s.policies, s.embeddedLibs)
}

var err error
if len(paths) > 0 {
loaded, err := LoadPoliciesFromDirs(srcFS, paths...)
Expand All @@ -94,12 +111,12 @@ func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies b
if len(readers) > 0 {
loaded, err := s.loadPoliciesFromReaders(readers)
if err != nil {
return fmt.Errorf("failed to load rego policies from reader(s): %w", err)
return fmt.Errorf("failed to load rego checks from reader(s): %w", err)
}
for name, policy := range loaded {
s.policies[name] = policy
}
s.debug.Log("Loaded %d policies from reader(s).", len(loaded))
s.debug.Log("Loaded %d checks from reader(s).", len(loaded))
}

// gather namespaces
Expand Down Expand Up @@ -127,9 +144,73 @@ func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies b
return s.compilePolicies(srcFS, paths)
}

func (s *Scanner) fallbackChecks(compiler *ast.Compiler) {

var excludedFiles []string

for _, e := range compiler.Errors {
loc := e.Location.File

if lo.Contains(excludedFiles, loc) {
continue
}

badPolicy, exists := s.policies[loc]
if !exists || badPolicy == nil {
continue
}

if !IsBuiltinNamespace(getModuleNamespace(badPolicy)) {
continue
}

s.debug.Log("Error occurred while parsing: %s, %s. Trying to fallback to embedded check.", loc, e.Error())

embedded := s.findMatchedEmbeddedCheck(badPolicy)
if embedded == nil {
s.debug.Log("Failed to find embedded check: %s", loc)
continue
}

s.debug.Log("Found embedded check: %s", embedded.Package.Location.File)
delete(s.policies, loc) // remove bad policy
s.policies[embedded.Package.Location.File] = embedded
delete(s.embeddedChecks, embedded.Package.Location.File) // avoid infinite loop if embedded check contains ref error
excludedFiles = append(excludedFiles, e.Location.File)
}

compiler.Errors = lo.Filter(compiler.Errors, func(e *ast.Error, _ int) bool {
return !lo.Contains(excludedFiles, e.Location.File)
})
}

func (s *Scanner) findMatchedEmbeddedCheck(badPolicy *ast.Module) *ast.Module {
for _, embeddedCheck := range s.embeddedChecks {
if embeddedCheck.Package.Path.String() == badPolicy.Package.Path.String() {
return embeddedCheck
}
}

badPolicyMeta, err := metadataFromRegoModule(badPolicy)
if err != nil {
return nil
}

for _, embeddedCheck := range s.embeddedChecks {
meta, err := metadataFromRegoModule(embeddedCheck)
if err != nil {
continue
}
if badPolicyMeta.AVDID != "" && badPolicyMeta.AVDID == meta.AVDID {
return embeddedCheck
}
}
return nil
}

func (s *Scanner) prunePoliciesWithError(compiler *ast.Compiler) error {
if len(compiler.Errors) > s.regoErrorLimit {
s.debug.Log("Error(s) occurred while loading policies")
s.debug.Log("Error(s) occurred while loading checks")
return compiler.Errors
}

Expand Down Expand Up @@ -157,6 +238,7 @@ func (s *Scanner) compilePolicies(srcFS fs.FS, paths []string) error {

compiler.Compile(s.policies)
if compiler.Failed() {
s.fallbackChecks(compiler)
if err := s.prunePoliciesWithError(compiler); err != nil {
return err
}
Expand Down
115 changes: 114 additions & 1 deletion pkg/iac/rego/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"testing/fstest"

trivy_policies "github.com/aquasecurity/trivy-policies"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand All @@ -19,6 +20,9 @@ import (
//go:embed all:testdata/policies
var testEmbedFS embed.FS

//go:embed testdata/embedded
var embeddedChecksFS embed.FS

func Test_RegoScanning_WithSomeInvalidPolicies(t *testing.T) {
t.Run("allow no errors", func(t *testing.T) {
var debugBuf bytes.Buffer
Expand All @@ -30,7 +34,7 @@ func Test_RegoScanning_WithSomeInvalidPolicies(t *testing.T) {

err := scanner.LoadPolicies(false, false, testEmbedFS, []string{"."}, nil)
require.ErrorContains(t, err, `want (one of): ["Cmd" "EndLine" "Flags" "JSON" "Original" "Path" "Stage" "StartLine" "SubCmd" "Value"]`)
assert.Contains(t, debugBuf.String(), "Error(s) occurred while loading policies")
assert.Contains(t, debugBuf.String(), "Error(s) occurred while loading checks")
})

t.Run("allow up to max 1 error", func(t *testing.T) {
Expand Down Expand Up @@ -95,3 +99,112 @@ deny {
})

}

func Test_FallbackToEmbedded(t *testing.T) {
tests := []struct {
name string
files map[string]*fstest.MapFile
expectedErr string
}{
{
name: "match by namespace",
files: map[string]*fstest.MapFile{
"policies/my-check2.rego": {
Data: []byte(`# METADATA
# schemas:
# - input: schema["fooschema"]
package builtin.test
deny {
input.evil == "foo bar"
}`,
),
},
},
},
{
name: "match by check ID",
files: map[string]*fstest.MapFile{
"policies/my-check2.rego": {
Data: []byte(`# METADATA
# schemas:
# - input: schema["fooschema"]
# custom:
# avd_id: test-001
package builtin.test2
deny {
input.evil == "foo bar"
}`,
),
},
},
},
{
name: "bad embedded check",
files: map[string]*fstest.MapFile{
"policies/my-check2.rego": {
Data: []byte(`# METADATA
# schemas:
# - input: schema["fooschema"]
package builtin.bad.test
deny {
input.evil == "foo bar"
}`,
),
},
},
expectedErr: "testdata/embedded/bad-check.rego:8: rego_type_error: undefined ref",
},
{
name: "with non existent function",
files: map[string]*fstest.MapFile{
"policies/my-check2.rego": {
Data: []byte(`# METADATA
# schemas:
# - input: schema["fooschema"]
package builtin.test
deny {
input.foo == fn.is_foo("foo")
}`,
),
},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scanner := rego.NewScanner(
types.SourceDockerfile,
options.ScannerWithRegoErrorLimits(0),
options.ScannerWithEmbeddedPolicies(false),
)

tt.files["schemas/fooschema.json"] = &fstest.MapFile{
Data: []byte(`{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"foo": {
"type": "string"
}
}
}`),
}

fsys := fstest.MapFS(tt.files)
trivy_policies.EmbeddedPolicyFileSystem = embeddedChecksFS
err := scanner.LoadPolicies(false, false, fsys, []string{"."}, nil)

if tt.expectedErr != "" {
assert.ErrorContains(t, err, tt.expectedErr)
} else {
assert.NoError(t, err)
}
})
}
}
Loading

0 comments on commit 12ec0df

Please sign in to comment.