diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index 57013022..db223da2 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -25,6 +25,7 @@ go_library( "//checker:go_default_library", "//common/ast:go_default_library", "//common/overloads:go_default_library", + "//common/operators:go_default_library", "//common/types:go_default_library", "//common/types/pb:go_default_library", "//common/types/ref:go_default_library", diff --git a/ext/sets.go b/ext/sets.go index 833c15f6..7e941665 100644 --- a/ext/sets.go +++ b/ext/sets.go @@ -19,6 +19,8 @@ import ( "github.com/google/cel-go/cel" "github.com/google/cel-go/checker" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" @@ -119,6 +121,68 @@ func (setsLib) ProgramOptions() []cel.ProgramOption { } } +// NewSetMembershipOptimizer rewrites set membership tests using the `in` operator against a list +// of constant values of enum, int, uint, string, or boolean type into a set membership test against +// a map where the map keys are the elements of the list. +func NewSetMembershipOptimizer() (cel.ASTOptimizer, error) { + return setsLib{}, nil +} + +func (setsLib) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + matches := ast.MatchDescendants(root, matchInConstantList(a)) + for _, match := range matches { + call := match.AsCall() + listArg := call.Args()[1] + entries := make([]ast.EntryExpr, len(listArg.AsList().Elements())) + for i, elem := range listArg.AsList().Elements() { + var entry ast.EntryExpr + if r, found := a.ReferenceMap()[elem.ID()]; found && r.Value != nil { + entry = ctx.NewMapEntry(ctx.NewLiteral(r.Value), ctx.NewLiteral(types.True), false) + } else { + entry = ctx.NewMapEntry(elem, ctx.NewLiteral(types.True), false) + } + entries[i] = entry + } + mapArg := ctx.NewMap(entries) + ctx.UpdateExpr(listArg, mapArg) + } + return a +} + +func matchInConstantList(a *ast.AST) ast.ExprMatcher { + return func(e ast.NavigableExpr) bool { + if e.Kind() != ast.CallKind { + return false + } + call := e.AsCall() + if call.FunctionName() != operators.In { + return false + } + aggregateVal := call.Args()[1] + if aggregateVal.Kind() != ast.ListKind { + return false + } + listVal := aggregateVal.AsList() + for _, elem := range listVal.Elements() { + if r, found := a.ReferenceMap()[elem.ID()]; found { + if r.Value != nil { + continue + } + } + if elem.Kind() != ast.LiteralKind { + return false + } + lit := elem.AsLiteral() + if !(lit.Type() == cel.StringType || lit.Type() == cel.IntType || + lit.Type() == cel.UintType || lit.Type() == cel.BoolType) { + return false + } + } + return true + } +} + func setsIntersects(listA, listB ref.Val) ref.Val { lA := listA.(traits.Lister) lB := listB.(traits.Lister) diff --git a/ext/sets_test.go b/ext/sets_test.go index ddaf96d8..26232010 100644 --- a/ext/sets_test.go +++ b/ext/sets_test.go @@ -22,6 +22,9 @@ import ( "github.com/google/cel-go/cel" "github.com/google/cel-go/checker" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/test/proto3pb" ) func TestSets(t *testing.T) { @@ -328,9 +331,174 @@ func TestSets(t *testing.T) { } } +func TestSetsMembershipRewriter(t *testing.T) { + tests := []struct { + expr string + optimized string + opts []cel.EnvOption + in map[string]any + out ref.Val + }{ + { + expr: `a in [1, 2, 3, 4]`, + optimized: `a in {1: true, 2: true, 3: true, 4: true}`, + opts: []cel.EnvOption{ + cel.Variable("a", cel.IntType), + }, + in: map[string]any{ + "a": 3, + }, + out: types.True, + }, + { + expr: `a in ['1', '2', '3', 4]`, + optimized: `a in {"1": true, "2": true, "3": true, 4: true}`, + opts: []cel.EnvOption{ + cel.Variable("a", cel.IntType), + }, + in: map[string]any{ + "a": 3, + }, + out: types.False, + }, + { + expr: `a in [1u, '2', '3', 4]`, + optimized: `a in {1u: true, "2": true, "3": true, 4: true}`, + opts: []cel.EnvOption{ + cel.Variable("a", cel.IntType), + }, + in: map[string]any{ + "a": 4, + }, + out: types.True, + }, + { + expr: `a in [1u, 2.0, '3', 4]`, + optimized: `a in [1u, 2.0, "3", 4]`, + opts: []cel.EnvOption{ + cel.Variable("a", cel.IntType), + }, + in: map[string]any{ + "a": 4, + }, + out: types.True, + }, + { + expr: `a in [b, 32]`, + optimized: `a in [b, 32]`, + opts: []cel.EnvOption{ + cel.Variable("a", cel.IntType), + cel.Variable("b", cel.IntType), + }, + in: map[string]any{ + "a": 4, + "b": 4, + }, + out: types.True, + }, + { + expr: `a in {b: c}`, + optimized: `a in {b: c}`, + opts: []cel.EnvOption{ + cel.Variable("a", cel.IntType), + cel.Variable("b", cel.IntType), + cel.Variable("c", cel.IntType), + }, + in: map[string]any{ + "a": 4, + "b": 42, + "c": 123, + }, + out: types.False, + }, + { + expr: `a in {3: true}`, + optimized: `a in {3: true}`, + opts: []cel.EnvOption{ + cel.Variable("a", cel.IntType), + }, + in: map[string]any{ + "a": 4, + }, + out: types.False, + }, + { + expr: `a in ["hello", "world"].map(i, i in ["goodbye", "world"], i + i)`, + optimized: `a in ["hello", "world"].map(i, i in {"goodbye": true, "world": true}, i + i)`, + opts: []cel.EnvOption{ + cel.Variable("a", cel.StringType), + }, + in: map[string]any{ + "a": "worldworld", + }, + out: types.True, + }, + { + expr: `a in [test.GlobalEnum.GOO, test.GlobalEnum.GAR, test.GlobalEnum.GAZ]`, + optimized: `a in {0: true, 1: true, 2: true}`, + opts: []cel.EnvOption{ + cel.Container("google.expr.proto3"), + cel.Variable("a", cel.IntType), + cel.Types(&proto3pb.TestAllTypes{}), + }, + in: map[string]any{ + "a": proto3pb.GlobalEnum_GAZ, + }, + out: types.True, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + env := testSetsEnv(t, tc.opts...) + var asts []*cel.Ast + a, iss := env.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Compile(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, a) + setsOpt, err := NewSetMembershipOptimizer() + if err != nil { + t.Fatalf("NewSetMembershipOptimizer() failed with error: %v", err) + } + opt := cel.NewStaticOptimizer(setsOpt) + optAST, iss := opt.Optimize(env, a) + if iss.Err() != nil { + t.Fatalf("opt.Optimize() failed: %v", iss.Err()) + } + optExpr, err := cel.AstToString(optAST) + if err != nil { + t.Fatalf("cel.AstToString() failed :%v", err) + } + if tc.optimized != optExpr { + t.Errorf("got %v, wanted optimized expr %v", optExpr, tc.optimized) + } + asts = append(asts, optAST) + + for _, ast := range asts { + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + in := tc.in + if in == nil { + in = map[string]any{} + } + out, _, err := prg.Eval(in) + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + if out != tc.out { + t.Errorf("prg.Eval() got %v, wanted %v for expr: %s", out, tc.out, tc.expr) + } + } + }) + } +} + func testSetsEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env { t.Helper() - baseOpts := []cel.EnvOption{Sets()} + baseOpts := []cel.EnvOption{cel.EnableMacroCallTracking(), Sets()} env, err := cel.NewEnv(append(baseOpts, opts...)...) if err != nil { t.Fatalf("cel.NewEnv(Sets()) failed: %v", err)