diff --git a/common/decls/decls.go b/common/decls/decls.go index 0a42f81d..1bf4667e 100644 --- a/common/decls/decls.go +++ b/common/decls/decls.go @@ -251,15 +251,15 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) { // are preserved in order to assist with the function resolution step. switch len(args) { case 1: - if o.unaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) { + if o.unaryOp != nil && o.matchesRuntimeSignature(f.disableTypeGuards, args...) { return o.unaryOp(args[0]) } case 2: - if o.binaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) { + if o.binaryOp != nil && o.matchesRuntimeSignature(f.disableTypeGuards, args...) { return o.binaryOp(args[0], args[1]) } } - if o.functionOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) { + if o.functionOp != nil && o.matchesRuntimeSignature(f.disableTypeGuards, args...) { return o.functionOp(args...) } // eventually this will fall through to the noSuchOverload below. diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index db223da2..3fe5378f 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -24,6 +24,7 @@ go_library( "//cel:go_default_library", "//checker:go_default_library", "//common/ast:go_default_library", + "//common/decls:go_default_library", "//common/overloads:go_default_library", "//common/operators:go_default_library", "//common/types:go_default_library", @@ -61,8 +62,8 @@ go_test( "//common/types/ref:go_default_library", "//common/types/traits:go_default_library", "//test:go_default_library", - "//test/proto2pb:go_default_library", - "//test/proto3pb:go_default_library", + "//test/proto2pb:go_default_library", + "//test/proto3pb:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//types/known/wrapperspb:go_default_library", "@org_golang_google_protobuf//encoding/protojson:go_default_library", diff --git a/ext/lists.go b/ext/lists.go index de2fd709..cfce9a72 100644 --- a/ext/lists.go +++ b/ext/lists.go @@ -19,6 +19,7 @@ import ( "math" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" @@ -95,6 +96,7 @@ func ListsVersion(version uint32) ListsOption { // CompileOptions implements the Library interface method. func (lib listsLib) CompileOptions() []cel.EnvOption { listType := cel.ListType(cel.TypeParamType("T")) + listListType := cel.ListType(listType) listDyn := cel.ListType(cel.DynType) opts := []cel.EnvOption{ cel.Function("slice", @@ -117,24 +119,35 @@ func (lib listsLib) CompileOptions() []cel.EnvOption { opts = append(opts, cel.Function("flatten", cel.MemberOverload("list_flatten", - []*cel.Type{listDyn}, listDyn, + []*cel.Type{listListType}, listType, cel.UnaryBinding(func(arg ref.Val) ref.Val { - list := arg.(traits.Lister) + list, ok := arg.(traits.Lister) + if !ok { + return types.MaybeNoSuchOverloadErr(arg) + } flatList := flatten(list, 1) return types.DefaultTypeAdapter.NativeToValue(flatList) }), ), - ), - cel.Function("flatten", cel.MemberOverload("list_flatten_int", []*cel.Type{listDyn, types.IntType}, listDyn, cel.BinaryBinding(func(arg1, arg2 ref.Val) ref.Val { - list := arg1.(traits.Lister) - depth := arg2.(types.Int) + list, ok := arg1.(traits.Lister) + if !ok { + return types.MaybeNoSuchOverloadErr(arg1) + } + depth, ok := arg2.(types.Int) + if !ok { + return types.MaybeNoSuchOverloadErr(arg2) + } flatList := flatten(list, int64(depth)) return types.DefaultTypeAdapter.NativeToValue(flatList) }), ), + // To handle the case where a variable of just `list(T)` is provided at runtime + // with a graceful failure more, disable the type guards since the implementation + // can handle lists which are already flat. + decls.DisableTypeGuards(true), ), ) } diff --git a/ext/lists_test.go b/ext/lists_test.go index 1851f838..beeb0ee7 100644 --- a/ext/lists_test.go +++ b/ext/lists_test.go @@ -36,8 +36,8 @@ func TestLists(t *testing.T) { {expr: `[1,2,3,4].slice(0, 10)`, err: "cannot slice(0, 10), list is length 4"}, {expr: `[1,2,3,4].slice(-5, 10)`, err: "cannot slice(-5, 10), negative indexes not supported"}, {expr: `[1,2,3,4].slice(-5, -3)`, err: "cannot slice(-5, -3), negative indexes not supported"}, - {expr: `[].flatten() == []`}, - {expr: `[1,2,3,4].flatten() == [1,2,3,4]`}, + {expr: `dyn([]).flatten() == []`}, + {expr: `dyn([1,2,3,4]).flatten() == [1,2,3,4]`}, {expr: `[1,[2,[3,4]]].flatten() == [1,2,[3,4]]`}, {expr: `[1,2,[],[],[3,4]].flatten() == [1,2,3,4]`}, {expr: `[1,[2,[3,4]]].flatten(2) == [1,2,3,4]`},