Skip to content

Commit

Permalink
perf(misconf): parse rego input once (aquasecurity#6615)
Browse files Browse the repository at this point in the history
Signed-off-by: Simar <[email protected]>
Co-authored-by: Simar <[email protected]>
  • Loading branch information
nikpivkin and simar7 authored May 7, 2024
1 parent a2c522d commit 67c6b1d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
8 changes: 5 additions & 3 deletions pkg/iac/rego/exceptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package rego
import (
"context"
"fmt"

"github.com/open-policy-agent/opa/ast"
)

func (s *Scanner) isIgnored(ctx context.Context, namespace, ruleName string, input interface{}) (bool, error) {
func (s *Scanner) isIgnored(ctx context.Context, namespace, ruleName string, input ast.Value) (bool, error) {
if ignored, err := s.isNamespaceIgnored(ctx, namespace, input); err != nil {
return false, err
} else if ignored {
Expand All @@ -14,7 +16,7 @@ func (s *Scanner) isIgnored(ctx context.Context, namespace, ruleName string, inp
return s.isRuleIgnored(ctx, namespace, ruleName, input)
}

func (s *Scanner) isNamespaceIgnored(ctx context.Context, namespace string, input interface{}) (bool, error) {
func (s *Scanner) isNamespaceIgnored(ctx context.Context, namespace string, input ast.Value) (bool, error) {
exceptionQuery := fmt.Sprintf("data.namespace.exceptions.exception[_] == %q", namespace)
result, _, err := s.runQuery(ctx, exceptionQuery, input, true)
if err != nil {
Expand All @@ -23,7 +25,7 @@ func (s *Scanner) isNamespaceIgnored(ctx context.Context, namespace string, inpu
return result.Allowed(), nil
}

func (s *Scanner) isRuleIgnored(ctx context.Context, namespace, ruleName string, input interface{}) (bool, error) {
func (s *Scanner) isRuleIgnored(ctx context.Context, namespace, ruleName string, input ast.Value) (bool, error) {
exceptionQuery := fmt.Sprintf("endswith(%q, data.%s.exception[_][_])", ruleName, namespace)
result, _, err := s.runQuery(ctx, exceptionQuery, input, true)
if err != nil {
Expand Down
35 changes: 28 additions & 7 deletions pkg/iac/rego/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/util"

"github.com/aquasecurity/trivy/pkg/iac/debug"
"github.com/aquasecurity/trivy/pkg/iac/framework"
Expand Down Expand Up @@ -161,7 +162,7 @@ func (s *Scanner) SetParentDebugLogger(l debug.Logger) {
s.debug = l.Extend("rego")
}

func (s *Scanner) runQuery(ctx context.Context, query string, input interface{}, disableTracing bool) (rego.ResultSet, []string, error) {
func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, disableTracing bool) (rego.ResultSet, []string, error) {

trace := (s.traceWriter != nil || s.tracePerResult) && !disableTracing

Expand All @@ -180,7 +181,7 @@ func (s *Scanner) runQuery(ctx context.Context, query string, input interface{},
}

if input != nil {
regoOptions = append(regoOptions, rego.Input(input))
regoOptions = append(regoOptions, rego.ParsedInput(input))
}

instance := rego.New(regoOptions...)
Expand Down Expand Up @@ -342,6 +343,14 @@ func isPolicyApplicable(staticMetadata *StaticMetadata, inputs ...Input) bool {
return false
}

func parseRawInput(input any) (ast.Value, error) {
if err := util.RoundTrip(&input); err != nil {
return nil, err
}

return ast.InterfaceToValue(input)
}

func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs []Input, combined bool) (scan.Results, error) {

// handle combined evaluations if possible
Expand All @@ -354,7 +363,12 @@ func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs
qualified := fmt.Sprintf("data.%s.%s", namespace, rule)
for _, input := range inputs {
s.trace("INPUT", input)
if ignored, err := s.isIgnored(ctx, namespace, rule, input.Contents); err != nil {
parsedInput, err := parseRawInput(input.Contents)
if err != nil {
s.debug.Log("Error occurred while parsing input: %s", err)
continue
}
if ignored, err := s.isIgnored(ctx, namespace, rule, parsedInput); err != nil {
return nil, err
} else if ignored {
var result regoResult
Expand All @@ -364,7 +378,7 @@ func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs
results.AddIgnored(result)
continue
}
set, traces, err := s.runQuery(ctx, qualified, input.Contents, false)
set, traces, err := s.runQuery(ctx, qualified, parsedInput, false)
if err != nil {
return nil, err
}
Expand All @@ -388,9 +402,15 @@ func (s *Scanner) applyRuleCombined(ctx context.Context, namespace, rule string,
if len(inputs) == 0 {
return nil, nil
}

parsed, err := parseRawInput(inputs)
if err != nil {
return nil, fmt.Errorf("failed to parse input: %w", err)
}

var results scan.Results
qualified := fmt.Sprintf("data.%s.%s", namespace, rule)
if ignored, err := s.isIgnored(ctx, namespace, rule, inputs); err != nil {

if ignored, err := s.isIgnored(ctx, namespace, rule, parsed); err != nil {
return nil, err
} else if ignored {
for _, input := range inputs {
Expand All @@ -402,7 +422,8 @@ func (s *Scanner) applyRuleCombined(ctx context.Context, namespace, rule string,
}
return results, nil
}
set, traces, err := s.runQuery(ctx, qualified, inputs, false)
qualified := fmt.Sprintf("data.%s.%s", namespace, rule)
set, traces, err := s.runQuery(ctx, qualified, parsed, false)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 67c6b1d

Please sign in to comment.