Skip to content

Commit

Permalink
Allow to override builtins (#522)
Browse files Browse the repository at this point in the history
* Allow to override builtins

* Add :: syntax to access builtin in case of override
  • Loading branch information
antonmedv authored Jan 13, 2024
1 parent db94b96 commit b2f6fb8
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 57 deletions.
84 changes: 59 additions & 25 deletions builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,37 +284,71 @@ func TestBuiltin_memory_limits(t *testing.T) {
}
}

func TestBuiltin_disallow_builtins_override(t *testing.T) {
t.Run("via env", func(t *testing.T) {
env := map[string]any{
"len": func() int { return 42 },
"repeat": func(a string) string {
return a
},
func TestBuiltin_allow_builtins_override(t *testing.T) {
t.Run("via env var", func(t *testing.T) {
for _, name := range builtin.Names {
t.Run(name, func(t *testing.T) {
env := map[string]any{
name: "hello world",
}
program, err := expr.Compile(name, expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
assert.Equal(t, "hello world", out)
})
}
})
t.Run("via env func", func(t *testing.T) {
for _, name := range builtin.Names {
t.Run(name, func(t *testing.T) {
env := map[string]any{
name: func() int { return 1 },
}
program, err := expr.Compile(fmt.Sprintf("%s()", name), expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
assert.Equal(t, 1, out)
})
}
assert.Panics(t, func() {
_, _ = expr.Compile(`string(len("foo")) + repeat("0", 2)`, expr.Env(env))
})
})
t.Run("via expr.Function", func(t *testing.T) {
length := expr.Function("len",
func(params ...any) (any, error) {
return 42, nil
},
new(func() int),
)
repeat := expr.Function("repeat",
func(params ...any) (any, error) {
return params[0], nil
},
new(func(string) string),
)
assert.Panics(t, func() {
_, _ = expr.Compile(`string(len("foo")) + repeat("0", 2)`, length, repeat)
})
for _, name := range builtin.Names {
t.Run(name, func(t *testing.T) {
fn := expr.Function(name,
func(params ...any) (any, error) {
return 42, nil
},
new(func() int),
)
program, err := expr.Compile(fmt.Sprintf("%s()", name), fn)
require.NoError(t, err)

out, err := expr.Run(program, nil)
require.NoError(t, err)
assert.Equal(t, 42, out)
})
}
})
}

func TestBuiltin_override_and_still_accessible(t *testing.T) {
env := map[string]any{
"len": func() int { return 42 },
"all": []int{1, 2, 3},
}

program, err := expr.Compile(`::all(all, #>0) && len() == 42 && ::len(all) == 3`, expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
assert.Equal(t, true, out)
}

func TestBuiltin_DisableBuiltin(t *testing.T) {
t.Run("via env", func(t *testing.T) {
for _, b := range builtin.Builtins {
Expand Down
24 changes: 13 additions & 11 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,24 +156,25 @@ func (v *checker) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info)
if node.Value == "$env" {
return mapType, info{}
}
if fn, ok := v.config.Builtins[node.Value]; ok {
return functionType, info{fn: fn}
}
if fn, ok := v.config.Functions[node.Value]; ok {
return functionType, info{fn: fn}
}
return v.env(node, node.Value, true)
return v.ident(node, node.Value, true, true)
}

// env method returns type of environment variable. env only lookups for
// environment variables, no builtins, no custom functions.
func (v *checker) env(node ast.Node, name string, strict bool) (reflect.Type, info) {
// ident method returns type of environment variable, builtin or function.
func (v *checker) ident(node ast.Node, name string, strict, builtins bool) (reflect.Type, info) {
if t, ok := v.config.Types[name]; ok {
if t.Ambiguous {
return v.error(node, "ambiguous identifier %v", name)
}
return t.Type, info{method: t.Method}
}
if builtins {
if fn, ok := v.config.Functions[name]; ok {
return functionType, info{fn: fn}
}
if fn, ok := v.config.Builtins[name]; ok {
return functionType, info{fn: fn}
}
}
if v.config.Strict && strict {
return v.error(node, "unknown name %v", name)
}
Expand Down Expand Up @@ -433,6 +434,7 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) {
base, _ := v.visit(node.Node)
prop, _ := v.visit(node.Property)

// $env variable
if an, ok := node.Node.(*ast.IdentifierNode); ok && an.Value == "$env" {
if name, ok := node.Property.(*ast.StringNode); ok {
strict := v.config.Strict
Expand All @@ -443,7 +445,7 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) {
// should throw error if field is not found & v.config.Strict.
strict = false
}
return v.env(node, name.Value, strict)
return v.ident(node, name.Value, strict, false /* no builtins and no functions */)
}
return anyType, info{}
}
Expand Down
22 changes: 8 additions & 14 deletions conf/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,14 @@ func (c *Config) Check() {
}
}
}
for fnName, t := range c.Types {
if kind(t.Type) == reflect.Func {
for _, b := range c.Builtins {
if b.Name == fnName {
panic(fmt.Errorf(`cannot override builtin %s(): use expr.DisableBuiltin("%s") to override`, b.Name, b.Name))
}
}
}
}

func (c *Config) IsOverridden(name string) bool {
if _, ok := c.Functions[name]; ok {
return true
}
for _, f := range c.Functions {
for _, b := range c.Builtins {
if b.Name == f.Name {
panic(fmt.Errorf(`cannot override builtin %s(); use expr.DisableBuiltin("%s") to override`, f.Name, f.Name))
}
}
if _, ok := c.Types[name]; ok {
return true
}
return false
}
8 changes: 8 additions & 0 deletions parser/lexer/lexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ func TestLex(t *testing.T) {
{Kind: EOF},
},
},
{
`: ::`,
[]Token{
{Kind: Operator, Value: ":"},
{Kind: Operator, Value: "::"},
{Kind: EOF},
},
},
}

for _, test := range tests {
Expand Down
5 changes: 4 additions & 1 deletion parser/lexer/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ func root(l *lexer) stateFn {
case r == '|':
l.accept("|")
l.emit(Operator)
case r == ':':
l.accept(":")
l.emit(Operator)
case strings.ContainsRune("([{", r):
l.emit(Bracket)
case strings.ContainsRune(")]}", r):
l.emit(Bracket)
case strings.ContainsRune(",:;%+-^", r): // single rune operator
case strings.ContainsRune(",;%+-^", r): // single rune operator
l.emit(Operator)
case strings.ContainsRune("&!=*<>", r): // possible double rune operator
l.accept("&=*")
Expand Down
21 changes: 15 additions & 6 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,13 @@ func (p *parser) parsePrimary() Node {
}
}

if token.Is(Operator, "::") {
p.next()
token = p.current
p.expect(Identifier)
return p.parsePostfixExpression(p.parseCall(token, false))
}

return p.parseSecondary()
}

Expand All @@ -300,7 +307,7 @@ func (p *parser) parseSecondary() Node {
node.SetLocation(token.Location)
return node
default:
node = p.parseCall(token)
node = p.parseCall(token, true)
}

case Number:
Expand Down Expand Up @@ -379,15 +386,17 @@ func (p *parser) toFloatNode(number float64) Node {
return &FloatNode{Value: number}
}

func (p *parser) parseCall(token Token) Node {
func (p *parser) parseCall(token Token, checkOverrides bool) Node {
var node Node
if p.current.Is(Bracket, "(") {
var arguments []Node

if b, ok := predicates[token.Value]; ok {
p.expect(Bracket, "(")
isOverridden := p.config.IsOverridden(token.Value)
isOverridden = isOverridden && checkOverrides

// TODO: Refactor parser to use builtin.Builtins instead of predicates map.
// TODO: Refactor parser to use builtin.Builtins instead of predicates map.
if b, ok := predicates[token.Value]; ok && !isOverridden {
p.expect(Bracket, "(")

if b.arity == 1 {
arguments = make([]Node, 1)
Expand Down Expand Up @@ -417,7 +426,7 @@ func (p *parser) parseCall(token Token) Node {
Arguments: arguments,
}
node.SetLocation(token.Location)
} else if _, ok := builtin.Index[token.Value]; ok && !p.config.Disabled[token.Value] {
} else if _, ok := builtin.Index[token.Value]; ok && !p.config.Disabled[token.Value] && !isOverridden {
node = &BuiltinNode{
Name: token.Value,
Arguments: p.parseArguments(),
Expand Down
23 changes: 23 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,29 @@ world`},
},
},
},
{
`::split("a,b,c", ",")`,
&BuiltinNode{
Name: "split",
Arguments: []Node{
&StringNode{Value: "a,b,c"},
&StringNode{Value: ","},
},
},
},
{
`::split("a,b,c", ",")[0]`,
&MemberNode{
Node: &BuiltinNode{
Name: "split",
Arguments: []Node{
&StringNode{Value: "a,b,c"},
&StringNode{Value: ","},
},
},
Property: &IntegerNode{Value: 0},
},
},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
Expand Down

0 comments on commit b2f6fb8

Please sign in to comment.