From b2f6fb83b7e749d225dfea77a56acf8624cccbf1 Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Sat, 13 Jan 2024 15:56:49 +0100 Subject: [PATCH] Allow to override builtins (#522) * Allow to override builtins * Add :: syntax to access builtin in case of override --- builtin/builtin_test.go | 84 ++++++++++++++++++++++++++------------ checker/checker.go | 24 ++++++----- conf/config.go | 22 ++++------ parser/lexer/lexer_test.go | 8 ++++ parser/lexer/state.go | 5 ++- parser/parser.go | 21 +++++++--- parser/parser_test.go | 23 +++++++++++ 7 files changed, 130 insertions(+), 57 deletions(-) diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index 23273a24..0cdf9515 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -284,37 +284,71 @@ func TestBuiltin_memory_limits(t *testing.T) { } } -func TestBuiltin_disallow_builtins_override(t *testing.T) { - t.Run("via env", func(t *testing.T) { - env := map[string]any{ - "len": func() int { return 42 }, - "repeat": func(a string) string { - return a - }, +func TestBuiltin_allow_builtins_override(t *testing.T) { + t.Run("via env var", func(t *testing.T) { + for _, name := range builtin.Names { + t.Run(name, func(t *testing.T) { + env := map[string]any{ + name: "hello world", + } + program, err := expr.Compile(name, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, "hello world", out) + }) + } + }) + t.Run("via env func", func(t *testing.T) { + for _, name := range builtin.Names { + t.Run(name, func(t *testing.T) { + env := map[string]any{ + name: func() int { return 1 }, + } + program, err := expr.Compile(fmt.Sprintf("%s()", name), expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, 1, out) + }) } - assert.Panics(t, func() { - _, _ = expr.Compile(`string(len("foo")) + repeat("0", 2)`, expr.Env(env)) - }) }) t.Run("via expr.Function", func(t *testing.T) { - length := expr.Function("len", - func(params ...any) (any, error) { - return 42, nil - }, - new(func() int), - ) - repeat := expr.Function("repeat", - func(params ...any) (any, error) { - return params[0], nil - }, - new(func(string) string), - ) - assert.Panics(t, func() { - _, _ = expr.Compile(`string(len("foo")) + repeat("0", 2)`, length, repeat) - }) + for _, name := range builtin.Names { + t.Run(name, func(t *testing.T) { + fn := expr.Function(name, + func(params ...any) (any, error) { + return 42, nil + }, + new(func() int), + ) + program, err := expr.Compile(fmt.Sprintf("%s()", name), fn) + require.NoError(t, err) + + out, err := expr.Run(program, nil) + require.NoError(t, err) + assert.Equal(t, 42, out) + }) + } }) } +func TestBuiltin_override_and_still_accessible(t *testing.T) { + env := map[string]any{ + "len": func() int { return 42 }, + "all": []int{1, 2, 3}, + } + + program, err := expr.Compile(`::all(all, #>0) && len() == 42 && ::len(all) == 3`, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, true, out) +} + func TestBuiltin_DisableBuiltin(t *testing.T) { t.Run("via env", func(t *testing.T) { for _, b := range builtin.Builtins { diff --git a/checker/checker.go b/checker/checker.go index 4dced34e..5dd722fe 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -156,24 +156,25 @@ func (v *checker) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info) if node.Value == "$env" { return mapType, info{} } - if fn, ok := v.config.Builtins[node.Value]; ok { - return functionType, info{fn: fn} - } - if fn, ok := v.config.Functions[node.Value]; ok { - return functionType, info{fn: fn} - } - return v.env(node, node.Value, true) + return v.ident(node, node.Value, true, true) } -// env method returns type of environment variable. env only lookups for -// environment variables, no builtins, no custom functions. -func (v *checker) env(node ast.Node, name string, strict bool) (reflect.Type, info) { +// ident method returns type of environment variable, builtin or function. +func (v *checker) ident(node ast.Node, name string, strict, builtins bool) (reflect.Type, info) { if t, ok := v.config.Types[name]; ok { if t.Ambiguous { return v.error(node, "ambiguous identifier %v", name) } return t.Type, info{method: t.Method} } + if builtins { + if fn, ok := v.config.Functions[name]; ok { + return functionType, info{fn: fn} + } + if fn, ok := v.config.Builtins[name]; ok { + return functionType, info{fn: fn} + } + } if v.config.Strict && strict { return v.error(node, "unknown name %v", name) } @@ -433,6 +434,7 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) { base, _ := v.visit(node.Node) prop, _ := v.visit(node.Property) + // $env variable if an, ok := node.Node.(*ast.IdentifierNode); ok && an.Value == "$env" { if name, ok := node.Property.(*ast.StringNode); ok { strict := v.config.Strict @@ -443,7 +445,7 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) { // should throw error if field is not found & v.config.Strict. strict = false } - return v.env(node, name.Value, strict) + return v.ident(node, name.Value, strict, false /* no builtins and no functions */) } return anyType, info{} } diff --git a/conf/config.go b/conf/config.go index 48d491a9..baf5dee0 100644 --- a/conf/config.go +++ b/conf/config.go @@ -98,20 +98,14 @@ func (c *Config) Check() { } } } - for fnName, t := range c.Types { - if kind(t.Type) == reflect.Func { - for _, b := range c.Builtins { - if b.Name == fnName { - panic(fmt.Errorf(`cannot override builtin %s(): use expr.DisableBuiltin("%s") to override`, b.Name, b.Name)) - } - } - } +} + +func (c *Config) IsOverridden(name string) bool { + if _, ok := c.Functions[name]; ok { + return true } - for _, f := range c.Functions { - for _, b := range c.Builtins { - if b.Name == f.Name { - panic(fmt.Errorf(`cannot override builtin %s(); use expr.DisableBuiltin("%s") to override`, f.Name, f.Name)) - } - } + if _, ok := c.Types[name]; ok { + return true } + return false } diff --git a/parser/lexer/lexer_test.go b/parser/lexer/lexer_test.go index feecf045..442fd4db 100644 --- a/parser/lexer/lexer_test.go +++ b/parser/lexer/lexer_test.go @@ -225,6 +225,14 @@ func TestLex(t *testing.T) { {Kind: EOF}, }, }, + { + `: ::`, + []Token{ + {Kind: Operator, Value: ":"}, + {Kind: Operator, Value: "::"}, + {Kind: EOF}, + }, + }, } for _, test := range tests { diff --git a/parser/lexer/state.go b/parser/lexer/state.go index 9999fd3c..72f02bf4 100644 --- a/parser/lexer/state.go +++ b/parser/lexer/state.go @@ -37,11 +37,14 @@ func root(l *lexer) stateFn { case r == '|': l.accept("|") l.emit(Operator) + case r == ':': + l.accept(":") + l.emit(Operator) case strings.ContainsRune("([{", r): l.emit(Bracket) case strings.ContainsRune(")]}", r): l.emit(Bracket) - case strings.ContainsRune(",:;%+-^", r): // single rune operator + case strings.ContainsRune(",;%+-^", r): // single rune operator l.emit(Operator) case strings.ContainsRune("&!=*<>", r): // possible double rune operator l.accept("&=*") diff --git a/parser/parser.go b/parser/parser.go index 5f29279b..1ce1cda0 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -275,6 +275,13 @@ func (p *parser) parsePrimary() Node { } } + if token.Is(Operator, "::") { + p.next() + token = p.current + p.expect(Identifier) + return p.parsePostfixExpression(p.parseCall(token, false)) + } + return p.parseSecondary() } @@ -300,7 +307,7 @@ func (p *parser) parseSecondary() Node { node.SetLocation(token.Location) return node default: - node = p.parseCall(token) + node = p.parseCall(token, true) } case Number: @@ -379,15 +386,17 @@ func (p *parser) toFloatNode(number float64) Node { return &FloatNode{Value: number} } -func (p *parser) parseCall(token Token) Node { +func (p *parser) parseCall(token Token, checkOverrides bool) Node { var node Node if p.current.Is(Bracket, "(") { var arguments []Node - if b, ok := predicates[token.Value]; ok { - p.expect(Bracket, "(") + isOverridden := p.config.IsOverridden(token.Value) + isOverridden = isOverridden && checkOverrides - // TODO: Refactor parser to use builtin.Builtins instead of predicates map. + // TODO: Refactor parser to use builtin.Builtins instead of predicates map. + if b, ok := predicates[token.Value]; ok && !isOverridden { + p.expect(Bracket, "(") if b.arity == 1 { arguments = make([]Node, 1) @@ -417,7 +426,7 @@ func (p *parser) parseCall(token Token) Node { Arguments: arguments, } node.SetLocation(token.Location) - } else if _, ok := builtin.Index[token.Value]; ok && !p.config.Disabled[token.Value] { + } else if _, ok := builtin.Index[token.Value]; ok && !p.config.Disabled[token.Value] && !isOverridden { node = &BuiltinNode{ Name: token.Value, Arguments: p.parseArguments(), diff --git a/parser/parser_test.go b/parser/parser_test.go index 453fe91a..0e7a2383 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -498,6 +498,29 @@ world`}, }, }, }, + { + `::split("a,b,c", ",")`, + &BuiltinNode{ + Name: "split", + Arguments: []Node{ + &StringNode{Value: "a,b,c"}, + &StringNode{Value: ","}, + }, + }, + }, + { + `::split("a,b,c", ",")[0]`, + &MemberNode{ + Node: &BuiltinNode{ + Name: "split", + Arguments: []Node{ + &StringNode{Value: "a,b,c"}, + &StringNode{Value: ","}, + }, + }, + Property: &IntegerNode{Value: 0}, + }, + }, } for _, test := range tests { t.Run(test.input, func(t *testing.T) {