Skip to content

Commit

Permalink
refactor(misconf): pass options to Rego scanner as is
Browse files Browse the repository at this point in the history
Signed-off-by: nikpivkin <[email protected]>
  • Loading branch information
nikpivkin committed Sep 17, 2024
1 parent f768d3a commit 39bda25
Show file tree
Hide file tree
Showing 33 changed files with 560 additions and 1,013 deletions.
18 changes: 9 additions & 9 deletions pkg/iac/rego/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (s *Scanner) loadEmbedded() error {
return nil
}

func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies bool, srcFS fs.FS, paths []string, readers []io.Reader) error {
func (s *Scanner) LoadPolicies(srcFS fs.FS) error {

if s.policies == nil {
s.policies = make(map[string]*ast.Module)
Expand All @@ -90,28 +90,28 @@ func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies b
return err
}

if enableEmbeddedPolicies {
if s.includeEmbeddedPolicies {
s.policies = lo.Assign(s.policies, s.embeddedChecks)
}

if enableEmbeddedLibraries {
if s.includeEmbeddedLibraries {
s.policies = lo.Assign(s.policies, s.embeddedLibs)
}

var err error
if len(paths) > 0 {
loaded, err := LoadPoliciesFromDirs(srcFS, paths...)
if len(s.policyDirs) > 0 {
loaded, err := LoadPoliciesFromDirs(srcFS, s.policyDirs...)
if err != nil {
return fmt.Errorf("failed to load rego checks from %s: %w", paths, err)
return fmt.Errorf("failed to load rego checks from %s: %w", s.policyDirs, err)
}
for name, policy := range loaded {
s.policies[name] = policy
}
s.logger.Debug("Checks from disk are loaded", log.Int("count", len(loaded)))
}

if len(readers) > 0 {
loaded, err := s.loadPoliciesFromReaders(readers)
if len(s.policyReaders) > 0 {
loaded, err := s.loadPoliciesFromReaders(s.policyReaders)
if err != nil {
return fmt.Errorf("failed to load rego checks from reader(s): %w", err)
}
Expand Down Expand Up @@ -143,7 +143,7 @@ func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies b
}
s.store = store

return s.compilePolicies(srcFS, paths)
return s.compilePolicies(srcFS, s.policyDirs)
}

func (s *Scanner) fallbackChecks(compiler *ast.Compiler) {
Expand Down
49 changes: 31 additions & 18 deletions pkg/iac/rego/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"embed"
"fmt"
"io"
"log/slog"
"strings"
"testing"
Expand All @@ -16,7 +15,6 @@ import (

checks "github.com/aquasecurity/trivy-checks"
"github.com/aquasecurity/trivy/pkg/iac/rego"
"github.com/aquasecurity/trivy/pkg/iac/scanners/options"
"github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/aquasecurity/trivy/pkg/log"
)
Expand All @@ -33,10 +31,11 @@ func Test_RegoScanning_WithSomeInvalidPolicies(t *testing.T) {
slog.SetDefault(log.New(log.NewHandler(&debugBuf, nil)))
scanner := rego.NewScanner(
types.SourceDockerfile,
options.ScannerWithRegoErrorLimits(0),
rego.WithRegoErrorLimits(0),
rego.WithPolicyDirs("."),
)

err := scanner.LoadPolicies(false, false, testEmbedFS, []string{"."}, nil)
err := scanner.LoadPolicies(testEmbedFS)
require.ErrorContains(t, err, `want (one of): ["Cmd" "EndLine" "Flags" "JSON" "Original" "Path" "Stage" "StartLine" "SubCmd" "Value"]`)
assert.Contains(t, debugBuf.String(), "Error(s) occurred while loading checks")
})
Expand All @@ -46,10 +45,11 @@ func Test_RegoScanning_WithSomeInvalidPolicies(t *testing.T) {
slog.SetDefault(log.New(log.NewHandler(&debugBuf, nil)))
scanner := rego.NewScanner(
types.SourceDockerfile,
options.ScannerWithRegoErrorLimits(1),
rego.WithRegoErrorLimits(1),
rego.WithPolicyDirs("."),
)

err := scanner.LoadPolicies(false, false, testEmbedFS, []string{"."}, nil)
err := scanner.LoadPolicies(testEmbedFS)
require.NoError(t, err)

assert.Contains(t, debugBuf.String(), "Error occurred while parsing\tfile_path=\"testdata/policies/invalid.rego\" err=\"testdata/policies/invalid.rego:7")
Expand All @@ -64,9 +64,13 @@ package mypackage
deny {
input.evil == "foo bar"
}`
scanner := rego.NewScanner(types.SourceJSON)
scanner := rego.NewScanner(
types.SourceJSON,
rego.WithPolicyDirs("."),
rego.WithPolicyReader(strings.NewReader(check)),
)

err := scanner.LoadPolicies(false, false, fstest.MapFS{}, []string{"."}, []io.Reader{strings.NewReader(check)})
err := scanner.LoadPolicies(fstest.MapFS{})
assert.ErrorContains(t, err, "could not find schema \"fooschema\"")
})

Expand All @@ -79,15 +83,19 @@ package mypackage
deny {
input.evil == "foo bar"
}`
scanner := rego.NewScanner(types.SourceJSON)
scanner := rego.NewScanner(
types.SourceJSON,
rego.WithPolicyDirs("."),
rego.WithPolicyReader(strings.NewReader(check)),
)

fsys := fstest.MapFS{
"schemas/fooschema.json": &fstest.MapFile{
Data: []byte("bad json"),
},
}

err := scanner.LoadPolicies(false, false, fsys, []string{"."}, []io.Reader{strings.NewReader(check)})
err := scanner.LoadPolicies(fsys)
assert.ErrorContains(t, err, "could not parse schema \"fooschema\"")
})

Expand All @@ -97,8 +105,12 @@ deny {
deny {
input.evil == "foo bar"
}`
scanner := rego.NewScanner(types.SourceJSON)
err := scanner.LoadPolicies(false, false, fstest.MapFS{}, []string{"."}, []io.Reader{strings.NewReader(check)})
scanner := rego.NewScanner(
types.SourceJSON,
rego.WithPolicyDirs("."),
rego.WithPolicyReader(strings.NewReader(check)),
)
err := scanner.LoadPolicies(fstest.MapFS{})
require.NoError(t, err)
})

Expand Down Expand Up @@ -184,8 +196,9 @@ deny {
t.Run(tt.name, func(t *testing.T) {
scanner := rego.NewScanner(
types.SourceDockerfile,
options.ScannerWithRegoErrorLimits(0),
options.ScannerWithEmbeddedPolicies(false),
rego.WithRegoErrorLimits(0),
rego.WithEmbeddedPolicies(false),
rego.WithPolicyDirs("."),
)

tt.files["schemas/fooschema.json"] = &fstest.MapFile{
Expand All @@ -200,9 +213,8 @@ deny {
}`),
}

fsys := fstest.MapFS(tt.files)
checks.EmbeddedPolicyFileSystem = embeddedChecksFS
err := scanner.LoadPolicies(false, false, fsys, []string{"."}, nil)
err := scanner.LoadPolicies(fstest.MapFS(tt.files))

if tt.expectedErr != "" {
assert.ErrorContains(t, err, tt.expectedErr)
Expand Down Expand Up @@ -244,8 +256,9 @@ deny {

scanner := rego.NewScanner(
types.SourceDockerfile,
options.ScannerWithEmbeddedPolicies(false),
rego.WithEmbeddedPolicies(false),
rego.WithPolicyDirs("."),
)
err := scanner.LoadPolicies(false, false, fsys, []string{"."}, nil)
err := scanner.LoadPolicies(fsys)
require.Error(t, err)
}
108 changes: 108 additions & 0 deletions pkg/iac/rego/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package rego

import (
"io"
"io/fs"

"github.com/aquasecurity/trivy/pkg/iac/scanners/options"
)

func WithPolicyReader(readers ...io.Reader) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.policyReaders = readers
}
}
}

func WithEmbeddedPolicies(include bool) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.includeEmbeddedPolicies = include
}
}
}

func WithEmbeddedLibraries(include bool) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.includeEmbeddedLibraries = include
}
}
}

// WithTrace specifies an io.Writer for trace logs (mainly rego tracing) - if not set, they are discarded
func WithTrace(w io.Writer) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.traceWriter = w
}
}
}

func WithPerResultTracing(enabled bool) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.tracePerResult = enabled
}
}
}

func WithPolicyDirs(paths ...string) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.policyDirs = paths
}
}
}

func WithDataDirs(paths ...string) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.dataDirs = paths
}
}
}

// WithPolicyNamespaces - namespaces which indicate rego policies containing enforced rules
func WithPolicyNamespaces(namespaces ...string) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
for _, namespace := range namespaces {
ss.ruleNamespaces[namespace] = struct{}{}
}
}
}
}

func WithPolicyFilesystem(fsys fs.FS) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.policyFS = fsys
}
}
}

func WithDataFilesystem(fsys fs.FS) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.dataFS = fsys
}
}
}

func WithRegoErrorLimits(limit int) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.regoErrorLimit = limit
}
}
}

func WithCustomSchemas(schemas map[string][]byte) options.ScannerOption {
return func(s options.ConfigurableScanner) {
if ss, ok := s.(*Scanner); ok {
ss.customSchemas = schemas
}
}
}
40 changes: 22 additions & 18 deletions pkg/iac/rego/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,28 @@ func makeSupportedProviders() map[string]struct{} {
var _ options.ConfigurableScanner = (*Scanner)(nil)

type Scanner struct {
ruleNamespaces map[string]struct{}
policies map[string]*ast.Module
store storage.Store
dataDirs []string
runtimeValues *ast.Term
compiler *ast.Compiler
regoErrorLimit int
logger *log.Logger
traceWriter io.Writer
tracePerResult bool
retriever *MetadataRetriever
policyFS fs.FS
dataFS fs.FS
frameworks []framework.Framework
spec string
inputSchema any // unmarshalled into this from a json schema document
sourceType types.Source
includeDeprecatedChecks bool
ruleNamespaces map[string]struct{}
policies map[string]*ast.Module
store storage.Store
runtimeValues *ast.Term
compiler *ast.Compiler
regoErrorLimit int
logger *log.Logger
traceWriter io.Writer
tracePerResult bool
retriever *MetadataRetriever
policyFS fs.FS
policyDirs []string
policyReaders []io.Reader
dataFS fs.FS
dataDirs []string
frameworks []framework.Framework
spec string
inputSchema any // unmarshalled into this from a json schema document
sourceType types.Source
includeDeprecatedChecks bool
includeEmbeddedPolicies bool
includeEmbeddedLibraries bool

embeddedLibs map[string]*ast.Module
embeddedChecks map[string]*ast.Module
Expand Down
Loading

0 comments on commit 39bda25

Please sign in to comment.