From 67c6b1d473999003d682bdb42657bbf3a4a69a9c Mon Sep 17 00:00:00 2001 From: Nikita Pivkin Date: Tue, 7 May 2024 10:20:38 +0600 Subject: [PATCH] perf(misconf): parse rego input once (#6615) Signed-off-by: Simar Co-authored-by: Simar --- pkg/iac/rego/exceptions.go | 8 +++++--- pkg/iac/rego/scanner.go | 35 ++++++++++++++++++++++++++++------- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/pkg/iac/rego/exceptions.go b/pkg/iac/rego/exceptions.go index 7abe9e8a9afa..591c72abdbbb 100644 --- a/pkg/iac/rego/exceptions.go +++ b/pkg/iac/rego/exceptions.go @@ -3,9 +3,11 @@ package rego import ( "context" "fmt" + + "github.com/open-policy-agent/opa/ast" ) -func (s *Scanner) isIgnored(ctx context.Context, namespace, ruleName string, input interface{}) (bool, error) { +func (s *Scanner) isIgnored(ctx context.Context, namespace, ruleName string, input ast.Value) (bool, error) { if ignored, err := s.isNamespaceIgnored(ctx, namespace, input); err != nil { return false, err } else if ignored { @@ -14,7 +16,7 @@ func (s *Scanner) isIgnored(ctx context.Context, namespace, ruleName string, inp return s.isRuleIgnored(ctx, namespace, ruleName, input) } -func (s *Scanner) isNamespaceIgnored(ctx context.Context, namespace string, input interface{}) (bool, error) { +func (s *Scanner) isNamespaceIgnored(ctx context.Context, namespace string, input ast.Value) (bool, error) { exceptionQuery := fmt.Sprintf("data.namespace.exceptions.exception[_] == %q", namespace) result, _, err := s.runQuery(ctx, exceptionQuery, input, true) if err != nil { @@ -23,7 +25,7 @@ func (s *Scanner) isNamespaceIgnored(ctx context.Context, namespace string, inpu return result.Allowed(), nil } -func (s *Scanner) isRuleIgnored(ctx context.Context, namespace, ruleName string, input interface{}) (bool, error) { +func (s *Scanner) isRuleIgnored(ctx context.Context, namespace, ruleName string, input ast.Value) (bool, error) { exceptionQuery := fmt.Sprintf("endswith(%q, data.%s.exception[_][_])", ruleName, namespace) result, _, err := s.runQuery(ctx, exceptionQuery, input, true) if err != nil { diff --git a/pkg/iac/rego/scanner.go b/pkg/iac/rego/scanner.go index 001e8f52a080..f2b9fff0fdf9 100644 --- a/pkg/iac/rego/scanner.go +++ b/pkg/iac/rego/scanner.go @@ -13,6 +13,7 @@ import ( "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/storage" + "github.com/open-policy-agent/opa/util" "github.com/aquasecurity/trivy/pkg/iac/debug" "github.com/aquasecurity/trivy/pkg/iac/framework" @@ -161,7 +162,7 @@ func (s *Scanner) SetParentDebugLogger(l debug.Logger) { s.debug = l.Extend("rego") } -func (s *Scanner) runQuery(ctx context.Context, query string, input interface{}, disableTracing bool) (rego.ResultSet, []string, error) { +func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, disableTracing bool) (rego.ResultSet, []string, error) { trace := (s.traceWriter != nil || s.tracePerResult) && !disableTracing @@ -180,7 +181,7 @@ func (s *Scanner) runQuery(ctx context.Context, query string, input interface{}, } if input != nil { - regoOptions = append(regoOptions, rego.Input(input)) + regoOptions = append(regoOptions, rego.ParsedInput(input)) } instance := rego.New(regoOptions...) @@ -342,6 +343,14 @@ func isPolicyApplicable(staticMetadata *StaticMetadata, inputs ...Input) bool { return false } +func parseRawInput(input any) (ast.Value, error) { + if err := util.RoundTrip(&input); err != nil { + return nil, err + } + + return ast.InterfaceToValue(input) +} + func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs []Input, combined bool) (scan.Results, error) { // handle combined evaluations if possible @@ -354,7 +363,12 @@ func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs qualified := fmt.Sprintf("data.%s.%s", namespace, rule) for _, input := range inputs { s.trace("INPUT", input) - if ignored, err := s.isIgnored(ctx, namespace, rule, input.Contents); err != nil { + parsedInput, err := parseRawInput(input.Contents) + if err != nil { + s.debug.Log("Error occurred while parsing input: %s", err) + continue + } + if ignored, err := s.isIgnored(ctx, namespace, rule, parsedInput); err != nil { return nil, err } else if ignored { var result regoResult @@ -364,7 +378,7 @@ func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs results.AddIgnored(result) continue } - set, traces, err := s.runQuery(ctx, qualified, input.Contents, false) + set, traces, err := s.runQuery(ctx, qualified, parsedInput, false) if err != nil { return nil, err } @@ -388,9 +402,15 @@ func (s *Scanner) applyRuleCombined(ctx context.Context, namespace, rule string, if len(inputs) == 0 { return nil, nil } + + parsed, err := parseRawInput(inputs) + if err != nil { + return nil, fmt.Errorf("failed to parse input: %w", err) + } + var results scan.Results - qualified := fmt.Sprintf("data.%s.%s", namespace, rule) - if ignored, err := s.isIgnored(ctx, namespace, rule, inputs); err != nil { + + if ignored, err := s.isIgnored(ctx, namespace, rule, parsed); err != nil { return nil, err } else if ignored { for _, input := range inputs { @@ -402,7 +422,8 @@ func (s *Scanner) applyRuleCombined(ctx context.Context, namespace, rule string, } return results, nil } - set, traces, err := s.runQuery(ctx, qualified, inputs, false) + qualified := fmt.Sprintf("data.%s.%s", namespace, rule) + set, traces, err := s.runQuery(ctx, qualified, parsed, false) if err != nil { return nil, err }