diff --git a/ast/ast.go b/ast/ast.go index 1de8aac0..1d45e13f 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -217,6 +217,7 @@ type Arg interface { func (ExprArg) isArg() {} func (IntervalArg) isArg() {} func (SequenceArg) isArg() {} +func (LambdaArg) isArg() {} // NullHandlingModifier represents IGNORE/RESPECT NULLS of aggregate function calls type NullHandlingModifier interface { @@ -1192,6 +1193,21 @@ type SequenceArg struct { Expr Expr } +// LambdaArg is lambda expression argument of the generic function call. +// +// {{if .Lparen.Invalid}}{{.Args | sqlJoin ", "}}{{else}}({{.Args | sqlJoin ", "}}) -> {{.Expr | sql}} +// +// Note: Args won't be empty. If Lparen is not appeared, Args have exactly one element. +type LambdaArg struct { + // pos = Lparen || Args[0].pos + // end = Expr.end + + Lparen token.Pos // optional + + Args []*Ident // if Lparen.Invalid() then len(Args) = 1 else len(Args) > 0 + Expr Expr +} + // NamedArg represents a name and value pair in named arguments // // {{.Name | sql}} => {{.Value | sql}} diff --git a/ast/pos.go b/ast/pos.go index dd0cd882..3b5c8b06 100644 --- a/ast/pos.go +++ b/ast/pos.go @@ -422,6 +422,14 @@ func (s *SequenceArg) End() token.Pos { return nodeEnd(wrapNode(s.Expr)) } +func (l *LambdaArg) Pos() token.Pos { + return posChoice(l.Lparen, nodePos(nodeSliceIndex(l.Args, 0))) +} + +func (l *LambdaArg) End() token.Pos { + return nodeEnd(wrapNode(l.Expr)) +} + func (n *NamedArg) Pos() token.Pos { return nodePos(wrapNode(n.Name)) } diff --git a/ast/sql.go b/ast/sql.go index 12f864db..e409f5c4 100644 --- a/ast/sql.go +++ b/ast/sql.go @@ -506,6 +506,15 @@ func (c *CallExpr) SQL() string { ")" } +func (l *LambdaArg) SQL() string { + // This implementation is not exactly matched with the doc comment for simplicity. + return strOpt(!l.Lparen.Invalid(), "(") + + sqlJoin(l.Args, ", ") + + strOpt(!l.Lparen.Invalid(), ")") + + " -> " + + l.Expr.SQL() +} + func (n *NamedArg) SQL() string { return n.Name.SQL() + " => " + n.Value.SQL() } func (i *IgnoreNulls) SQL() string { return "IGNORE NULLS" } diff --git a/parser.go b/parser.go index 0655fc52..bc2461ab 100644 --- a/parser.go +++ b/parser.go @@ -1588,6 +1588,42 @@ func (p *Parser) tryParseNamedArg() *ast.NamedArg { } } +func (p *Parser) lookaheadLambdaArg() bool { + lexer := p.Lexer.Clone() + defer func() { + p.Lexer = lexer + }() + + if p.Token.Kind != "(" && p.Token.Kind != token.TokenIdent { + return false + } + + // Note: all lambda patterns can be parsed as expr -> expr. + p.parseExpr() + return p.Token.Kind == "->" +} + +func (p *Parser) parseLambdaArg() *ast.LambdaArg { + lparen := token.InvalidPos + var args []*ast.Ident + if p.Token.Kind == "(" { + lparen = p.expect("(").Pos + args = parseCommaSeparatedList(p, p.parseIdent) + p.expect(")") + } else { + args = []*ast.Ident{p.parseIdent()} + } + + p.expect("->") + expr := p.parseExpr() + + return &ast.LambdaArg{ + Lparen: lparen, + Args: args, + Expr: expr, + } +} + func (p *Parser) parseArg() ast.Arg { if i := p.tryParseIntervalArg(); i != nil { return i @@ -1595,6 +1631,9 @@ func (p *Parser) parseArg() ast.Arg { if s := p.tryParseSequenceArg(); s != nil { return s } + if p.lookaheadLambdaArg() { + return p.parseLambdaArg() + } return p.parseExprArg() } diff --git a/testdata/input/expr/array_functions_array_filter_parenless_lambda.sql b/testdata/input/expr/array_functions_array_filter_parenless_lambda.sql new file mode 100644 index 00000000..e00d2ea4 --- /dev/null +++ b/testdata/input/expr/array_functions_array_filter_parenless_lambda.sql @@ -0,0 +1 @@ +ARRAY_FILTER([1 ,2, 3], e -> e > 1) \ No newline at end of file diff --git a/testdata/input/expr/array_functions_array_filter_two_args_lambda.sql b/testdata/input/expr/array_functions_array_filter_two_args_lambda.sql new file mode 100644 index 00000000..205e21f0 --- /dev/null +++ b/testdata/input/expr/array_functions_array_filter_two_args_lambda.sql @@ -0,0 +1 @@ +ARRAY_FILTER([0, 2, 3], (e, i) -> e > i) \ No newline at end of file diff --git a/testdata/result/expr/array_functions_array_filter_parenless_lambda.sql.txt b/testdata/result/expr/array_functions_array_filter_parenless_lambda.sql.txt new file mode 100644 index 00000000..982a7bab --- /dev/null +++ b/testdata/result/expr/array_functions_array_filter_parenless_lambda.sql.txt @@ -0,0 +1,72 @@ +--- array_functions_array_filter_parenless_lambda.sql +ARRAY_FILTER([1 ,2, 3], e -> e > 1) +--- AST +&ast.CallExpr{ + Rparen: 34, + Func: &ast.Ident{ + NamePos: 0, + NameEnd: 12, + Name: "ARRAY_FILTER", + }, + Distinct: false, + Args: []ast.Arg{ + &ast.ExprArg{ + Expr: &ast.ArrayLiteral{ + Array: -1, + Lbrack: 13, + Rbrack: 21, + Type: nil, + Values: []ast.Expr{ + &ast.IntLiteral{ + ValuePos: 14, + ValueEnd: 15, + Base: 10, + Value: "1", + }, + &ast.IntLiteral{ + ValuePos: 17, + ValueEnd: 18, + Base: 10, + Value: "2", + }, + &ast.IntLiteral{ + ValuePos: 20, + ValueEnd: 21, + Base: 10, + Value: "3", + }, + }, + }, + }, + &ast.LambdaArg{ + Lparen: -1, + Args: []*ast.Ident{ + &ast.Ident{ + NamePos: 24, + NameEnd: 25, + Name: "e", + }, + }, + Expr: &ast.BinaryExpr{ + Op: ">", + Left: &ast.Ident{ + NamePos: 29, + NameEnd: 30, + Name: "e", + }, + Right: &ast.IntLiteral{ + ValuePos: 33, + ValueEnd: 34, + Base: 10, + Value: "1", + }, + }, + }, + }, + NamedArgs: []*ast.NamedArg(nil), + NullHandling: nil, + Having: nil, +} + +--- SQL +ARRAY_FILTER([1, 2, 3], e -> e > 1) diff --git a/testdata/result/expr/array_functions_array_filter_two_args_lambda.sql.txt b/testdata/result/expr/array_functions_array_filter_two_args_lambda.sql.txt new file mode 100644 index 00000000..7d83c243 --- /dev/null +++ b/testdata/result/expr/array_functions_array_filter_two_args_lambda.sql.txt @@ -0,0 +1,76 @@ +--- array_functions_array_filter_two_args_lambda.sql +ARRAY_FILTER([0, 2, 3], (e, i) -> e > i) +--- AST +&ast.CallExpr{ + Rparen: 39, + Func: &ast.Ident{ + NamePos: 0, + NameEnd: 12, + Name: "ARRAY_FILTER", + }, + Distinct: false, + Args: []ast.Arg{ + &ast.ExprArg{ + Expr: &ast.ArrayLiteral{ + Array: -1, + Lbrack: 13, + Rbrack: 21, + Type: nil, + Values: []ast.Expr{ + &ast.IntLiteral{ + ValuePos: 14, + ValueEnd: 15, + Base: 10, + Value: "0", + }, + &ast.IntLiteral{ + ValuePos: 17, + ValueEnd: 18, + Base: 10, + Value: "2", + }, + &ast.IntLiteral{ + ValuePos: 20, + ValueEnd: 21, + Base: 10, + Value: "3", + }, + }, + }, + }, + &ast.LambdaArg{ + Lparen: 24, + Args: []*ast.Ident{ + &ast.Ident{ + NamePos: 25, + NameEnd: 26, + Name: "e", + }, + &ast.Ident{ + NamePos: 28, + NameEnd: 29, + Name: "i", + }, + }, + Expr: &ast.BinaryExpr{ + Op: ">", + Left: &ast.Ident{ + NamePos: 34, + NameEnd: 35, + Name: "e", + }, + Right: &ast.Ident{ + NamePos: 38, + NameEnd: 39, + Name: "i", + }, + }, + }, + }, + NamedArgs: []*ast.NamedArg(nil), + NullHandling: nil, + Having: nil, +} + +--- SQL +ARRAY_FILTER([0, 2, 3], (e, i) -> e > i)