Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update eval_cancel_error logic to separate context canceled, timeout errors #7202

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions rego/rego_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func TestRegoCancellation(t *testing.T) {
if err == nil {
t.Fatalf("Expected cancellation error but got: %v", rs)
}
exp := topdown.Error{Code: topdown.CancelErr, Message: "caller cancelled query execution"}
exp := topdown.Error{Code: topdown.CancelErr, Message: context.DeadlineExceeded.Error()}
if !errors.Is(err, &exp) {
t.Errorf("error: expected %v, got: %v", exp, err)
}
Expand Down Expand Up @@ -1137,7 +1137,7 @@ func TestPrepareAndPartial(t *testing.T) {
mod := `
package test
import rego.v1

default p = false
p if {
input.x = 1
Expand Down Expand Up @@ -1326,7 +1326,7 @@ func TestPartialResultWithInput(t *testing.T) {
mod := `
package test
import rego.v1

default p = false
p if {
input.x == 1
Expand Down Expand Up @@ -1355,7 +1355,7 @@ func TestPartialResultWithNamespace(t *testing.T) {
mod := `
package test
import rego.v1

p if {
true
}
Expand Down Expand Up @@ -1398,7 +1398,7 @@ func TestPreparedPartialResultWithTracer(t *testing.T) {
mod := `
package test
import rego.v1

default p = false
p if {
input.x = 1
Expand Down Expand Up @@ -1440,7 +1440,7 @@ func TestPreparedPartialResultWithQueryTracer(t *testing.T) {
mod := `
package test
import rego.v1

default p = false
p if {
input.x = 1
Expand Down Expand Up @@ -2174,7 +2174,7 @@ func TestRegoLoadBundleWithProvidedStore(t *testing.T) {
func TestRegoCustomBuiltinPartialPropagate(t *testing.T) {
mod := `package test
import rego.v1

p if {
x = trim_and_split(input.foo, "/")
x == ["foo", "bar", "baz"]
Expand Down Expand Up @@ -2279,15 +2279,15 @@ func TestShallowInliningOption(t *testing.T) {
SetRegoVersion(ast.RegoV1),
Module("example.rego", `
package test

p if {
q = true
}

q if {
input.x = r
}

r = 7
`),
ShallowInlining(true))
Expand Down Expand Up @@ -2318,19 +2318,19 @@ func TestRegoPartialResultSortedRules(t *testing.T) {
SetRegoVersion(ast.RegoV1),
Module("example.rego", `
package test

default p = false

p if {
r = (input.d * input.a) + input.c
r < s
}

p if {
r = (input.d * input.b) + input.c
r < s
}

s = 100
`))

Expand Down
9 changes: 9 additions & 0 deletions tester/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package tester_test

import (
"context"
"errors"
"fmt"
"reflect"
"testing"
Expand Down Expand Up @@ -408,6 +409,10 @@ func testCancel(t *testing.T, bench bool) {
if !topdown.IsCancel(results[0].Error) {
t.Fatalf("Expected cancel error for first test but got: %v", results[0].Error)
}

if !errors.Is(results[0].Error, context.Canceled) {
t.Fatalf("Expected error to be of type context.Canceled but got: %v", results[0].Error)
}
})
}

Expand Down Expand Up @@ -476,6 +481,10 @@ func testTimeout(t *testing.T, bench bool) {
t.Fatalf("Expected no error for second test, but it timed out")
}
}

if !errors.Is(results[0].Error, context.DeadlineExceeded) {
t.Fatalf("Expected error to be of type context.DeadlineExceeded but got: %v", results[0].Error)
}
})
}

Expand Down
8 changes: 8 additions & 0 deletions topdown/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,14 @@ func (e *eval) evalExpr(iter evalIterator) error {
return &earlyExitError{prev: err, e: e}
}

if e.ctx != nil && e.ctx.Err() != nil {
return &Error{
Code: CancelErr,
Message: e.ctx.Err().Error(),
err: e.ctx.Err(),
}
}

if e.cancel != nil && e.cancel.Cancelled() {
return &Error{
Code: CancelErr,
Expand Down
78 changes: 78 additions & 0 deletions topdown/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ package topdown
import (
"context"
"encoding/json"
"errors"
"strings"
"testing"
"time"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/metrics"
Expand Down Expand Up @@ -1568,3 +1570,79 @@ func TestPartialRule(t *testing.T) {
})
}
}

func TestContextErrorHandling(t *testing.T) {
t.Parallel()

ctx := context.Background()
store := inmem.New()

tests := []struct {
note string
before func() context.Context
module string
expErr string
expErrType error
}{
{
note: "context deadline exceeded is handled",
before: func() context.Context {
ctx, cancel := context.WithTimeout(ctx, 1*time.Millisecond)
time.Sleep(10 * time.Millisecond)
cancel()
return ctx
},
module: `package test
p contains v if {
v := [1, 2, 3][_]
}
`,
expErr: context.DeadlineExceeded.Error(),
expErrType: context.DeadlineExceeded,
},
{
note: "context cancellation is handled",
before: func() context.Context {
ctx, cancel := context.WithCancel(ctx)
cancel()
return ctx
},
module: `package test
p contains v if {
v := [1, 2, 3][_]
}
`,
expErr: context.Canceled.Error(),
expErrType: context.Canceled,
},
}

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
t.Parallel()

compiler := compileModules([]string{tc.module})
txn := storage.NewTransactionOrDie(ctx, store)
defer store.Abort(ctx, txn)

query := NewQuery(ast.MustParseBody("")).
WithCompiler(compiler).
WithStore(store).
WithTransaction(txn)

testCtx := tc.before()
qrs, err := query.Run(testCtx)

if err == nil {
t.Fatalf("Expected error %v but got result: %v", tc.expErr, qrs)
}
if exp, act := tc.expErr, err.Error(); !strings.Contains(act, exp) {
t.Fatalf("Expected error %v but got: %v", exp, act)
}

if et := tc.expErrType; et != nil && !errors.Is(err, tc.expErrType) {
t.Fatalf("Expected error to be of type %#v, but got %#v", et, err)
}
})
}
}