Skip to content

Commit

Permalink
trace+tester: Adding local var values to trace and test report (#6815)
Browse files Browse the repository at this point in the history
Fixing: #2546
Signed-off-by: Johan Fylling <[email protected]>
  • Loading branch information
johanfylling authored Jun 25, 2024
1 parent 31120ce commit 96ecf38
Show file tree
Hide file tree
Showing 21 changed files with 3,061 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pull-request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-22.04, macos-14]
version: ["1.20"]
version: ["1.21"]
steps:
- uses: actions/checkout@v4
- name: Download generated artifacts
Expand Down
149 changes: 121 additions & 28 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,34 +120,35 @@ type Compiler struct {
// Capabliities required by the modules that were compiled.
Required *Capabilities

localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
stages []stage
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
imports map[string][]*Import // saved imports from stripping
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
evalMode CompilerEvalMode
localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
stages []stage
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
imports map[string][]*Import // saved imports from stripping
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
evalMode CompilerEvalMode //
rewriteTestRulesForTracing bool // rewrite test rules to capture dynamic values for tracing.
}

// CompilerStage defines the interface for stages in the compiler.
Expand Down Expand Up @@ -346,6 +347,7 @@ func NewCompiler() *Compiler {
{"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies},
{"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals},
{"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms},
{"RewriteTestRulesForTracing", "compile_stage_rewrite_test_rules_for_tracing", c.rewriteTestRuleEqualities}, // must run after RewriteDynamicTerms
{"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion},
{"CheckTypes", "compile_stage_check_types", c.checkTypes}, // must be run after CheckRecursion
{"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins},
Expand Down Expand Up @@ -469,6 +471,13 @@ func (c *Compiler) WithEvalMode(e CompilerEvalMode) *Compiler {
return c
}

// WithRewriteTestRules enables rewriting test rules to capture dynamic values in local variables,
// so they can be accessed by tracing.
func (c *Compiler) WithRewriteTestRules(rewrite bool) *Compiler {
c.rewriteTestRulesForTracing = rewrite
return c
}

// ParsedModules returns the parsed, unprocessed modules from the compiler.
// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
// The map includes all modules loaded via the ModuleLoader, if one was used.
Expand Down Expand Up @@ -2167,6 +2176,43 @@ func (c *Compiler) rewriteDynamicTerms() {
}
}

// rewriteTestRuleEqualities rewrites equality expressions in test rule bodies to create local vars for statements that would otherwise
// not have their values captured through tracing, such as refs and comprehensions not unified/assigned to a local var.
// For example, given the following module:
//
// package test
//
// p.q contains v if {
// some v in numbers.range(1, 3)
// }
//
// p.r := "foo"
//
// test_rule {
// p == {
// "q": {4, 5, 6}
// }
// }
//
// `p` in `test_rule` resolves to `data.test.p`, which won't be an entry in the virtual-cache and must therefore be calculated after-the-fact.
// If `p` isn't captured in a local var, there is no trivial way to retrieve its value for test reporting.
func (c *Compiler) rewriteTestRuleEqualities() {
if !c.rewriteTestRulesForTracing {
return
}

f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(rule *Rule) bool {
if strings.HasPrefix(string(rule.Head.Name), "test_") {
rule.Body = rewriteTestEqualities(f, rule.Body)
}
return false
})
}
}

func (c *Compiler) parseMetadataBlocks() {
// Only parse annotations if rego.metadata built-ins are called
regoMetadataCalled := false
Expand Down Expand Up @@ -4517,6 +4563,41 @@ func rewriteEquals(x interface{}) (modified bool) {
return modified
}

func rewriteTestEqualities(f *equalityFactory, body Body) Body {
result := make(Body, 0, len(body))
for _, expr := range body {
// We can't rewrite negated expressions; if the extracted term is undefined, evaluation would fail before
// reaching the negation check.
if !expr.Negated && !expr.Generated {
switch {
case expr.IsEquality():
terms := expr.Terms.([]*Term)
result, terms[1] = rewriteDynamicsShallow(expr, f, terms[1], result)
result, terms[2] = rewriteDynamicsShallow(expr, f, terms[2], result)
case expr.IsEvery():
// We rewrite equalities inside of every-bodies as a fail here will be the cause of the test-rule fail.
// Failures inside other expressions with closures, such as comprehensions, won't cause the test-rule to fail, so we skip those.
every := expr.Terms.(*Every)
every.Body = rewriteTestEqualities(f, every.Body)
}
}
result = appendExpr(result, expr)
}
return result
}

func rewriteDynamicsShallow(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
switch term.Value.(type) {
case Ref, *ArrayComprehension, *SetComprehension, *ObjectComprehension:
generated := f.Generate(term)
generated.With = original.With
result.Append(generated)
connectGeneratedExprs(original, generated)
return result, result[len(result)-1].Operand(0)
}
return result, term
}

// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and
// comprehensions) are bound to vars earlier in the query. This translation
// results in eager evaluation.
Expand Down Expand Up @@ -4608,6 +4689,7 @@ func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result B
generated := f.Generate(term)
generated.With = original.With
result.Append(generated)
connectGeneratedExprs(original, generated)
return result, result[len(result)-1].Operand(0)
case *Array:
for i := 0; i < v.Len(); i++ {
Expand Down Expand Up @@ -4636,16 +4718,19 @@ func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result B
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
connectGeneratedExprs(original, extra)
return result, result[len(result)-1].Operand(0)
case *SetComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
connectGeneratedExprs(original, extra)
return result, result[len(result)-1].Operand(0)
case *ObjectComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
connectGeneratedExprs(original, extra)
return result, result[len(result)-1].Operand(0)
}
return result, term
Expand Down Expand Up @@ -4713,6 +4798,7 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
for i := 1; i < len(terms); i++ {
var extras []*Expr
extras, terms[i] = expandExprTerm(gen, terms[i])
connectGeneratedExprs(expr, extras...)
if len(expr.With) > 0 {
for i := range extras {
extras[i].With = expr.With
Expand Down Expand Up @@ -4740,6 +4826,13 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
return
}

func connectGeneratedExprs(parent *Expr, children ...*Expr) {
for _, child := range children {
child.generatedFrom = parent
parent.generates = append(parent.generates, child)
}
}

func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) {
output = term
switch v := term.Value.(type) {
Expand Down
162 changes: 162 additions & 0 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10679,3 +10679,165 @@ deny {
t.Fatal(c.Errors)
}
}

func TestCompilerRewriteTestRulesForTracing(t *testing.T) {
tests := []struct {
note string
rewrite bool
module string
exp string
}{
{
note: "ref comparison, no rewrite",
module: `package test
import rego.v1
a := 1
b := 2
test_something if {
a == b
}`,
exp: `package test
a := 1 { true }
b := 2 { true }
test_something = true {
data.test.a = data.test.b
}`,
},
{
note: "ref comparison, rewrite",
rewrite: true,
module: `package test
import rego.v1
a := 1
b := 2
test_something if {
a == b
}`,
// When the test fails on '__local0__ = __local1__', the values for 'a' and 'b' are captured in local bindings,
// accessible by the tracer.
exp: `package test
a := 1 { true }
b := 2 { true }
test_something = true {
__local0__ = data.test.a
__local1__ = data.test.b
__local0__ = __local1__
}`,
},
{
note: "ref comparison, not-stmt, rewrite",
rewrite: true,
module: `package test
import rego.v1
a := 1
b := 2
test_something if {
not a == b
}`,
// We don't break out local vars from a not-stmt, as that would change the semantics of the rule.
exp: `package test
a := 1 { true }
b := 2 { true }
test_something = true {
not data.test.a = data.test.b
}`,
},
{
note: "ref comparison, inside every-stmt, no rewrite",
module: `package test
import rego.v1
a := 1
b := 2
l := [1, 2, 3]
test_something if {
every x in l {
a < b + x
}
}`,
exp: `package test
import future.keywords
a := 1 { true }
b := 2 { true }
l := [1, 2, 3] { true }
test_something = true {
__local2__ = data.test.l
every __local0__, __local1__ in __local2__ {
__local4__ = data.test.b
plus(__local4__, __local1__, __local3__)
__local5__ = data.test.a
lt(__local5__, __local3__)
}
}`,
},
{
note: "ref comparison, inside every-stmt, rewrite",
rewrite: true,
module: `package test
import rego.v1
a := 1
b := 2
l := [1, 2, 3]
test_something if {
every x in l {
a < b + x
}
}`,
// When tests contain an 'every' statement, we're interested in the circumstances that made the every fail,
// so it's body is rewritten.
exp: `package test
import future.keywords
a := 1 { true }
b := 2 { true }
l := [1, 2, 3] { true }
test_something = true {
__local2__ = data.test.l;
every __local0__, __local1__ in __local2__ {
__local4__ = data.test.b
plus(__local4__, __local1__, __local3__)
__local5__ = data.test.a
lt(__local5__, __local3__)
}
}`,
},
}

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
ms := map[string]string{
"test.rego": tc.module,
}
c := getCompilerWithParsedModules(ms).
WithRewriteTestRules(tc.rewrite)

compileStages(c, c.rewriteTestRuleEqualities)
assertNotFailed(t, c)

result := c.Modules["test.rego"]
exp := MustParseModule(tc.exp)
exp.Imports = nil // We strip the imports since the compiler will too
if result.Compare(exp) != 0 {
t.Fatalf("\nExpected:\n\n%v\n\nGot:\n\n%v", exp, result)
}
})
}
}
Loading

0 comments on commit 96ecf38

Please sign in to comment.