Skip to content

Commit

Permalink
Update to flatten list to have a stronger type-check signature (#1004)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Aug 16, 2024
1 parent 8d9b9d3 commit 2e58e6e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 13 deletions.
6 changes: 3 additions & 3 deletions common/decls/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,15 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
// are preserved in order to assist with the function resolution step.
switch len(args) {
case 1:
if o.unaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
if o.unaryOp != nil && o.matchesRuntimeSignature(f.disableTypeGuards, args...) {
return o.unaryOp(args[0])
}
case 2:
if o.binaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
if o.binaryOp != nil && o.matchesRuntimeSignature(f.disableTypeGuards, args...) {
return o.binaryOp(args[0], args[1])
}
}
if o.functionOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
if o.functionOp != nil && o.matchesRuntimeSignature(f.disableTypeGuards, args...) {
return o.functionOp(args...)
}
// eventually this will fall through to the noSuchOverload below.
Expand Down
5 changes: 3 additions & 2 deletions ext/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ go_library(
"//cel:go_default_library",
"//checker:go_default_library",
"//common/ast:go_default_library",
"//common/decls:go_default_library",
"//common/overloads:go_default_library",
"//common/operators:go_default_library",
"//common/types:go_default_library",
Expand Down Expand Up @@ -61,8 +62,8 @@ go_test(
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
Expand Down
25 changes: 19 additions & 6 deletions ext/lists.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"math"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
Expand Down Expand Up @@ -95,6 +96,7 @@ func ListsVersion(version uint32) ListsOption {
// CompileOptions implements the Library interface method.
func (lib listsLib) CompileOptions() []cel.EnvOption {
listType := cel.ListType(cel.TypeParamType("T"))
listListType := cel.ListType(listType)
listDyn := cel.ListType(cel.DynType)
opts := []cel.EnvOption{
cel.Function("slice",
Expand All @@ -117,24 +119,35 @@ func (lib listsLib) CompileOptions() []cel.EnvOption {
opts = append(opts,
cel.Function("flatten",
cel.MemberOverload("list_flatten",
[]*cel.Type{listDyn}, listDyn,
[]*cel.Type{listListType}, listType,
cel.UnaryBinding(func(arg ref.Val) ref.Val {
list := arg.(traits.Lister)
list, ok := arg.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
flatList := flatten(list, 1)
return types.DefaultTypeAdapter.NativeToValue(flatList)
}),
),
),
cel.Function("flatten",
cel.MemberOverload("list_flatten_int",
[]*cel.Type{listDyn, types.IntType}, listDyn,
cel.BinaryBinding(func(arg1, arg2 ref.Val) ref.Val {
list := arg1.(traits.Lister)
depth := arg2.(types.Int)
list, ok := arg1.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg1)
}
depth, ok := arg2.(types.Int)
if !ok {
return types.MaybeNoSuchOverloadErr(arg2)
}
flatList := flatten(list, int64(depth))
return types.DefaultTypeAdapter.NativeToValue(flatList)
}),
),
// To handle the case where a variable of just `list(T)` is provided at runtime
// with a graceful failure more, disable the type guards since the implementation
// can handle lists which are already flat.
decls.DisableTypeGuards(true),
),
)
}
Expand Down
4 changes: 2 additions & 2 deletions ext/lists_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ func TestLists(t *testing.T) {
{expr: `[1,2,3,4].slice(0, 10)`, err: "cannot slice(0, 10), list is length 4"},
{expr: `[1,2,3,4].slice(-5, 10)`, err: "cannot slice(-5, 10), negative indexes not supported"},
{expr: `[1,2,3,4].slice(-5, -3)`, err: "cannot slice(-5, -3), negative indexes not supported"},
{expr: `[].flatten() == []`},
{expr: `[1,2,3,4].flatten() == [1,2,3,4]`},
{expr: `dyn([]).flatten() == []`},
{expr: `dyn([1,2,3,4]).flatten() == [1,2,3,4]`},
{expr: `[1,[2,[3,4]]].flatten() == [1,2,[3,4]]`},
{expr: `[1,2,[],[],[3,4]].flatten() == [1,2,3,4]`},
{expr: `[1,[2,[3,4]]].flatten(2) == [1,2,3,4]`},
Expand Down

0 comments on commit 2e58e6e

Please sign in to comment.