Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TVF #209

Merged
merged 9 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ func (PathTableExpr) isTableExpr() {}
func (SubQueryTableExpr) isTableExpr() {}
func (ParenTableExpr) isTableExpr() {}
func (Join) isTableExpr() {}
func (TVFCallExpr) isTableExpr() {}

// JoinCondition represents condition part of JOIN expression.
type JoinCondition interface {
Expand Down Expand Up @@ -220,6 +221,15 @@ func (IntervalArg) isArg() {}
func (SequenceArg) isArg() {}
func (LambdaArg) isArg() {}

type TVFArg interface {
Node
isTVFArg()
}

func (ExprArg) isTVFArg() {}
func (ModelArg) isTVFArg() {}
func (TableArg) isTVFArg() {}

// NullHandlingModifier represents IGNORE/RESPECT NULLS of aggregate function calls
type NullHandlingModifier interface {
Node
Expand Down Expand Up @@ -1159,6 +1169,28 @@ type CallExpr struct {
Having HavingModifier // optional
}

// TVFCallExpr is table-valued function call expression node.
//
// {{.Name | sql}}(
// {{.Args | sqlJoin ", "}}
// {{if len(.Args) > 0 && len(.NamedArgs) > 0}}, {{end}}
// {{.NamedArgs | sqlJoin ", "}}
// )
// {{.Hint | sqlOpt}}
// {{.Sample | sqlOpt}}
type TVFCallExpr struct {
// pos = Name.pos
// end = (Sample ?? Hint).end || Rparen + 1

Rparen token.Pos // position of ")"

Name *Path
Args []TVFArg
NamedArgs []*NamedArg
Hint *Hint // optional
Sample *TableSample // optional
}

// ExprArg is argument of the generic function call.
//
// {{.Expr | sql}}
Expand Down Expand Up @@ -1209,6 +1241,30 @@ type LambdaArg struct {
Expr Expr
}

// ModelArg is argument of model function call.
//
// MODEL {{.Name | sql}}
type ModelArg struct {
// pos = Model
// end = Name.end

Model token.Pos // position of "MODEL" keyword

Name *Path
}

// TableArg is TABLE table_name argument of table valued function call.
//
// TABLE {{.Name | sql}}
type TableArg struct {
// pos = Table
// end = Name.end

Table token.Pos // position of "TABLE" keyword

Name *Path
}

// NamedArg represents a name and value pair in named arguments
//
// {{.Name | sql}} => {{.Value | sql}}
Expand Down
24 changes: 24 additions & 0 deletions ast/pos.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions ast/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,15 @@ func (l *LambdaArg) SQL() string {
l.Expr.SQL()
}

func (c *TVFCallExpr) SQL() string {
return c.Name.SQL() + "(" +
sqlJoin(c.Args, ", ") +
strOpt(len(c.Args) > 0 && len(c.NamedArgs) > 0, ", ") +
sqlJoin(c.NamedArgs, ", ") +
")" +
sqlOpt(" ", c.Hint, "")
}

func (n *NamedArg) SQL() string { return n.Name.SQL() + " => " + n.Value.SQL() }

func (i *IgnoreNulls) SQL() string { return "IGNORE NULLS" }
Expand All @@ -541,6 +550,14 @@ func (s *SequenceArg) SQL() string {
return "SEQUENCE " + s.Expr.SQL()
}

func (s *ModelArg) SQL() string {
return "MODEL " + s.Name.SQL()
}

func (s *TableArg) SQL() string {
return "TABLE " + s.Name.SQL()
}

func (*CountStarExpr) SQL() string {
return "COUNT(*)"
}
Expand Down
60 changes: 60 additions & 0 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,9 @@ func (p *Parser) parseSimpleTableExpr() ast.TableExpr {

if p.Token.Kind == token.TokenIdent {
ids := p.parseIdentOrPath()
if p.Token.Kind == "(" {
return p.parseTVFCallExpr(ids)
}
if len(ids) == 1 {
return p.parseTableNameSuffix(ids[0])
}
Expand All @@ -826,6 +829,63 @@ func (p *Parser) parseSimpleTableExpr() ast.TableExpr {
panic(p.errorfAtToken(&p.Token, "expected token: (, UNNEST, <ident>, but: %s", p.Token.Kind))
}

func (p *Parser) parseTVFCallExpr(ids []*ast.Ident) *ast.TVFCallExpr {
p.expect("(")

var args []ast.TVFArg
if p.Token.Kind != ")" {
for !p.lookaheadNamedArg() {
args = append(args, p.parseTVFArg())
if p.Token.Kind != "," {
break
}
p.nextToken()
}
}

var namedArgs []*ast.NamedArg
if p.lookaheadNamedArg() {
namedArgs = parseCommaSeparatedList(p, p.parseNamedArg)
}

rparen := p.expect(")").Pos
hint := p.tryParseHint()
sample := p.tryParseTableSample()

return &ast.TVFCallExpr{
Rparen: rparen,
Name: &ast.Path{Idents: ids},
Args: args,
NamedArgs: namedArgs,
Hint: hint,
Sample: sample,
}
}

func (p *Parser) parseTVFArg() ast.TVFArg {
pos := p.Token.Pos
switch {
case p.Token.IsKeywordLike("TABLE"):
p.nextToken()
path := p.parsePath()

return &ast.TableArg{
Table: pos,
Name: path,
}
case p.Token.IsKeywordLike("MODEL"):
p.nextToken()
path := p.parsePath()

return &ast.ModelArg{
Model: pos,
Name: path,
}
default:
return p.parseExprArg()
}
}

func (p *Parser) parseIdentOrPath() []*ast.Ident {
ids := []*ast.Ident{p.parseIdent()}
for p.Token.Kind == "." {
Expand Down
12 changes: 12 additions & 0 deletions testdata/input/dml/update_with_safe_ml_predict.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-- https://cloud.google.com/spanner/docs/backfill-embeddings?hl=en#backfill
UPDATE products
SET
products.desc_embed = (
SELECT embeddings.values
FROM SAFE.ML.PREDICT(
MODEL gecko_model,
(SELECT products.description AS content)
) @{remote_udf_max_rows_per_rpc=200}
),
products.desc_embed_model_version = 3
WHERE products.desc_embed IS NULL
6 changes: 6 additions & 0 deletions testdata/input/query/select_from_change_stream.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT ChangeRecord FROM READ_SingersNameStream (
start_timestamp => "2022-05-01T09:00:00Z",
end_timestamp => NULL,
partition_token => NULL,
heartbeat_milliseconds => 10000
)
7 changes: 7 additions & 0 deletions testdata/input/query/select_from_ml_predict_hint.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- https://cloud.google.com/spanner/docs/ml-tutorial-generative-ai?hl=en#register_a_generative_ai_model_in_a_schema
SELECT content
FROM ML.PREDICT(
MODEL TextBison,
(SELECT "Is 13 prime?" AS prompt),
STRUCT(256 AS maxOutputTokens, 0.2 AS temperature, 40 as topK, 0.95 AS topP)
) @{remote_udf_max_rows_per_rpc=1}
2 changes: 2 additions & 0 deletions testdata/input/query/select_from_ml_predict_simple.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT id, color, value
FROM ML.PREDICT(MODEL DiamondAppraise, TABLE Diamonds)
15 changes: 15 additions & 0 deletions testdata/input/query/select_from_ml_predict_textbison.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
SELECT product_id, product_name, content
FROM ML.PREDICT(
MODEL TextBison,
(SELECT
product.id as product_id,
product.name as product_name,
CONCAT("Is this product safe for infants?", "\n",
"Product Name: ", product.name, "\n",
"Category Name: ", category.name, "\n",
"Product Description:", product.description) AS prompt
FROM
Products AS product JOIN Categories AS category
ON product.category_id = category.id),
STRUCT(100 AS maxOutputTokens)
) @{remote_udf_max_rows_per_rpc=1}
Loading
Loading