Skip to content

Commit

Permalink
Revert "feat(scanner): Break out options for enabling libs and polici…
Browse files Browse the repository at this point in the history
…es (#1280)" (#1298)

This reverts commit 97ff1b4.
  • Loading branch information
simar7 authored Apr 28, 2023
1 parent f26eb8e commit 63a8b4f
Show file tree
Hide file tree
Showing 20 changed files with 142 additions and 217 deletions.
1 change: 0 additions & 1 deletion cmd/defsec/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ func scanAWS(stdout, stderr io.Writer) error {

opts := []options.ScannerOption{
options.ScannerWithEmbeddedPolicies(true),
options.ScannerWithEmbeddedLibraries(true),
}

if flagDebug {
Expand Down
1 change: 0 additions & 1 deletion cmd/defsec/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ func scanFS(dir string, stdout, stderr io.Writer) error {

opts := []options.ScannerOption{
options.ScannerWithEmbeddedPolicies(true),
options.ScannerWithEmbeddedLibraries(true),
}

if flagDebug {
Expand Down
35 changes: 12 additions & 23 deletions pkg/rego/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,18 @@ func (s *Scanner) LoadEmbeddedLibraries() error {
return nil
}

func (s *Scanner) loadEmbedded(enableEmbeddedLibraries, enableEmbeddedPolicies bool) error {
if enableEmbeddedLibraries {
func (s *Scanner) LoadPolicies(loadEmbedded bool, srcFS fs.FS, paths []string, readers []io.Reader) error {

if s.policies == nil {
s.policies = make(map[string]*ast.Module)
}

if s.policyFS != nil {
s.debug.Log("Overriding filesystem for policies!")
srcFS = s.policyFS
}

if loadEmbedded {
loadedLibs, errLoad := loadEmbeddedLibraries()
if errLoad != nil {
return fmt.Errorf("failed to load embedded rego libraries: %w", errLoad)
Expand All @@ -103,9 +113,6 @@ func (s *Scanner) loadEmbedded(enableEmbeddedLibraries, enableEmbeddedPolicies b
s.policies[name] = policy
}
s.debug.Log("Loaded %d embedded libraries.", len(loadedLibs))
}

if enableEmbeddedPolicies {
loaded, err := loadEmbeddedPolicies()
if err != nil {
return fmt.Errorf("failed to load embedded rego policies: %w", err)
Expand All @@ -116,24 +123,6 @@ func (s *Scanner) loadEmbedded(enableEmbeddedLibraries, enableEmbeddedPolicies b
s.debug.Log("Loaded %d embedded policies.", len(loaded))
}

return nil
}

func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies bool, srcFS fs.FS, paths []string, readers []io.Reader) error {

if s.policies == nil {
s.policies = make(map[string]*ast.Module)
}

if s.policyFS != nil {
s.debug.Log("Overriding filesystem for policies!")
srcFS = s.policyFS
}

if err := s.loadEmbedded(enableEmbeddedLibraries, enableEmbeddedPolicies); err != nil {
return err
}

var err error
if len(paths) > 0 {
loaded, err := s.loadPoliciesFromDirs(srcFS, paths)
Expand Down
5 changes: 0 additions & 5 deletions pkg/rego/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ type Scanner struct {
sourceType types.Source
}

func (s *Scanner) SetUseEmbeddedLibraries(b bool) {
// handled externally
}

func (s *Scanner) SetSpec(spec string) {
s.spec = spec
}
Expand All @@ -62,7 +58,6 @@ func (s *Scanner) SetFrameworks(frameworks []framework.Framework) {
func (s *Scanner) SetUseEmbeddedPolicies(b bool) {
// handled externally
}

func (s *Scanner) trace(heading string, input interface{}) {
if s.traceWriter == nil {
return
Expand Down
48 changes: 24 additions & 24 deletions pkg/rego/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ deny {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -70,7 +70,7 @@ deny {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"/policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"/policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -105,7 +105,7 @@ warn {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -137,7 +137,7 @@ deny {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -180,7 +180,7 @@ exception[ns] {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -228,7 +228,7 @@ exception[ns] {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -265,7 +265,7 @@ exception[rules] {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -301,7 +301,7 @@ exception[rules] {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -335,7 +335,7 @@ deny_evil {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -366,7 +366,7 @@ deny[msg] {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -404,7 +404,7 @@ deny[res] {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -446,7 +446,7 @@ deny[res] {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -500,7 +500,7 @@ deny[res] {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -549,7 +549,7 @@ deny {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -583,7 +583,7 @@ deny {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -614,7 +614,7 @@ deny {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -649,7 +649,7 @@ deny {
scanner := NewScanner(types.SourceJSON, options.ScannerWithTrace(traceBuffer))
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -683,7 +683,7 @@ deny {
scanner := NewScanner(types.SourceJSON, options.ScannerWithPerResultTracing(true))
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -721,7 +721,7 @@ deny {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -754,7 +754,7 @@ deny {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -801,7 +801,7 @@ deny {
scanner := NewScanner(types.SourceJSON)
require.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)

results, err := scanner.ScanInput(context.TODO(), Input{
Expand Down Expand Up @@ -839,7 +839,7 @@ deny {
scanner := NewScanner(types.SourceDockerfile)
assert.ErrorContains(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
"undefined ref: input.evil",
)
}
Expand All @@ -861,7 +861,7 @@ deny {
scanner := NewScanner(types.SourceDockerfile)
assert.NoError(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
)
}

Expand All @@ -880,7 +880,7 @@ deny {
scanner := NewScanner(types.SourceJSON)
assert.ErrorContains(
t,
scanner.LoadPolicies(false, false, srcFS, []string{"policies"}, nil),
scanner.LoadPolicies(false, srcFS, []string{"policies"}, nil),
"undefined ref: input.evil",
)
}
33 changes: 14 additions & 19 deletions pkg/scanners/azure/arm/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,17 @@ var _ scanners.FSScanner = (*Scanner)(nil)
var _ options.ConfigurableScanner = (*Scanner)(nil)

type Scanner struct {
scannerOptions []options.ScannerOption
parserOptions []options.ParserOption
debug debug.Logger
frameworks []framework.Framework
skipRequired bool
regoOnly bool
loadEmbeddedPolicies bool
loadEmbeddedLibraries bool
policyDirs []string
policyReaders []io.Reader
regoScanner *rego.Scanner
spec string
scannerOptions []options.ScannerOption
parserOptions []options.ParserOption
debug debug.Logger
frameworks []framework.Framework
skipRequired bool
regoOnly bool
loadEmbedded bool
policyDirs []string
policyReaders []io.Reader
regoScanner *rego.Scanner
spec string
sync.Mutex
}

Expand Down Expand Up @@ -88,12 +87,8 @@ func (s *Scanner) SetDataFilesystem(_ fs.FS) {
// handled by rego when option is passed on
}

func (s *Scanner) SetUseEmbeddedPolicies(b bool) {
s.loadEmbeddedPolicies = b
}

func (s *Scanner) SetUseEmbeddedLibraries(b bool) {
s.loadEmbeddedLibraries = b
func (s *Scanner) SetUseEmbeddedPolicies(loadEmbedded bool) {
s.loadEmbedded = loadEmbedded
}

func (s *Scanner) SetFrameworks(frameworks []framework.Framework) {
Expand All @@ -113,7 +108,7 @@ func (s *Scanner) initRegoScanner(srcFS fs.FS) error {
}
regoScanner := rego.NewScanner(types.SourceCloud, s.scannerOptions...)
regoScanner.SetParentDebugLogger(s.debug)
if err := regoScanner.LoadPolicies(s.loadEmbeddedLibraries, s.loadEmbeddedPolicies, srcFS, s.policyDirs, s.policyReaders); err != nil {
if err := regoScanner.LoadPolicies(s.loadEmbedded, srcFS, s.policyDirs, s.policyReaders); err != nil {
return err
}
s.regoScanner = regoScanner
Expand Down
Loading

0 comments on commit 63a8b4f

Please sign in to comment.