diff --git a/checker/cost.go b/checker/cost.go index f232f30d..fd3f7350 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -520,6 +520,9 @@ func (c *coster) costComprehension(e *exprpb.Expr) CostEstimate { c.iterRanges.pop(comp.GetIterVar()) sum = sum.Add(c.cost(comp.Result)) rangeCnt := c.sizeEstimate(c.newAstNode(comp.GetIterRange())) + + c.computedSizes[e.GetId()] = rangeCnt + rangeCost := rangeCnt.MultiplyByCost(stepCost.Add(loopCost)) sum = sum.Add(rangeCost) diff --git a/checker/cost_test.go b/checker/cost_test.go index c94c1c2b..7781a815 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -502,6 +502,71 @@ func TestCost(t *testing.T) { }, wanted: CostEstimate{Min: 3, Max: 3}, }, + { + name: ".filter list literal", + expr: `[1,2,3,4,5].filter(x, x % 2 == 0)`, + wanted: CostEstimate{Min: 41, Max: 101}, + }, + { + name: ".map list literal", + expr: `[1,2,3,4,5].map(x, x)`, + wanted: CostEstimate{Min: 86, Max: 86}, + }, + { + name: ".map.filter list literal", + expr: `[1,2,3,4,5].map(x, x).filter(x, x % 2 == 0)`, + wanted: CostEstimate{Min: 117, Max: 177}, + }, + { + name: ".map.exists list literal", + expr: `[1,2,3,4,5].map(x, x).exists(x, x == 5) == true`, + wanted: CostEstimate{Min: 108, Max: 118}, + }, + { + name: ".map.map list literal", + expr: `[1,2,3,4,5].map(x, x).map(x, x)`, + wanted: CostEstimate{Min: 162, Max: 162}, + }, + { + name: ".map list literal selection", + expr: `[1,2,3,4,5].map(x, x)[4]`, + wanted: CostEstimate{Min: 87, Max: 87}, + }, + { + name: "nested array selection", + expr: `[[1,2],[1,2],[1,2],[1,2],[1,2]][4]`, + wanted: CostEstimate{Min: 61, Max: 61}, + }, + { + name: "nested array selection", + expr: `{'a': [1,2], 'b': [1,2], 'c': [1,2], 'd': [1,2], 'e': [1,2]}.b`, + wanted: CostEstimate{Min: 81, Max: 81}, + }, + { + // Estimated cost does not track the sizes of nested aggregate types + // (lists, maps, ...) and so assumes a worst case cost when an + // expression applies a comprehension to a nested aggregated type, + // even if the size information is available. + // TODO: This should be fixed. + name: "comprehension on nested list", + expr: `[1,2,3,4,5].map(x, [x, x]).all(y, y.all(y, y == 1))`, + wanted: CostEstimate{Min: 157, Max: 18446744073709551615}, + }, + { + // Make sure we're accounting for not just the iteration range size, + // but also the overall comprehension size. The chained map calls + // will treat the result of one map as the iteration range of the other, + // so they're planned in reverse; however, the `+` should verify that + // the comprehension result has a size. + name: "comprehension size", + expr: `[1,2,3,4,5].map(x, x).map(x, x) + [1]`, + wanted: CostEstimate{Min: 173, Max: 173}, + }, + { + name: "nested comprehension", + expr: `[1,2,3].all(i, i in [1,2,3].map(j, j + j))`, + wanted: CostEstimate{Min: 20, Max: 230}, + }, } for _, tc := range cases { diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index 1c6ac124..687a47b8 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -791,6 +791,70 @@ func TestRuntimeCost(t *testing.T) { in: map[string]any{}, want: 77, }, + { + name: "list map literal", + expr: `[{'k1': 1}, {'k2': 2}].all(x, true)`, + vars: []*decls.VariableDecl{}, + in: map[string]any{}, + want: 77, + }, + { + name: ".filter list literal", + expr: `[1,2,3,4,5].filter(x, x % 2 == 0)`, + vars: []*decls.VariableDecl{}, + in: map[string]any{}, + want: 62, + }, + { + name: ".map list literal", + expr: `[1,2,3,4,5].map(x, x)`, + vars: []*decls.VariableDecl{}, + in: map[string]any{}, + want: 86, + }, + { + name: ".map.filter list literal", + expr: `[1,2,3,4,5].map(x, x).filter(x, x % 2 == 0)`, + vars: []*decls.VariableDecl{}, + in: map[string]any{}, + want: 138, + }, + { + name: ".map.exists list literal", + expr: `[1,2,3,4,5].map(x, x).exists(x, x == 5) == true`, + vars: []*decls.VariableDecl{}, + in: map[string]any{}, + want: 118, + }, + { + name: ".map.map list literal", + expr: `[1,2,3,4,5].map(x, x).map(x, x)`, + vars: []*decls.VariableDecl{}, + in: map[string]any{}, + want: 162, + }, + { + name: ".map.map list literal", + expr: `[1,2,3,4,5].map(x, [x, x]).filter(z, z.size() == 2)`, + vars: []*decls.VariableDecl{}, + in: map[string]any{}, + want: 232, + }, + { + name: "comprehension on nested list", + expr: `[1,2,3,4,5].map(x, [x, x]).all(y, y.all(y, y == 1))`, + want: 171, + }, + { + name: "comprehension size", + expr: `[1,2,3,4,5].map(x, x).map(x, x) + [1]`, + want: 173, + }, + { + name: "nested comprehension", + expr: `[1,2,3].all(i, i in [1,2,3].map(j, j + j))`, + want: 86, + }, } for _, tc := range cases {