Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decompose the compile and compose methods #986

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions policy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ go_library(
srcs = [
"compiler.go",
"conformance.go",
"composer.go",
"config.go",
"parser.go",
"source.go",
Expand All @@ -33,6 +34,7 @@ go_library(
"//cel:go_default_library",
"//common:go_default_library",
"//common/ast:go_default_library",
"//common/decls:go_default_library",
"//common/operators:go_default_library",
"//common/types:go_default_library",
"//ext:go_default_library",
Expand All @@ -46,7 +48,7 @@ go_test(
srcs = [
"compiler_test.go",
"config_test.go",
"helper_test.go",
"helper_test.go",
"parser_test.go",
],
data = glob(["testdata/**"]),
Expand All @@ -58,4 +60,4 @@ go_test(
"//test/proto3pb:go_default_library",
"@in_gopkg_yaml_v3//:go_default_library",
],
)
)
219 changes: 117 additions & 102 deletions policy/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,111 @@ import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/types"
)

type compiler struct {
env *cel.Env
info *ast.SourceInfo
src *Source
// CompiledRule represents the variables and match blocks associated with a rule block.
type CompiledRule struct {
id *ValueString
variables []*CompiledVariable
matches []*CompiledMatch
}

type compiledRule struct {
variables []*compiledVariable
matches []*compiledMatch
// ID returns the expression id associated with the rule.
func (r *CompiledRule) ID() *ValueString {
return r.id
}

type compiledVariable struct {
name string
expr *cel.Ast
// Variables rturns the list of CompiledVariable values associated with the rule.
func (r *CompiledRule) Variables() []*CompiledVariable {
return r.variables[:]
}

// Matches returns the list of matches associated with the rule.
func (r *CompiledRule) Matches() []*CompiledMatch {
return r.matches[:]
}

type compiledMatch struct {
// CompiledVariable represents the variable name, expression, and associated type-check declaration.
type CompiledVariable struct {
name string
expr *cel.Ast
varDecl *decls.VariableDecl
}

// Name returns the variable name.
func (v *CompiledVariable) Name() string {
return v.name
}

// Expr returns the compiled expression associated with the variable name.
func (v *CompiledVariable) Expr() *cel.Ast {
return v.expr
}

// Declaration returns the type-check declaration associated with the variable.
func (v *CompiledVariable) Declaration() *decls.VariableDecl {
return v.varDecl
}

// CompiledMatch represents a match block which has an optional condition (true, by default) as well
// as an output or a nested rule (one or the other, but not both).
type CompiledMatch struct {
cond *cel.Ast
output *cel.Ast
nestedRule *compiledRule
output *OutputValue
nestedRule *CompiledRule
}

// Condition returns the compiled predicate expression which must evaluate to true before the output
// or subrule is entered.
func (m *CompiledMatch) Condition() *cel.Ast {
return m.cond
}

// Output returns the compiled output expression associated with the match block, if set.
func (m *CompiledMatch) Output() *OutputValue {
return m.output
}

// NestedRule returns the nested rule, if set.
func (m *CompiledMatch) NestedRule() *CompiledRule {
return m.nestedRule
}

// OutputValue represents the output expression associated with a match block.
type OutputValue struct {
id int64
expr *cel.Ast
}

// ID returns the expression id associated with the output expression.
func (o *OutputValue) ID() int64 {
return o.id
}

// Expr returns the compiled expression associated with the output.
func (o *OutputValue) Expr() *cel.Ast {
return o.expr
}

// Compile generates a single CEL AST from a collection of policy expressions associated with a CEL environment.
// Compile combines the policy compilation and composition steps into a single call.
//
// This generates a single CEL AST from a collection of policy expressions associated with a
// CEL environment.
func Compile(env *cel.Env, p *Policy) (*cel.Ast, *cel.Issues) {
rule, iss := CompileRule(env, p)
if iss.Err() != nil {
return nil, iss
}
composer := NewRuleComposer(env, p)
return composer.Compose(rule)
}

// CompileRule creates a compiled rules from the policy which contains a set of compiled variables and
// match statements. The compiled rule defines an expression graph, which can be composed into a single
// expression via the RuleComposer.Compose method.
func CompileRule(env *cel.Env, p *Policy) (*CompiledRule, *cel.Issues) {
c := &compiler{
env: env,
info: p.SourceInfo(),
Expand All @@ -59,18 +136,18 @@ func Compile(env *cel.Env, p *Policy) (*cel.Ast, *cel.Issues) {
iss := cel.NewIssuesWithSourceInfo(errs, c.info)
rule, ruleIss := c.compileRule(p.Rule(), c.env, iss)
iss = iss.Append(ruleIss)
if iss.Err() != nil {
return nil, iss
}
ruleRoot, _ := env.Compile("true")
opt := cel.NewStaticOptimizer(&ruleComposer{rule: rule})
ruleExprAST, optIss := opt.Optimize(env, ruleRoot)
return ruleExprAST, iss.Append(optIss)
return rule, iss
}

func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*compiledRule, *cel.Issues) {
type compiler struct {
env *cel.Env
info *ast.SourceInfo
src *Source
}

func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*CompiledRule, *cel.Issues) {
var err error
compiledVars := make([]*compiledVariable, len(r.Variables()))
compiledVars := make([]*CompiledVariable, len(r.Variables()))
for i, v := range r.Variables() {
exprSrc := c.relSource(v.Expression())
varAST, exprIss := ruleEnv.CompileSource(exprSrc)
Expand All @@ -88,16 +165,18 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*com

// Introduce the variable into the environment. By extending the environment, the variables
// are effectively scoped such that they must be declared before use.
ruleEnv, err = ruleEnv.Extend(cel.Variable(fmt.Sprintf("%s.%s", variablePrefix, varName), varType))
varDecl := decls.NewVariable(fmt.Sprintf("%s.%s", variablePrefix, varName), varType)
ruleEnv, err = ruleEnv.Extend(cel.Variable(varDecl.Name(), varDecl.Type()))
if err != nil {
iss.ReportErrorAtID(v.Expression().ID, "invalid variable declaration")
}
compiledVars[i] = &compiledVariable{
name: v.name.Value,
expr: varAST,
compiledVars[i] = &CompiledVariable{
name: v.name.Value,
expr: varAST,
varDecl: varDecl,
}
}
compiledMatches := []*compiledMatch{}
compiledMatches := []*CompiledMatch{}
for _, m := range r.Matches() {
condSrc := c.relSource(m.Condition())
condAST, condIss := ruleEnv.CompileSource(condSrc)
Expand All @@ -113,22 +192,26 @@ func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*com
outSrc := c.relSource(m.Output())
outAST, outIss := ruleEnv.CompileSource(outSrc)
iss = iss.Append(outIss)
compiledMatches = append(compiledMatches, &compiledMatch{
cond: condAST,
output: outAST,
compiledMatches = append(compiledMatches, &CompiledMatch{
cond: condAST,
output: &OutputValue{
id: m.Output().ID,
expr: outAST,
},
})
continue
}
if m.HasRule() {
nestedRule, ruleIss := c.compileRule(m.Rule(), ruleEnv, iss)
iss = iss.Append(ruleIss)
compiledMatches = append(compiledMatches, &compiledMatch{
compiledMatches = append(compiledMatches, &CompiledMatch{
cond: condAST,
nestedRule: nestedRule,
})
}
}
return &compiledRule{
return &CompiledRule{
id: r.id,
variables: compiledVars,
matches: compiledMatches,
}, iss
Expand All @@ -146,74 +229,6 @@ func (c *compiler) relSource(pstr ValueString) *RelativeSource {
return c.src.Relative(pstr.Value, line, col)
}

type ruleComposer struct {
rule *compiledRule
}

// Optimize implements an AST optimizer for CEL which composes an expression graph into a single
// expression value.
func (opt *ruleComposer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
// The input to optimize is a dummy expression which is completely replaced according
// to the configuration of the rule composition graph.
ruleExpr, _ := optimizeRule(ctx, opt.rule)
return ctx.NewAST(ruleExpr)
}

func optimizeRule(ctx *cel.OptimizerContext, r *compiledRule) (ast.Expr, bool) {
matchExpr := ctx.NewCall("optional.none")
matches := r.matches
optionalResult := true
for i := len(matches) - 1; i >= 0; i-- {
m := matches[i]
cond := ctx.CopyASTAndMetadata(m.cond.NativeRep())
triviallyTrue := cond.Kind() == ast.LiteralKind && cond.AsLiteral() == types.True
if m.output != nil {
out := ctx.CopyASTAndMetadata(m.output.NativeRep())
if triviallyTrue {
matchExpr = out
optionalResult = false
continue
}
if optionalResult {
out = ctx.NewCall("optional.of", out)
}
matchExpr = ctx.NewCall(
operators.Conditional,
cond,
out,
matchExpr)
continue
}
nestedRule, nestedOptional := optimizeRule(ctx, m.nestedRule)
if optionalResult && !nestedOptional {
nestedRule = ctx.NewCall("optional.of", nestedRule)
}
if !optionalResult && nestedOptional {
matchExpr = ctx.NewCall("optional.of", matchExpr)
optionalResult = true
}
if !optionalResult && !nestedOptional {
ctx.ReportErrorAtID(nestedRule.ID(), "subrule early terminates policy")
continue
}
matchExpr = ctx.NewMemberCall("or", nestedRule, matchExpr)
}

vars := r.variables
for i := len(vars) - 1; i >= 0; i-- {
v := vars[i]
varAST := ctx.CopyASTAndMetadata(v.expr.NativeRep())
// Build up the bindings in reverse order, starting from root, all the way up to the outermost
// binding:
// currExpr = cel.bind(outerVar, outerExpr, currExpr)
varName := fmt.Sprintf("%s.%s", variablePrefix, v.name)
inlined, bindMacro := ctx.NewBindMacro(matchExpr.ID(), varName, varAST, matchExpr)
ctx.UpdateExpr(matchExpr, inlined)
ctx.SetMacroCall(matchExpr.ID(), bindMacro)
}
return matchExpr, optionalResult
}

const (
// Consider making the variables namespace configurable.
variablePrefix = "variables"
Expand Down
Loading