From 8ad600b649be1b9ef5a003e8c5632d89b9aaf790 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Tue, 5 Nov 2024 16:04:35 -0800 Subject: [PATCH] Ensure variables in comprehensions don't collide (#1062) --- ext/comprehensions.go | 66 ++++++++++++++++++++------------------ ext/comprehensions_test.go | 12 +++++++ 2 files changed, 46 insertions(+), 32 deletions(-) diff --git a/ext/comprehensions.go b/ext/comprehensions.go index 6071ea06..1428558d 100644 --- a/ext/comprehensions.go +++ b/ext/comprehensions.go @@ -15,6 +15,8 @@ package ext import ( + "fmt" + "github.com/google/cel-go/cel" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" @@ -220,14 +222,11 @@ func (compreV2Lib) ProgramOptions() []cel.ProgramOption { } func quantifierAll(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { - iterVar1, err := extractIterVar(mef, args[0]) - if err != nil { - return nil, err - } - iterVar2, err := extractIterVar(mef, args[1]) + iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1]) if err != nil { return nil, err } + return mef.NewComprehensionTwoVar( target, iterVar1, @@ -241,14 +240,11 @@ func quantifierAll(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) ( } func quantifierExists(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { - iterVar1, err := extractIterVar(mef, args[0]) - if err != nil { - return nil, err - } - iterVar2, err := extractIterVar(mef, args[1]) + iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1]) if err != nil { return nil, err } + return mef.NewComprehensionTwoVar( target, iterVar1, @@ -262,14 +258,11 @@ func quantifierExists(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr } func quantifierExistsOne(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { - iterVar1, err := extractIterVar(mef, args[0]) - if err != nil { - return nil, err - } - iterVar2, err := extractIterVar(mef, args[1]) + iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1]) if err != nil { return nil, err } + return mef.NewComprehensionTwoVar( target, iterVar1, @@ -285,11 +278,7 @@ func quantifierExistsOne(mef cel.MacroExprFactory, target ast.Expr, args []ast.E } func transformList(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { - iterVar1, err := extractIterVar(mef, args[0]) - if err != nil { - return nil, err - } - iterVar2, err := extractIterVar(mef, args[1]) + iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1]) if err != nil { return nil, err } @@ -324,11 +313,7 @@ func transformList(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) ( } func transformMap(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { - iterVar1, err := extractIterVar(mef, args[0]) - if err != nil { - return nil, err - } - iterVar2, err := extractIterVar(mef, args[1]) + iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1]) if err != nil { return nil, err } @@ -362,11 +347,7 @@ func transformMap(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (a } func transformMapEntry(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { - iterVar1, err := extractIterVar(mef, args[0]) - if err != nil { - return nil, err - } - iterVar2, err := extractIterVar(mef, args[1]) + iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1]) if err != nil { return nil, err } @@ -399,10 +380,31 @@ func transformMapEntry(mef cel.MacroExprFactory, target ast.Expr, args []ast.Exp ), nil } -func extractIterVar(meh cel.MacroExprFactory, target ast.Expr) (string, *cel.Error) { +func extractIterVars(mef cel.MacroExprFactory, arg0, arg1 ast.Expr) (string, string, *cel.Error) { + iterVar1, err := extractIterVar(mef, arg0) + if err != nil { + return "", "", err + } + iterVar2, err := extractIterVar(mef, arg1) + if err != nil { + return "", "", err + } + if iterVar1 == iterVar2 { + return "", "", mef.NewError(arg1.ID(), fmt.Sprintf("duplicate variable name: %s", iterVar1)) + } + if iterVar1 == parser.AccumulatorName { + return "", "", mef.NewError(arg0.ID(), "iteration variable overwrites accumulator variable") + } + if iterVar2 == parser.AccumulatorName { + return "", "", mef.NewError(arg1.ID(), "iteration variable overwrites accumulator variable") + } + return iterVar1, iterVar2, nil +} + +func extractIterVar(mef cel.MacroExprFactory, target ast.Expr) (string, *cel.Error) { iterVar, found := extractIdent(target) if !found { - return "", meh.NewError(target.ID(), "argument must be a simple name") + return "", mef.NewError(target.ID(), "argument must be a simple name") } return iterVar, nil } diff --git a/ext/comprehensions_test.go b/ext/comprehensions_test.go index d457ed0d..84d82c37 100644 --- a/ext/comprehensions_test.go +++ b/ext/comprehensions_test.go @@ -210,6 +210,18 @@ func TestTwoVarComprehensionsStaticErrors(t *testing.T) { expr string err string }{ + { + expr: "[].all(i, i, i < i)", + err: "duplicate variable name: i", + }, + { + expr: "[].all(__result__, i, __result__ < i)", + err: "iteration variable overwrites accumulator variable", + }, + { + expr: "[].all(j, __result__, __result__ < j)", + err: "iteration variable overwrites accumulator variable", + }, { expr: "[].all(i.j, k, i.j < k)", err: "argument must be a simple name",